diff --git a/Cargo.lock b/Cargo.lock index fadf7408d..367e21c7b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5426,6 +5426,7 @@ dependencies = [ "test-log", "thiserror", "tokio", + "tokio-stream", "tonic 0.11.0", "tonic 0.12.1", "tonic-build", diff --git a/Cargo.toml b/Cargo.toml index d346940b6..d50c6987d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -109,7 +109,7 @@ flexbuffers = { version = "2.0.0" } futures = "0.3.25" futures-sink = "0.3.25" futures-util = "0.3.25" -googletest = "0.10" +googletest = { version = "0.10", features = ["anyhow"] } hostname = { version = "0.4.0" } http = "1.1.0" http-body = "1.0.1" diff --git a/crates/admin/Cargo.toml b/crates/admin/Cargo.toml index b6851af4b..9b4984ded 100644 --- a/crates/admin/Cargo.toml +++ b/crates/admin/Cargo.toml @@ -69,6 +69,7 @@ restate-types = { workspace = true, features = ["test-util"] } googletest = { workspace = true } tempfile = { workspace = true } test-log = { workspace = true } +tokio-stream = { workspace = true, features = ["net"] } tracing = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/crates/admin/src/cluster_controller/cluster_state.rs b/crates/admin/src/cluster_controller/cluster_state.rs index 60ed5fb28..dd75ed325 100644 --- a/crates/admin/src/cluster_controller/cluster_state.rs +++ b/crates/admin/src/cluster_controller/cluster_state.rs @@ -16,7 +16,7 @@ use std::time::Instant; use tokio::sync::watch; use restate_core::network::rpc_router::RpcRouter; -use restate_core::network::{MessageRouterBuilder, NetworkSender, Outgoing}; +use restate_core::network::{MessageRouterBuilder, Networking, TransportConnect}; use restate_types::net::partition_processor_manager::GetProcessorsState; use restate_types::nodes_config::Role; use restate_types::time::MillisSinceEpoch; @@ -24,26 +24,24 @@ use restate_types::time::MillisSinceEpoch; use restate_core::{Metadata, ShutdownError, TaskCenter, TaskHandle}; use restate_types::Version; -pub struct ClusterStateRefresher { +pub struct ClusterStateRefresher { task_center: TaskCenter, metadata: Metadata, - get_state_router: RpcRouter, + network_sender: Networking, + get_state_router: RpcRouter, in_flight_refresh: Option>>, cluster_state_update_rx: watch::Receiver>, cluster_state_update_tx: Arc>>, } -impl ClusterStateRefresher -where - N: NetworkSender + 'static, -{ +impl ClusterStateRefresher { pub fn new( task_center: TaskCenter, metadata: Metadata, - networking: N, + network_sender: Networking, router_builder: &mut MessageRouterBuilder, ) -> Self { - let get_state_router = RpcRouter::new(networking.clone(), router_builder); + let get_state_router = RpcRouter::new(router_builder); let initial_state = ClusterState { last_refreshed: None, @@ -58,6 +56,7 @@ where Self { task_center, metadata, + network_sender, get_state_router, in_flight_refresh: None, cluster_state_update_rx, @@ -97,6 +96,7 @@ where self.in_flight_refresh = Self::start_refresh_task( self.task_center.clone(), self.get_state_router.clone(), + self.network_sender.clone(), Arc::clone(&self.cluster_state_update_tx), self.metadata.clone(), )?; @@ -106,7 +106,8 @@ where fn start_refresh_task( tc: TaskCenter, - get_state_router: RpcRouter, + get_state_router: RpcRouter, + network_sender: Networking, cluster_state_tx: Arc>>, metadata: Metadata, ) -> Result>>, ShutdownError> { @@ -138,6 +139,7 @@ where let rpc_router = get_state_router.clone(); let tc = tc.clone(); + let network_sender = network_sender.clone(); join_set .build_task() .name("get-processors-state") @@ -148,10 +150,11 @@ where tokio::time::timeout( // todo: make configurable std::time::Duration::from_secs(1), - rpc_router.call(Outgoing::new( + rpc_router.call( + &network_sender, node_id, GetProcessorsState::default(), - )), + ), ) .await, ) @@ -191,7 +194,7 @@ where node_id, NodeState::Alive(AliveNode { last_heartbeat_at: MillisSinceEpoch::now(), - generational_node_id: from, + generational_node_id: *from.peer(), partitions: msg.state, }), ); diff --git a/crates/admin/src/cluster_controller/scheduler.rs b/crates/admin/src/cluster_controller/scheduler.rs index fa6dfdb38..7563032c1 100644 --- a/crates/admin/src/cluster_controller/scheduler.rs +++ b/crates/admin/src/cluster_controller/scheduler.rs @@ -15,7 +15,7 @@ use rand::seq::IteratorRandom; use tracing::{debug, trace}; use restate_core::metadata_store::{MetadataStoreClient, Precondition, ReadError, WriteError}; -use restate_core::network::{NetworkSender, Outgoing}; +use restate_core::network::{NetworkSender, Networking, Outgoing, TransportConnect}; use restate_core::{ShutdownError, SyncError, TaskCenter, TaskKind}; use restate_types::cluster::cluster_state::{ClusterState, NodeState, RunMode}; use restate_types::cluster_controller::{ @@ -45,27 +45,24 @@ pub enum Error { Shutdown(#[from] ShutdownError), } -pub struct Scheduler { +pub struct Scheduler { scheduling_plan: SchedulingPlan, observed_cluster_state: ObservedClusterState, task_center: TaskCenter, metadata_store_client: MetadataStoreClient, - networking: N, + networking: Networking, } /// The scheduler is responsible for assigning partition processors to nodes and to electing /// leaders. It achieves it by deciding on a scheduling plan which is persisted to the metadata /// store and then driving the observed cluster state to the target state (represented by the /// scheduling plan). -impl Scheduler -where - N: NetworkSender + 'static, -{ +impl Scheduler { pub async fn init( task_center: TaskCenter, metadata_store_client: MetadataStoreClient, - networking: N, + networking: Networking, ) -> Result { let scheduling_plan = metadata_store_client .get(SCHEDULING_PLAN_KEY.clone()) @@ -83,7 +80,7 @@ where pub async fn on_attach_node( &mut self, - node: GenerationalNodeId, + node: &GenerationalNodeId, ) -> Result, ShutdownError> { trace!(node = %node, "Node is attaching to cluster"); // the convergence loop will make sure that the node receives its instructions @@ -418,9 +415,11 @@ impl ObservedPartitionState { #[cfg(test)] mod tests { - use crate::cluster_controller::scheduler::{ - ObservedClusterState, ObservedPartitionState, Scheduler, - }; + use std::collections::{BTreeMap, BTreeSet}; + use std::num::NonZero; + use std::sync::Arc; + use std::time::Duration; + use futures::StreamExt; use googletest::matcher::{Matcher, MatcherResult}; use googletest::matchers::{empty, eq}; @@ -428,24 +427,29 @@ mod tests { use http::Uri; use rand::prelude::ThreadRng; use rand::Rng; - use restate_core::TestCoreEnvBuilder; + use test_log::test; + use tokio::sync::mpsc; + use tokio_stream::wrappers::ReceiverStream; + + use restate_core::network::{ForwardingHandler, Incoming, MessageCollectorMockConnector}; + use restate_core::{TaskCenterBuilder, TestCoreEnvBuilder}; use restate_types::cluster::cluster_state::{ AliveNode, ClusterState, DeadNode, NodeState, PartitionProcessorStatus, RunMode, }; use restate_types::cluster_controller::{ReplicationStrategy, SchedulingPlan}; use restate_types::identifiers::PartitionId; use restate_types::metadata_store::keys::SCHEDULING_PLAN_KEY; + use restate_types::net::codec::WireDecode; use restate_types::net::partition_processor_manager::{ControlProcessors, ProcessorCommand}; - use restate_types::net::AdvertisedAddress; + use restate_types::net::{AdvertisedAddress, TargetName}; use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfiguration, Role}; use restate_types::partition_table::PartitionTable; use restate_types::time::MillisSinceEpoch; use restate_types::{GenerationalNodeId, PlainNodeId, Version}; - use std::collections::{BTreeMap, BTreeSet}; - use std::num::NonZero; - use std::sync::Arc; - use std::time::Duration; - use test_log::test; + + use crate::cluster_controller::scheduler::{ + ObservedClusterState, ObservedPartitionState, Scheduler, + }; impl ObservedClusterState { fn remove_node_from_partition( @@ -725,29 +729,57 @@ mod tests { nodes_config.upsert_node(node_config); } - let mut builder = TestCoreEnvBuilder::new_with_mock_network(); - let mut control_processors = builder - .router_builder - .subscribe_to_stream::(32); + let tc = TaskCenterBuilder::default_for_tests() + .build() + .expect("task_center builds"); + + // network messages going to other nodes are written to `tx` + let (tx, control_recv) = mpsc::channel(100); + let connector = MessageCollectorMockConnector::new(tc.clone(), 10, tx.clone()); + + let mut builder = TestCoreEnvBuilder::with_transport_connector(tc, connector); + builder.router_builder.add_raw_handler( + TargetName::ControlProcessors, + // network messages going to my node is also written to `tx` + Box::new(ForwardingHandler::new(GenerationalNodeId::new(1, 1), tx)), + ); + + let mut control_recv = ReceiverStream::new(control_recv) + .filter_map(|(node_id, message)| async move { + if message.body().target() == TargetName::ControlProcessors { + let message = message + .try_map(|mut m| { + ControlProcessors::decode( + &mut m.payload, + restate_types::net::CURRENT_PROTOCOL_VERSION, + ) + }) + .unwrap(); + Some((node_id, message)) + } else { + None + } + }) + .boxed(); let partition_table = PartitionTable::with_equally_sized_partitions(Version::MIN, num_partitions); let initial_scheduling_plan = SchedulingPlan::from(&partition_table, replication_strategy); let metadata_store_client = builder.metadata_store_client.clone(); - let network_sender = builder.network_sender.clone(); + let networking = builder.networking.clone(); let env = builder - .with_nodes_config(nodes_config) - .with_partition_table(partition_table.clone()) - .with_scheduling_plan(initial_scheduling_plan) + .set_nodes_config(nodes_config.clone()) + .set_partition_table(partition_table.clone()) + .set_scheduling_plan(initial_scheduling_plan) .build() .await; let tc = env.tc.clone(); env.tc .run_in_scope("test", None, async move { let mut scheduler = - Scheduler::init(tc, metadata_store_client.clone(), network_sender).await?; + Scheduler::init(tc, metadata_store_client.clone(), networking).await?; for _ in 0..num_scheduling_rounds { let cluster_state = random_cluster_state(&node_ids, num_partitions); @@ -757,10 +789,9 @@ mod tests { .on_cluster_state_update(Arc::clone(&cluster_state)) .await?; // collect all control messages from the network to build up the effective scheduling plan - let control_messages = control_processors + let control_messages = control_recv .as_mut() .take_until(tokio::time::sleep(Duration::from_secs(10))) - .map(|message| message.split()) .collect::>() .await; @@ -865,32 +896,33 @@ mod tests { fn derive_observed_cluster_state( cluster_state: &ClusterState, - control_messages: Vec<(GenerationalNodeId, ControlProcessors)>, + control_messages: Vec<(GenerationalNodeId, Incoming)>, ) -> ObservedClusterState { let mut observed_cluster_state = ObservedClusterState::default(); observed_cluster_state.update(cluster_state); // apply commands - for (node_id, control_processors) in control_messages { - for control_processor in control_processors.commands { + for (target_node, control_processors) in control_messages { + let plain_node_id = target_node.as_plain(); + for control_processor in control_processors.into_body().commands { match control_processor.command { ProcessorCommand::Stop => { observed_cluster_state.remove_node_from_partition( &control_processor.partition_id, - &node_id.as_plain(), + &plain_node_id, ); } ProcessorCommand::Follower => { observed_cluster_state.add_node_to_partition( control_processor.partition_id, - node_id.as_plain(), + plain_node_id, RunMode::Follower, ); } ProcessorCommand::Leader => { observed_cluster_state.add_node_to_partition( control_processor.partition_id, - node_id.as_plain(), + plain_node_id, RunMode::Leader, ); } diff --git a/crates/admin/src/cluster_controller/service.rs b/crates/admin/src/cluster_controller/service.rs index 09e4d4c5a..5c869d0ae 100644 --- a/crates/admin/src/cluster_controller/service.rs +++ b/crates/admin/src/cluster_controller/service.rs @@ -22,7 +22,7 @@ use tracing::{debug, warn}; use restate_bifrost::{Bifrost, BifrostAdmin}; use restate_core::metadata_store::MetadataStoreClient; -use restate_core::network::{Incoming, MessageRouterBuilder, NetworkSender}; +use restate_core::network::{Incoming, MessageRouterBuilder, Networking, TransportConnect}; use restate_core::{ cancellation_watcher, Metadata, MetadataWriter, ShutdownError, TargetVersion, TaskCenter, TaskKind, @@ -46,12 +46,12 @@ pub enum Error { Error, } -pub struct Service { +pub struct Service { task_center: TaskCenter, metadata: Metadata, - networking: N, + networking: Networking, incoming_messages: Pin> + Send + Sync + 'static>>, - cluster_state_refresher: ClusterStateRefresher, + cluster_state_refresher: ClusterStateRefresher, command_tx: mpsc::Sender, command_rx: mpsc::Receiver, @@ -63,15 +63,15 @@ pub struct Service { log_trim_threshold: Lsn, } -impl Service +impl Service where - N: NetworkSender + 'static, + T: TransportConnect, { pub fn new( mut configuration: Live, task_center: TaskCenter, metadata: Metadata, - networking: N, + networking: Networking, router_builder: &mut MessageRouterBuilder, metadata_writer: MetadataWriter, metadata_store_client: MetadataStoreClient, @@ -176,10 +176,7 @@ impl ClusterControllerHandle { } } -impl Service -where - N: NetworkSender + 'static, -{ +impl Service { pub fn handle(&self) -> ClusterControllerHandle { ClusterControllerHandle { tx: self.command_tx.clone(), @@ -344,7 +341,7 @@ where async fn on_attach_request( &self, - scheduler: &mut Scheduler, + scheduler: &mut Scheduler, request: Incoming, ) -> Result<(), ShutdownError> { let actions = scheduler.on_attach_node(request.peer()).await?; @@ -352,7 +349,12 @@ where TaskKind::Disposable, "attachment-response", None, - async move { Ok(request.respond_rpc(AttachResponse { actions }).await?) }, + async move { + Ok(request + .to_rpc_response(AttachResponse { actions }) + .send() + .await?) + }, )?; Ok(()) } @@ -406,13 +408,19 @@ async fn signal_all_partitions_started( #[cfg(test)] mod tests { use super::Service; + + use std::collections::BTreeSet; + use std::sync::atomic::{AtomicU64, Ordering}; + use std::sync::Arc; + use std::time::Duration; + use googletest::assert_that; use googletest::matchers::eq; + use test_log::test; + use restate_bifrost::Bifrost; - use restate_core::network::{Incoming, MessageHandler, NetworkSender}; - use restate_core::{ - MockNetworkSender, NoOpMessageHandler, TaskKind, TestCoreEnv, TestCoreEnvBuilder, - }; + use restate_core::network::{FailingConnector, Incoming, MessageHandler, MockPeerConnection}; + use restate_core::{NoOpMessageHandler, TaskKind, TestCoreEnv, TestCoreEnvBuilder}; use restate_types::cluster::cluster_state::PartitionProcessorStatus; use restate_types::config::{AdminOptions, Configuration}; use restate_types::identifiers::PartitionId; @@ -424,22 +432,17 @@ mod tests { use restate_types::net::AdvertisedAddress; use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfiguration, Role}; use restate_types::{GenerationalNodeId, Version}; - use std::collections::BTreeSet; - use std::sync::atomic::{AtomicU64, Ordering}; - use std::sync::Arc; - use std::time::Duration; - use test_log::test; #[test(tokio::test)] async fn manual_log_trim() -> anyhow::Result<()> { const LOG_ID: LogId = LogId::new(0); - let mut builder = TestCoreEnvBuilder::new_with_mock_network(); + let mut builder = TestCoreEnvBuilder::with_incoming_only_connector(); let svc = Service::new( Live::from_value(Configuration::default()), builder.tc.clone(), builder.metadata.clone(), - builder.network_sender.clone(), + builder.networking.clone(), &mut builder.router_builder, builder.metadata_writer.clone(), builder.metadata_store_client.clone(), @@ -482,7 +485,6 @@ mod tests { } struct PartitionProcessorStatusHandler { - network_sender: MockNetworkSender, persisted_lsn: Arc, // set of node ids for which the handler won't send a response to the caller, this allows to simulate // dead nodes @@ -493,7 +495,7 @@ mod tests { type MessageType = GetProcessorsState; async fn on_message(&self, msg: Incoming) { - if self.block_list.contains(&msg.peer()) { + if self.block_list.contains(msg.peer()) { return; } @@ -503,14 +505,11 @@ mod tests { }; let state = [(PartitionId::MIN, partition_processor_status)].into(); - let response = msg.prepare_rpc_response(ProcessorsStateResponse { state }); - - self.network_sender - // We are not really sending something back to target, we just need to provide a known - // node_id. The response will be sent to a handler running on the very same node. - .send(response) - .await - .expect("send should succeed"); + let response = msg.to_rpc_response(ProcessorsStateResponse { state }); + + // We are not really sending something back to target, we just need to provide a known + // node_id. The response will be sent to a handler running on the very same node. + response.send().await.expect("send should succeed"); } } @@ -528,22 +527,40 @@ mod tests { }; let persisted_lsn = Arc::new(AtomicU64::new(0)); - let (node_env, bifrost) = create_test_env(config, |builder| { - let get_processor_state_handler = PartitionProcessorStatusHandler { - network_sender: builder.network_sender.clone(), - persisted_lsn: Arc::clone(&persisted_lsn), - block_list: BTreeSet::new(), - }; + let get_processor_state_handler = Arc::new(PartitionProcessorStatusHandler { + persisted_lsn: Arc::clone(&persisted_lsn), + block_list: BTreeSet::new(), + }); + let (node_env, bifrost) = create_test_env(config, |builder| { builder - .add_message_handler(get_processor_state_handler) + .add_message_handler(get_processor_state_handler.clone()) .add_message_handler(NoOpMessageHandler::::default()) }) .await?; node_env .tc + .clone() .run_in_scope("test", None, async move { + // simulate a connection from node 2 so we can have a connection between the two + // nodes + let node_2 = MockPeerConnection::connect( + GenerationalNodeId::new(2, 2), + node_env.metadata.nodes_config_version(), + node_env + .metadata + .nodes_config_ref() + .cluster_name() + .to_owned(), + node_env.networking.connection_manager(), + 10, + ) + .await?; + // let node2 receive messages and use the same message handler as node1 + let (_node_2, _node2_reactor) = node_2 + .process_with_message_handler(&node_env.tc, get_processor_state_handler)?; + let mut appender = bifrost.create_appender(LOG_ID)?; for i in 1..=20 { let lsn = appender.append("").await?; @@ -593,22 +610,39 @@ mod tests { }; let persisted_lsn = Arc::new(AtomicU64::new(0)); + let get_processor_state_handler = Arc::new(PartitionProcessorStatusHandler { + persisted_lsn: Arc::clone(&persisted_lsn), + block_list: BTreeSet::new(), + }); let (node_env, bifrost) = create_test_env(config, |builder| { - let get_processor_state_handler = PartitionProcessorStatusHandler { - network_sender: builder.network_sender.clone(), - persisted_lsn: Arc::clone(&persisted_lsn), - block_list: BTreeSet::new(), - }; - builder - .add_message_handler(get_processor_state_handler) + .add_message_handler(get_processor_state_handler.clone()) .add_message_handler(NoOpMessageHandler::::default()) }) .await?; node_env .tc + .clone() .run_in_scope("test", None, async move { + // simulate a connection from node 2 so we can have a connection between the two + // nodes + let node_2 = MockPeerConnection::connect( + GenerationalNodeId::new(2, 2), + node_env.metadata.nodes_config_version(), + node_env + .metadata + .nodes_config_ref() + .cluster_name() + .to_owned(), + node_env.networking.connection_manager(), + 10, + ) + .await?; + // let node2 receive messages and use the same message handler as node1 + let (_node_2, _node2_reactor) = node_2 + .process_with_message_handler(&node_env.tc, get_processor_state_handler)?; + let mut appender = bifrost.create_appender(LOG_ID)?; for i in 1..=20 { let lsn = appender.append(format!("record{}", i)).await?; @@ -666,7 +700,6 @@ mod tests { .collect(); let get_processor_state_handler = PartitionProcessorStatusHandler { - network_sender: builder.network_sender.clone(), persisted_lsn: Arc::clone(&persisted_lsn), block_list: black_list, }; @@ -701,18 +734,18 @@ mod tests { async fn create_test_env( config: Configuration, mut modify_builder: F, - ) -> anyhow::Result<(TestCoreEnv, Bifrost)> + ) -> anyhow::Result<(TestCoreEnv, Bifrost)> where - F: FnMut(TestCoreEnvBuilder) -> TestCoreEnvBuilder, + F: FnMut(TestCoreEnvBuilder) -> TestCoreEnvBuilder, { - let mut builder = TestCoreEnvBuilder::new_with_mock_network(); + let mut builder = TestCoreEnvBuilder::with_incoming_only_connector(); let metadata = builder.metadata.clone(); let svc = Service::new( Live::from_value(config), builder.tc.clone(), builder.metadata.clone(), - builder.network_sender.clone(), + builder.networking.clone(), &mut builder.router_builder, builder.metadata_writer.clone(), builder.metadata_store_client.clone(), @@ -733,7 +766,7 @@ mod tests { Role::Worker.into(), LogServerConfig::default(), )); - let builder = modify_builder(builder.with_nodes_config(nodes_config)); + let builder = modify_builder(builder.set_nodes_config(nodes_config)); let node_env = builder.build().await; diff --git a/crates/bifrost/benches/util.rs b/crates/bifrost/benches/util.rs index 1d8984ba5..f547054a5 100644 --- a/crates/bifrost/benches/util.rs +++ b/crates/bifrost/benches/util.rs @@ -8,9 +8,11 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use tracing::warn; + +use restate_core::network::Networking; use restate_core::{ - spawn_metadata_manager, MetadataBuilder, MetadataManager, MockNetworkSender, TaskCenter, - TaskCenterBuilder, + spawn_metadata_manager, MetadataBuilder, MetadataManager, TaskCenter, TaskCenterBuilder, }; use restate_metadata_store::{MetadataStoreClient, Precondition}; use restate_rocksdb::RocksDbManager; @@ -18,7 +20,6 @@ use restate_types::config::Configuration; use restate_types::live::Constant; use restate_types::logs::metadata::ProviderKind; use restate_types::metadata_store::keys::BIFROST_CONFIG_KEY; -use tracing::warn; pub async fn spawn_environment( config: Configuration, @@ -35,13 +36,13 @@ pub async fn spawn_environment( restate_types::config::set_current_config(config.clone()); let metadata_builder = MetadataBuilder::default(); - let network_sender = MockNetworkSender::new(metadata_builder.to_metadata()); + let networking = Networking::new(metadata_builder.to_metadata(), config.networking.clone()); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); let metadata_manager = MetadataManager::new( metadata_builder, - network_sender.clone(), + networking.clone(), metadata_store_client.clone(), ); diff --git a/crates/bifrost/src/bifrost.rs b/crates/bifrost/src/bifrost.rs index c577b39e3..2d3504235 100644 --- a/crates/bifrost/src/bifrost.rs +++ b/crates/bifrost/src/bifrost.rs @@ -515,8 +515,8 @@ mod tests { #[traced_test] async fn test_append_smoke() -> googletest::Result<()> { let num_partitions = 5; - let node_env = TestCoreEnvBuilder::new_with_mock_network() - .with_partition_table(PartitionTable::with_equally_sized_partitions( + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, num_partitions, )) @@ -592,7 +592,7 @@ mod tests { #[tokio::test(start_paused = true)] async fn test_lazy_initialization() -> googletest::Result<()> { - let node_env = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let node_env = TestCoreEnv::create_with_single_node(1, 1).await; let tc = node_env.tc; tc.run_in_scope("test", None, async { let delay = Duration::from_secs(5); @@ -614,7 +614,7 @@ mod tests { #[test(tokio::test(flavor = "multi_thread", worker_threads = 2))] async fn trim_log_smoke_test() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::new_with_mock_network() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -705,8 +705,8 @@ mod tests { #[tokio::test(start_paused = true)] async fn test_read_across_segments() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::new_with_mock_network() - .with_partition_table(PartitionTable::with_equally_sized_partitions( + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, )) @@ -907,8 +907,8 @@ mod tests { #[traced_test] async fn test_appends_correctly_handle_reconfiguration() -> googletest::Result<()> { const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::new_with_mock_network() - .with_partition_table(PartitionTable::with_equally_sized_partitions( + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() + .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, )) diff --git a/crates/bifrost/src/providers/local_loglet/mod.rs b/crates/bifrost/src/providers/local_loglet/mod.rs index a1d1b20cc..0f84de1d7 100644 --- a/crates/bifrost/src/providers/local_loglet/mod.rs +++ b/crates/bifrost/src/providers/local_loglet/mod.rs @@ -300,7 +300,7 @@ mod tests { F: FnMut(Arc) -> O, O: std::future::Future>, { - let node_env = TestCoreEnvBuilder::new_with_mock_network() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -346,7 +346,7 @@ mod tests { #[tokio::test(flavor = "multi_thread", worker_threads = 4)] async fn local_loglet_append_after_seal_concurrent() -> googletest::Result<()> { - let node_env = TestCoreEnvBuilder::new_with_mock_network() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; diff --git a/crates/bifrost/src/providers/replicated_loglet/provider.rs b/crates/bifrost/src/providers/replicated_loglet/provider.rs index c2cdb9ca5..b79d1c69e 100644 --- a/crates/bifrost/src/providers/replicated_loglet/provider.rs +++ b/crates/bifrost/src/providers/replicated_loglet/provider.rs @@ -13,7 +13,7 @@ use std::sync::Arc; use async_trait::async_trait; use dashmap::DashMap; -use restate_core::network::{MessageRouterBuilder, Networking}; +use restate_core::network::{MessageRouterBuilder, Networking, TransportConnect}; use restate_core::{Metadata, TaskCenter}; use restate_metadata_store::MetadataStoreClient; use restate_types::config::ReplicatedLogletOptions; @@ -28,21 +28,21 @@ use crate::loglet::{Loglet, LogletProvider, LogletProviderFactory, OperationErro use crate::providers::replicated_loglet::error::ReplicatedLogletError; use crate::Error; -pub struct Factory { +pub struct Factory { task_center: TaskCenter, opts: BoxedLiveLoad, metadata: Metadata, metadata_store_client: MetadataStoreClient, - networking: Networking, + networking: Networking, } -impl Factory { +impl Factory { pub fn new( task_center: TaskCenter, opts: BoxedLiveLoad, metadata_store_client: MetadataStoreClient, metadata: Metadata, - networking: Networking, + networking: Networking, _router_builder: &mut MessageRouterBuilder, ) -> Self { // todo(asoli): @@ -59,7 +59,7 @@ impl Factory { } #[async_trait] -impl LogletProviderFactory for Factory { +impl LogletProviderFactory for Factory { fn kind(&self) -> ProviderKind { ProviderKind::Replicated } @@ -76,22 +76,22 @@ impl LogletProviderFactory for Factory { } } -struct ReplicatedLogletProvider { +struct ReplicatedLogletProvider { active_loglets: DashMap<(LogId, SegmentIndex), Arc>, _task_center: TaskCenter, _opts: BoxedLiveLoad, _metadata: Metadata, _metadata_store_client: MetadataStoreClient, - _networking: Networking, + _networking: Networking, } -impl ReplicatedLogletProvider { +impl ReplicatedLogletProvider { fn new( task_center: TaskCenter, opts: BoxedLiveLoad, metadata: Metadata, metadata_store_client: MetadataStoreClient, - networking: Networking, + networking: Networking, ) -> Self { // todo(asoli): create all global state here that'll be shared across loglet instances // - RecordCache. @@ -108,7 +108,7 @@ impl ReplicatedLogletProvider { } #[async_trait] -impl LogletProvider for ReplicatedLogletProvider { +impl LogletProvider for ReplicatedLogletProvider { async fn get_loglet( &self, log_id: LogId, diff --git a/crates/bifrost/src/read_stream.rs b/crates/bifrost/src/read_stream.rs index 2d212acdf..4b217e9c0 100644 --- a/crates/bifrost/src/read_stream.rs +++ b/crates/bifrost/src/read_stream.rs @@ -463,7 +463,7 @@ mod tests { setup_panic_handler(); const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::new_with_mock_network() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -551,7 +551,7 @@ mod tests { setup_panic_handler(); const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::new_with_mock_network() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -660,7 +660,7 @@ mod tests { setup_panic_handler(); const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::new_with_mock_network() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -832,7 +832,7 @@ mod tests { setup_panic_handler(); const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::new_with_mock_network() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; @@ -961,7 +961,7 @@ mod tests { setup_panic_handler(); const LOG_ID: LogId = LogId::new(0); - let node_env = TestCoreEnvBuilder::new_with_mock_network() + let node_env = TestCoreEnvBuilder::with_incoming_only_connector() .set_provider_kind(ProviderKind::Local) .build() .await; diff --git a/crates/core/src/lib.rs b/crates/core/src/lib.rs index ee62f62ce..634a4b983 100644 --- a/crates/core/src/lib.rs +++ b/crates/core/src/lib.rs @@ -29,7 +29,4 @@ pub use task_center_types::*; mod test_env; #[cfg(any(test, feature = "test-util"))] -pub use test_env::{ - create_mock_nodes_config, MockNetworkSender, NoOpMessageHandler, TestCoreEnv, - TestCoreEnvBuilder, -}; +pub use test_env::{create_mock_nodes_config, NoOpMessageHandler, TestCoreEnv, TestCoreEnvBuilder}; diff --git a/crates/core/src/metadata/manager.rs b/crates/core/src/metadata/manager.rs index 8cb30a11b..827dcfdf2 100644 --- a/crates/core/src/metadata/manager.rs +++ b/crates/core/src/metadata/manager.rs @@ -28,7 +28,7 @@ use restate_types::net::metadata::{GetMetadataRequest, MetadataMessage, Metadata use restate_types::nodes_config::NodesConfiguration; use restate_types::partition_table::PartitionTable; use restate_types::schema::Schema; -use restate_types::{GenerationalNodeId, NodeId}; +use restate_types::NodeId; use restate_types::{Version, Versioned}; use super::{Metadata, MetadataContainer, MetadataKind, MetadataWriter}; @@ -37,7 +37,10 @@ use crate::cancellation_watcher; use crate::is_cancellation_requested; use crate::metadata_store::{MetadataStoreClient, ReadError}; use crate::network::Incoming; +use crate::network::Networking; use crate::network::Outgoing; +use crate::network::Reciprocal; +use crate::network::TransportConnect; use crate::network::{MessageHandler, MessageRouterBuilder, NetworkError, NetworkSender}; use crate::task_center; @@ -87,51 +90,44 @@ pub(super) enum Command { /// A handler for processing network messages targeting metadata manager /// (dev.restate.common.TargetName = METADATA_MANAGER) -struct MetadataMessageHandler -where - N: NetworkSender + 'static + Clone, -{ +struct MetadataMessageHandler { sender: CommandSender, - networking: N, metadata: Metadata, } -impl MetadataMessageHandler -where - N: NetworkSender + 'static + Clone, -{ +impl MetadataMessageHandler { fn send_metadata( &self, - peer: GenerationalNodeId, + to: Reciprocal, metadata_kind: MetadataKind, min_version: Option, ) { match metadata_kind { - MetadataKind::NodesConfiguration => self.send_nodes_config(peer, min_version), - MetadataKind::PartitionTable => self.send_partition_table(peer, min_version), - MetadataKind::Logs => self.send_logs(peer, min_version), - MetadataKind::Schema => self.send_schema(peer, min_version), + MetadataKind::NodesConfiguration => self.send_nodes_config(to, min_version), + MetadataKind::PartitionTable => self.send_partition_table(to, min_version), + MetadataKind::Logs => self.send_logs(to, min_version), + MetadataKind::Schema => self.send_schema(to, min_version), }; } - fn send_nodes_config(&self, to: GenerationalNodeId, version: Option) { + fn send_nodes_config(&self, to: Reciprocal, version: Option) { let config = self.metadata.nodes_config_snapshot(); self.send_metadata_internal(to, version, config.deref(), "nodes_config"); } - fn send_partition_table(&self, to: GenerationalNodeId, version: Option) { + fn send_partition_table(&self, to: Reciprocal, version: Option) { let partition_table = self.metadata.partition_table_snapshot(); self.send_metadata_internal(to, version, partition_table.deref(), "partition_table"); } - fn send_logs(&self, to: GenerationalNodeId, version: Option) { + fn send_logs(&self, to: Reciprocal, version: Option) { let logs = self.metadata.logs(); if logs.version() != Version::INVALID { self.send_metadata_internal(to, version, logs.deref(), "logs"); } } - fn send_schema(&self, to: GenerationalNodeId, version: Option) { + fn send_schema(&self, to: Reciprocal, version: Option) { let schema = self.metadata.schema(); if schema.version != Version::INVALID { self.send_metadata_internal(to, version, schema.deref(), "schema"); @@ -140,7 +136,7 @@ where fn send_metadata_internal( &self, - to: GenerationalNodeId, + to: Reciprocal, version: Option, metadata: &T, metadata_name: &str, @@ -165,22 +161,17 @@ where version, ); let metadata = metadata.clone(); + let outgoing = to.prepare(MetadataMessage::MetadataUpdate(MetadataUpdate { + container: MetadataContainer::from(metadata), + })); let _ = task_center().spawn_child( crate::TaskKind::Disposable, "send-metadata-to-peer", None, { - let networking = self.networking.clone(); async move { - networking - .send(Outgoing::new( - to, - MetadataMessage::MetadataUpdate(MetadataUpdate { - container: MetadataContainer::from(metadata), - }), - )) - .await?; + outgoing.send().await?; Ok(()) } }, @@ -188,20 +179,17 @@ where } } -impl MessageHandler for MetadataMessageHandler -where - N: NetworkSender + 'static + Clone, -{ +impl MessageHandler for MetadataMessageHandler { type MessageType = MetadataMessage; async fn on_message(&self, envelope: Incoming) { - let (peer, msg) = envelope.split(); + let (reciprocal, msg) = envelope.split(); match msg { MetadataMessage::MetadataUpdate(update) => { info!( "Received '{}' metadata update from peer {}", update.container.kind(), - peer + reciprocal.peer(), ); if let Err(e) = self .sender @@ -213,8 +201,11 @@ where } } MetadataMessage::GetMetadataRequest(request) => { - debug!("Received GetMetadataRequest from peer {}", peer); - self.send_metadata(peer, request.metadata_kind, request.min_version); + debug!( + "Received GetMetadataRequest from peer {}", + reciprocal.peer() + ); + self.send_metadata(reciprocal, request.metadata_kind, request.min_version); } }; } @@ -239,21 +230,18 @@ where /// - Schema metadata /// - NodesConfiguration /// - Partition table -pub struct MetadataManager { +pub struct MetadataManager { metadata: Metadata, inbound: CommandReceiver, - networking: N, + networking: Networking, metadata_store_client: MetadataStoreClient, update_tasks: EnumMap>, } -impl MetadataManager -where - N: NetworkSender + 'static + Clone, -{ +impl MetadataManager { pub fn new( metadata_builder: MetadataBuilder, - networking: N, + networking: Networking, metadata_store_client: MetadataStoreClient, ) -> Self { Self { @@ -268,7 +256,6 @@ where pub fn register_in_message_router(&self, sr_builder: &mut MessageRouterBuilder) { sr_builder.add_message_handler(MetadataMessageHandler { sender: self.metadata.sender.clone(), - networking: self.networking.clone(), metadata: self.metadata.clone(), }); } @@ -443,7 +430,7 @@ where self.update_task_and_notify_watches(maybe_new_version, MetadataKind::Schema); } - fn update_internal(container: &ArcSwap, new_value: T) -> Version { + fn update_internal(container: &ArcSwap, new_value: M) -> Version { let current_value = container.load(); let mut maybe_new_version = new_value.version(); @@ -573,6 +560,7 @@ mod tests { use super::*; use googletest::prelude::*; + use restate_types::config::NetworkingOptions; use test_log::test; use restate_test_util::assert_eq; @@ -581,7 +569,6 @@ mod tests { use restate_types::{GenerationalNodeId, Version}; use crate::metadata::spawn_metadata_manager; - use crate::test_env::MockNetworkSender; use crate::TaskCenterBuilder; #[test] @@ -618,11 +605,12 @@ mod tests { let tc = TaskCenterBuilder::default().build()?; tc.block_on("test", None, async move { let metadata_builder = MetadataBuilder::default(); - let network_sender = MockNetworkSender::new(metadata_builder.to_metadata()); + let networking = + Networking::new(metadata_builder.to_metadata(), NetworkingOptions::default()); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); let metadata_manager = - MetadataManager::new(metadata_builder, network_sender, metadata_store_client); + MetadataManager::new(metadata_builder, networking, metadata_store_client); let metadata_writer = metadata_manager.writer(); assert_eq!(Version::INVALID, config_version(&metadata)); @@ -691,12 +679,13 @@ mod tests { let tc = TaskCenterBuilder::default().build()?; tc.block_on("test", None, async move { let metadata_builder = MetadataBuilder::default(); - let network_sender = MockNetworkSender::new(metadata_builder.to_metadata()); + let networking = + Networking::new(metadata_builder.to_metadata(), NetworkingOptions::default()); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); let metadata_manager = - MetadataManager::new(metadata_builder, network_sender, metadata_store_client); + MetadataManager::new(metadata_builder, networking, metadata_store_client); let metadata_writer = metadata_manager.writer(); assert_eq!(Version::INVALID, config_version(&metadata)); diff --git a/crates/core/src/metadata/mod.rs b/crates/core/src/metadata/mod.rs index b4ec3574e..732b2321c 100644 --- a/crates/core/src/metadata/mod.rs +++ b/crates/core/src/metadata/mod.rs @@ -11,8 +11,7 @@ mod manager; pub use manager::{MetadataManager, TargetVersion}; -use restate_types::live::{Live, Pinned}; -use restate_types::schema::Schema; +pub use restate_types::net::metadata::MetadataKind; use std::sync::{Arc, OnceLock}; @@ -20,16 +19,17 @@ use arc_swap::{ArcSwap, AsRaw}; use enum_map::EnumMap; use tokio::sync::{mpsc, oneshot, watch}; +use restate_types::live::{Live, Pinned}; use restate_types::logs::metadata::Logs; use restate_types::net::metadata::MetadataContainer; -pub use restate_types::net::metadata::MetadataKind; use restate_types::nodes_config::NodesConfiguration; use restate_types::partition_table::PartitionTable; +use restate_types::schema::Schema; use restate_types::{GenerationalNodeId, NodeId, Version, Versioned}; use crate::metadata::manager::Command; use crate::metadata_store::ReadError; -use crate::network::NetworkSender; +use crate::network::TransportConnect; use crate::{ShutdownError, TaskCenter, TaskId, TaskKind}; #[derive(Debug, thiserror::Error)] @@ -352,13 +352,10 @@ impl Default for VersionWatch { } } -pub fn spawn_metadata_manager( +pub fn spawn_metadata_manager( tc: &TaskCenter, - metadata_manager: MetadataManager, -) -> Result -where - N: NetworkSender + 'static, -{ + metadata_manager: MetadataManager, +) -> Result { tc.spawn( TaskKind::MetadataBackgroundSync, "metadata-manager", diff --git a/crates/core/src/network/connection.rs b/crates/core/src/network/connection.rs index 90a5e8ec2..bfb46781e 100644 --- a/crates/core/src/network/connection.rs +++ b/crates/core/src/network/connection.rs @@ -15,55 +15,98 @@ use std::time::Instant; use enum_map::{enum_map, EnumMap}; use tokio::sync::mpsc; use tokio::sync::mpsc::error::TrySendError; -use tracing::instrument; -use restate_types::live::Live; -use restate_types::logs::metadata::Logs; use restate_types::net::codec::Targeted; use restate_types::net::codec::{serialize_message, WireEncode}; use restate_types::net::metadata::MetadataKind; use restate_types::net::ProtocolVersion; -use restate_types::nodes_config::NodesConfiguration; -use restate_types::partition_table::PartitionTable; use restate_types::protobuf::node::message; use restate_types::protobuf::node::Header; use restate_types::protobuf::node::Message; -use restate_types::schema::Schema; -use restate_types::{GenerationalNodeId, Version, Versioned}; +use restate_types::{GenerationalNodeId, Version}; -use super::metric_definitions::CONNECTION_SEND_DURATION; use super::metric_definitions::MESSAGE_SENT; use super::NetworkError; -use super::NetworkSendError; use super::Outgoing; -use super::ProtocolError; -use crate::network::connection_manager::MetadataVersions; use crate::Metadata; +pub struct OwnedSendPermit { + _protocol_version: ProtocolVersion, + _permit: mpsc::OwnedPermit, + _phantom: std::marker::PhantomData, +} + +pub struct SendPermit<'a, M> { + protocol_version: ProtocolVersion, + permit: mpsc::Permit<'a, Message>, + _phantom: std::marker::PhantomData, +} + +impl<'a, M> SendPermit<'a, M> +where + M: WireEncode + Targeted, +{ + /// Sends a message over this permit. + /// + /// Note that sending messages over this permit won't use the peer information nor the connection + /// associated with the message. + pub fn send(self, message: Outgoing, metadata: &Metadata) { + let metadata_versions = HeaderMetadataVersions::from_metadata(metadata); + self.send_with_versions(message, metadata_versions); + } + + fn send_with_versions( + self, + message: Outgoing, + metadata_versions: HeaderMetadataVersions, + ) { + let header = Header::new( + metadata_versions[MetadataKind::NodesConfiguration] + .expect("nodes configuration version must be set"), + metadata_versions[MetadataKind::Logs], + metadata_versions[MetadataKind::Schema], + metadata_versions[MetadataKind::PartitionTable], + message.msg_id(), + message.in_response_to(), + ); + let body = serialize_message(message.into_body(), self.protocol_version) + .expect("message encoding infallible"); + self.send_raw(Message::new(header, body)); + } +} + +impl<'a, M> SendPermit<'a, M> { + /// Sends a raw pre-serialized message over this permit. + /// + /// Note that sending messages over this permit won't use the peer information nor the connection + /// associated with the message. + pub(crate) fn send_raw(self, raw_message: Message) { + self.permit.send(raw_message); + MESSAGE_SENT.increment(1); + } +} + /// A single streaming connection with a channel to the peer. A connection can be /// opened by either ends of the connection and has no direction. Any connection /// can be used to send or receive from a peer. /// /// The primary owner of a connection is the running reactor, all other components -/// should hold a Weak if caching access to a certain connection is +/// should hold a `WeakConnection` if access to a certain connection is /// needed. -pub struct Connection { - /// Connection identifier, randomly generated on this end of the connection. - pub(crate) cid: u64, +pub struct OwnedConnection { pub(crate) peer: GenerationalNodeId, pub(crate) protocol_version: ProtocolVersion, pub(crate) sender: mpsc::Sender, pub(crate) created: Instant, } -impl Connection { +impl OwnedConnection { pub(crate) fn new( peer: GenerationalNodeId, protocol_version: ProtocolVersion, sender: mpsc::Sender, ) -> Self { Self { - cid: rand::random(), peer, protocol_version, sender, @@ -106,118 +149,47 @@ impl Connection { /// A handle that sends messages through that connection. This hides the /// wire protocol from the user and guarantees order of messages. - pub fn sender(self: &Arc, metadata: &Metadata) -> ConnectionSender { - ConnectionSender { + pub fn downgrade(self: &Arc) -> WeakConnection { + WeakConnection { + peer: self.peer, connection: Arc::downgrade(self), - nodes_config: metadata.updateable_nodes_config(), - schema: metadata.updateable_schema(), - logs: metadata.updateable_logs_metadata(), - partition_table: metadata.updateable_partition_table(), - metadata_versions: MetadataVersions::default(), } } - /// Send a message on this connection. This returns Ok(()) when the message is: - /// - Successfully serialized to the wire format based on the negotiated protocol - /// - Serialized message was enqueued on the send buffer of the socket - /// - /// That means that this is not a guarantee that the message has been sent - /// over the network or that the peer has received it. - /// - /// If this is needed, the caller must design the wire protocol with a - /// request/response state machine and perform retries on other nodes/connections if needed. - /// - /// This roughly maps to the semantics of a POSIX write/send socket operation. - /// - /// This doesn't auto-retry connection resets or send errors, this is up to the user - /// for retrying externally. - #[instrument(level = "trace", skip_all, fields(peer_node_id = %self.peer, target_service = ?message.target(), msg = ?message.kind()))] - pub async fn send( - &self, - message: Outgoing, - metadata_versions: HeaderMetadataVersions, - ) -> Result<(), NetworkSendError> - where - M: WireEncode + Targeted, - { - let send_start = Instant::now(); - // do not serialize if we can't acquire capacity - let permit = match self.sender.reserve().await { - Ok(permit) => permit, - Err(_) => { - return Err(NetworkSendError::new( - message, - NetworkError::ConnectionClosed, - )) - } - }; - - let serialized_msg = match self.create_message(&message, metadata_versions) { - Ok(m) => m, - Err(e) => return Err(NetworkSendError::new(message, e)), - }; - permit.send(serialized_msg); - MESSAGE_SENT.increment(1); - CONNECTION_SEND_DURATION.record(send_start.elapsed()); - Ok(()) + /// Allocates capacity to send one message on this connection. If connection is closed, this + /// returns None. + pub async fn reserve(&self) -> Option> { + let permit = self.sender.reserve().await.ok()?; + Some(SendPermit { + permit, + protocol_version: self.protocol_version, + _phantom: std::marker::PhantomData, + }) } - fn create_message( - &self, - message: &Outgoing, - metadata_versions: HeaderMetadataVersions, - ) -> Result - where - M: WireEncode + Targeted, - { - let header = Header::new( - metadata_versions[MetadataKind::NodesConfiguration] - .expect("nodes configuration version must be set"), - metadata_versions[MetadataKind::Logs], - metadata_versions[MetadataKind::Schema], - metadata_versions[MetadataKind::PartitionTable], - message.msg_id(), - message.in_response_to(), - ); - let body = serialize_message(message.body(), self.protocol_version) - .map_err(ProtocolError::Codec)?; - Ok(Message::new(header, body)) + pub async fn reserve_owned(self) -> Option> { + let permit = self.sender.reserve_owned().await.ok()?; + Some(OwnedSendPermit { + _permit: permit, + _protocol_version: self.protocol_version, + _phantom: std::marker::PhantomData, + }) } - /// Tries sending a message on this connection. If there is no capacity, it will fail. Apart - /// from this, the method behaves similarly to [`Connection::send`]. - #[instrument(skip_all, fields(peer_node_id = %self.peer, target_service = ?message.target(), msg = ?message.kind()))] - pub fn try_send( - &self, - message: Outgoing, - metadata_versions: HeaderMetadataVersions, - ) -> Result<(), NetworkSendError> - where - M: WireEncode + Targeted, - { - let send_start = Instant::now(); - // do not serialize if we can't acquire capacity + /// Tries to allocate capacity to send one message on this connection. If there is no capacity, + /// it will fail with [`NetworkError::Full`]. If connection is closed it returns [`NetworkError::ConnectionClosed`] + pub fn try_reserve(&self) -> Result, NetworkError> { let permit = match self.sender.try_reserve() { Ok(permit) => permit, - Err(TrySendError::Full(_)) => { - return Err(NetworkSendError::new(message, NetworkError::Full)) - } - Err(TrySendError::Closed(_)) => { - return Err(NetworkSendError::new( - message, - NetworkError::ConnectionClosed, - )) - } + Err(TrySendError::Full(_)) => return Err(NetworkError::Full), + Err(TrySendError::Closed(_)) => return Err(NetworkError::ConnectionClosed), }; - let serialized_msg = match self.create_message(&message, metadata_versions) { - Ok(m) => m, - Err(e) => return Err(NetworkSendError::new(message, e)), - }; - permit.send(serialized_msg); - MESSAGE_SENT.increment(1); - CONNECTION_SEND_DURATION.record(send_start.elapsed()); - Ok(()) + Ok(SendPermit { + permit, + protocol_version: self.protocol_version, + _phantom: std::marker::PhantomData, + }) } } @@ -252,40 +224,32 @@ impl HeaderMetadataVersions { } } -impl PartialEq for Connection { +impl PartialEq for OwnedConnection { fn eq(&self, other: &Self) -> bool { - self.cid == other.cid && self.peer == other.peer + self.sender.same_channel(&other.sender) } } /// A handle to send messages through a connection. It's safe to hold and clone objects of this -/// even if the connection has been dropped. Cloning and holding comes at the cost of caching -/// all existing metadata which is not for free. -#[derive(Clone)] -pub struct ConnectionSender { - connection: Weak, - nodes_config: Live, - schema: Live, - logs: Live, - partition_table: Live, - metadata_versions: MetadataVersions, +/// even if the connection has been dropped. Cheap to clone. +#[derive(Clone, Debug)] +pub struct WeakConnection { + pub(crate) peer: GenerationalNodeId, + pub(crate) connection: Weak, } -impl ConnectionSender { - /// See [`Connection::send`]. - pub async fn send(&mut self, message: Outgoing) -> Result<(), NetworkSendError> - where - M: WireEncode + Targeted, - { - let Some(connection) = self.connection.upgrade() else { - return Err(NetworkSendError::new( - message, - NetworkError::ConnectionClosed, - )); - }; - connection - .send(message, self.header_metadata_versions()) - .await +static_assertions::assert_impl_all!(WeakConnection: Send, Sync); + +impl WeakConnection { + pub fn new_closed(peer: GenerationalNodeId) -> Self { + Self { + peer, + connection: Weak::new(), + } + } + + pub fn peer(&self) -> &GenerationalNodeId { + &self.peer } pub fn is_closed(&self) -> bool { @@ -305,35 +269,462 @@ impl ConnectionSender { connection.closed().await } } +} - /// See [`Connection::try_send`]. - pub fn try_send(&mut self, message: Outgoing) -> Result<(), NetworkSendError> - where - M: WireEncode + Targeted, - { - let Some(connection) = self.connection.upgrade() else { - return Err(NetworkSendError::new( - message, - NetworkError::ConnectionClosed, - )); - }; - connection.try_send(message, self.header_metadata_versions()) +#[cfg(any(test, feature = "test-util"))] +pub mod test_util { + use super::*; + + use std::sync::Arc; + use std::time::Instant; + + use async_trait::async_trait; + use futures::stream::BoxStream; + use futures::StreamExt; + use restate_types::net::CodecError; + use tokio::sync::mpsc; + use tokio::sync::mpsc::error::TrySendError; + use tokio_stream::wrappers::ReceiverStream; + use tracing::info; + use tracing::warn; + + use restate_types::net::codec::MessageBodyExt; + use restate_types::net::codec::Targeted; + use restate_types::net::codec::{serialize_message, WireEncode}; + use restate_types::net::ProtocolVersion; + use restate_types::nodes_config::NodesConfiguration; + use restate_types::protobuf::node::message; + use restate_types::protobuf::node::message::BinaryMessage; + use restate_types::protobuf::node::message::Body; + use restate_types::protobuf::node::message::ConnectionControl; + use restate_types::protobuf::node::Header; + use restate_types::protobuf::node::Hello; + use restate_types::protobuf::node::Message; + use restate_types::protobuf::node::Welcome; + use restate_types::NodeId; + use restate_types::{GenerationalNodeId, Version}; + + use crate::cancellation_watcher; + use crate::network::handshake::negotiate_protocol_version; + use crate::network::handshake::wait_for_hello; + use crate::network::ConnectionManager; + use crate::network::Handler; + use crate::network::Incoming; + use crate::network::MessageHandler; + use crate::network::MessageRouterBuilder; + use crate::network::NetworkError; + use crate::network::ProtocolError; + use crate::network::TransportConnect; + use crate::TaskCenter; + use crate::TaskHandle; + use crate::TaskKind; + + // For testing + // + // Used to simulate incoming connection. Gives control to reading and writing messages. + // + // Sending messages on this connection simulates a remote machine sending messages to our + // connection manager. Sending means "incoming messages". The recv_stream on the other hand + // can be used to read responses that we sent back. + #[derive(derive_more::Debug)] + pub struct MockPeerConnection { + /// The Id of the node that this connection represents + pub my_node_id: GenerationalNodeId, + /// The Id of the node we are connected to + pub(crate) peer: GenerationalNodeId, + pub protocol_version: ProtocolVersion, + pub sender: mpsc::Sender, + pub created: Instant, + + #[debug(skip)] + pub recv_stream: BoxStream<'static, Message>, } - fn header_metadata_versions(&mut self) -> HeaderMetadataVersions { - let mut version_updates = self.metadata_versions.update( - None, - Some(self.partition_table.live_load().version()), - Some(self.schema.live_load().version()), - Some(self.logs.live_load().version()), - ); - version_updates[MetadataKind::NodesConfiguration] = - Some(self.nodes_config.live_load().version()); + impl MockPeerConnection { + /// must run in task-center + pub async fn connect( + from_node_id: GenerationalNodeId, + my_node_config_version: Version, + my_cluster_name: String, + connection_manager: &ConnectionManager, + message_buffer: usize, + ) -> anyhow::Result { + let (sender, rx) = mpsc::channel(message_buffer); + let incoming = ReceiverStream::new(rx).map(Ok); + + let hello = Hello::new(from_node_id, my_cluster_name); + let hello = Message::new( + Header::new( + my_node_config_version, + None, + None, + None, + crate::network::generate_msg_id(), + None, + ), + hello, + ); + sender.send(hello).await?; - HeaderMetadataVersions { - versions: version_updates, + let created = Instant::now(); + let mut recv_stream = connection_manager + .accept_incoming_connection(incoming) + .await?; + let msg = recv_stream + .next() + .await + .ok_or(anyhow::anyhow!("expected welcome message"))?; + let welcome = match msg.body { + Some(message::Body::Welcome(welcome)) => welcome, + _ => anyhow::bail!("unexpected message, we expect Welcome instead"), + }; + + let peer: NodeId = welcome.my_node_id.expect("peer node id must be set").into(); + let peer = peer + .as_generational() + .expect("peer must be generational node id"); + + Ok(Self { + my_node_id: from_node_id, + peer, + protocol_version: welcome.protocol_version(), + sender, + recv_stream: Box::pin(recv_stream), + created, + }) + } + + /// fails only if receiver is terminated (connection terminated) + pub async fn send_raw(&self, message: M, header: Header) -> anyhow::Result<()> + where + M: WireEncode + Targeted, + { + let body = serialize_message(message, self.protocol_version).expect("serde unfallible"); + let message = Message::new(header, body); + + self.sender.send(message).await?; + + Ok(()) + } + + pub async fn reserve(&self) -> Option> { + let permit = self.sender.reserve().await.ok()?; + Some(SendPermit { + permit, + protocol_version: self.protocol_version, + _phantom: std::marker::PhantomData, + }) + } + + pub async fn reserve_owned(self) -> Option> { + let permit = self.sender.reserve_owned().await.ok()?; + Some(OwnedSendPermit { + _permit: permit, + _protocol_version: self.protocol_version, + _phantom: std::marker::PhantomData, + }) + } + + /// Tries to allocate capacity to send one message on this connection. If there is no capacity, + /// it will fail with [`NetworkError::Full`]. If connection is closed it returns [`NetworkError::ConnectionClosed`] + pub fn try_reserve(&self) -> Result, NetworkError> { + let permit = match self.sender.try_reserve() { + Ok(permit) => permit, + Err(TrySendError::Full(_)) => return Err(NetworkError::Full), + Err(TrySendError::Closed(_)) => return Err(NetworkError::ConnectionClosed), + }; + + Ok(SendPermit { + permit, + protocol_version: self.protocol_version, + _phantom: std::marker::PhantomData, + }) + } + + /// Allows you to use utilities in OwnedConnection or WeakConnection. + /// Reminder: Sending on this connection will cause message to arrive as incoming to the node + /// we are connected to. + pub fn to_owned_connection(&self) -> OwnedConnection { + OwnedConnection { + peer: self.peer, + protocol_version: self.protocol_version, + sender: self.sender.clone(), + created: self.created, + } + } + + // Allow for messages received on this connection to be processed by a given message handler. + pub fn process_with_message_handler( + self, + task_center: &TaskCenter, + handler: H, + ) -> anyhow::Result<(WeakConnection, TaskHandle>)> { + let mut router = MessageRouterBuilder::default(); + router.add_message_handler(handler); + let router = router.build(); + self.process_with_message_router(task_center, router) + } + + // Allow for messages received on this connection to be processed by a given message router. + // A task will be created that takes ownership of the receive stream. Stopping the task will + // drop the receive stream (simulates connection loss). + pub fn process_with_message_router( + self, + task_center: &TaskCenter, + router: R, + ) -> anyhow::Result<(WeakConnection, TaskHandle>)> { + let Self { + my_node_id, + peer, + protocol_version, + sender, + created, + recv_stream, + } = self; + + let connection = Arc::new(OwnedConnection { + peer, + protocol_version, + sender, + created, + }); + + let weak = connection.downgrade(); + let message_processor = MessageProcessor { + my_node_id, + router, + connection, + recv_stream, + }; + let handle = task_center.spawn_unmanaged( + TaskKind::ConnectionReactor, + "test-message-processor", + None, + async move { message_processor.run().await }, + )?; + Ok((weak, handle)) + } + + // Allow for messages received on this connection to be forwarded to the supplied sender. + pub fn forward_to_sender( + self, + task_center: &TaskCenter, + sender: mpsc::Sender<(GenerationalNodeId, Incoming)>, + ) -> anyhow::Result<(WeakConnection, TaskHandle>)> { + let handler = ForwardingHandler { + my_node_id: self.my_node_id, + inner_sender: sender, + }; + + self.process_with_message_router(task_center, handler) } } -} -static_assertions::assert_impl_all!(ConnectionSender: Send, Sync); + // Prepresents a partially connected peer connection in test environment. A connection must be + // handshaken in order to be converted into MockPeerConnection. + // + // This is used to represent an outgoing connection from (`peer` to `my_node_id`) + #[derive(derive_more::Debug)] + pub struct PartialPeerConnection { + /// The Id of the node that this connection represents + pub my_node_id: GenerationalNodeId, + /// The Id of the node id that started this connection + pub(crate) peer: GenerationalNodeId, + pub sender: mpsc::Sender, + pub created: Instant, + + #[debug(skip)] + pub recv_stream: BoxStream<'static, Message>, + } + + impl PartialPeerConnection { + // todo(asoli): replace implementation with body of accept_incoming_connection to unify + // handshake validations + pub async fn handshake( + self, + nodes_config: &NodesConfiguration, + ) -> anyhow::Result { + let Self { + my_node_id, + peer, + sender, + created, + mut recv_stream, + } = self; + let temp_stream = recv_stream.by_ref(); + let (header, hello) = wait_for_hello( + &mut temp_stream.map(Ok), + std::time::Duration::from_millis(500), + ) + .await?; + + // NodeId **must** be generational at this layer + let _peer_node_id = hello.my_node_id.ok_or(ProtocolError::HandshakeFailed( + "NodeId is not set in the Hello message", + ))?; + + // Are we both from the same cluster? + if hello.cluster_name != nodes_config.cluster_name() { + return Err(ProtocolError::HandshakeFailed("cluster name mismatch").into()); + } + + let selected_protocol_version = negotiate_protocol_version(&hello)?; + + // Enqueue the welcome message + let welcome = Welcome::new(my_node_id, selected_protocol_version); + + let welcome = Message::new( + Header::new( + nodes_config.version(), + None, + None, + None, + crate::network::generate_msg_id(), + Some(header.msg_id), + ), + welcome, + ); + sender.try_send(welcome)?; + + Ok(MockPeerConnection { + my_node_id, + peer, + protocol_version: selected_protocol_version, + sender, + created, + recv_stream, + }) + } + } + + pub struct ForwardingHandler { + my_node_id: GenerationalNodeId, + inner_sender: mpsc::Sender<(GenerationalNodeId, Incoming)>, + } + + impl ForwardingHandler { + pub fn new( + my_node_id: GenerationalNodeId, + inner_sender: mpsc::Sender<(GenerationalNodeId, Incoming)>, + ) -> Self { + Self { + my_node_id, + inner_sender, + } + } + } + + #[async_trait] + impl Handler for ForwardingHandler { + type Error = CodecError; + + async fn call( + &self, + message: Incoming, + _protocol_version: ProtocolVersion, + ) -> Result<(), Self::Error> { + if self + .inner_sender + .send((self.my_node_id, message)) + .await + .is_err() + { + warn!("Failed to send message to inner sender, connection is closed"); + } + Ok(()) + } + } + + struct MessageProcessor { + my_node_id: GenerationalNodeId, + router: R, + connection: Arc, + recv_stream: BoxStream<'static, Message>, + } + + impl MessageProcessor { + async fn run(mut self) -> anyhow::Result<()> { + let mut cancel = std::pin::pin!(cancellation_watcher()); + loop { + tokio::select! { + _ = &mut cancel => { + info!("Message processor cancelled for node {}", self.my_node_id); + break; + } + maybe_msg = self.recv_stream.next() => { + let Some(msg) = maybe_msg else { + info!("Terminating message processor because connection sender is dropped for node {}", self.my_node_id); + break; + }; + // header is required on all messages + let Some(header) = msg.header else { + self.connection.send_control_frame(ConnectionControl::codec_error( + "Header is missing on message", + )); + break; + }; + + // body are not allowed to be empty. + let Some(body) = msg.body else { + self.connection + .send_control_frame(ConnectionControl::codec_error("Body is missing on message")); + break; + }; + + // Welcome and hello are not allowed after handshake + if body.is_welcome() || body.is_hello() { + self.connection.send_control_frame(ConnectionControl::codec_error( + "Hello/Welcome are not allowed after handshake", + )); + break; + }; + + // if it's a control signal, handle it, otherwise, route with message router. + if let message::Body::ConnectionControl(ctrl_msg) = &body { + // do something + info!( + "Terminating connection based on signal from peer: {:?} {}", + ctrl_msg.signal(), + ctrl_msg.message + ); + break; + } + + + self.route_message(header, body).await?; + } + } + } + Ok(()) + } + + async fn route_message(&mut self, header: Header, body: Body) -> anyhow::Result<()> { + match body.try_as_binary_body(self.connection.protocol_version) { + Ok(msg) => { + if let Err(e) = self + .router + .call( + Incoming::from_parts( + msg, + self.connection.downgrade(), + header.msg_id, + header.in_response_to, + ), + self.connection.protocol_version, + ) + .await + { + warn!("Error processing message: {:?}", e); + } + } + Err(status) => { + // terminate the stream + info!("Error processing message, reporting error to peer: {status}"); + self.connection + .send_control_frame(ConnectionControl::codec_error(status.to_string())); + } + } + Ok(()) + } + } +} diff --git a/crates/core/src/network/connection_manager.rs b/crates/core/src/network/connection_manager.rs index 60fba478d..889d39de9 100644 --- a/crates/core/src/network/connection_manager.rs +++ b/crates/core/src/network/connection_manager.rs @@ -8,56 +8,49 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use std::collections::{hash_map, HashMap}; -use std::sync::{Arc, Mutex, Weak}; +use std::collections::HashMap; +use std::sync::{Arc, Weak}; use std::time::Instant; use enum_map::EnumMap; -use futures::stream::BoxStream; use futures::{Stream, StreamExt}; +use parking_lot::Mutex; use rand::seq::SliceRandom; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; -use tonic::transport::Channel; use tracing::{debug, info, trace, warn, Instrument, Span}; use restate_types::config::NetworkingOptions; use restate_types::net::codec::MessageBodyExt; use restate_types::net::metadata::MetadataKind; -use restate_types::net::AdvertisedAddress; use restate_types::nodes_config::NodesConfiguration; use restate_types::protobuf::node::message::{self, ConnectionControl}; use restate_types::protobuf::node::{Header, Hello, Message, Welcome}; use restate_types::{GenerationalNodeId, NodeId, PlainNodeId, Version}; -use super::connection::{Connection, ConnectionSender}; +use super::connection::{OwnedConnection, WeakConnection}; use super::error::{NetworkError, ProtocolError}; -use super::handshake::{negotiate_protocol_version, wait_for_hello, wait_for_welcome}; +use super::handshake::wait_for_welcome; use super::metric_definitions::{ self, CONNECTION_DROPPED, INCOMING_CONNECTION, MESSAGE_PROCESSING_DURATION, MESSAGE_RECEIVED, ONGOING_DRAIN, OUTGOING_CONNECTION, }; -use super::protobuf::node_svc::node_svc_client::NodeSvcClient; +use super::transport_connector::TransportConnect; use super::{Handler, MessageRouter}; use crate::metadata::Urgency; -use crate::network::net_util::create_tonic_channel_from_advertised_address; +use crate::network::handshake::{negotiate_protocol_version, wait_for_hello}; use crate::network::Incoming; use crate::Metadata; use crate::{cancellation_watcher, current_task_id, task_center, TaskId, TaskKind}; -// todo: make this configurable -const SEND_QUEUE_SIZE: usize = 1000; -static_assertions::const_assert!(SEND_QUEUE_SIZE >= 1); - struct ConnectionManagerInner { router: MessageRouter, - connections: HashMap>, - connection_by_gen_id: HashMap>>, + connections: HashMap>, + connection_by_gen_id: HashMap>>, /// This tracks the max generation we observed from connection attempts regardless of our nodes /// configuration. We cannot accept connections from nodes older than ones we have observed /// already. observed_generations: HashMap, - channel_cache: HashMap, } impl ConnectionManagerInner { @@ -67,11 +60,14 @@ impl ConnectionManagerInner { fn cleanup_stale_connections(&mut self, peer_node_id: &GenerationalNodeId) { if let Some(connections) = self.connection_by_gen_id.get_mut(peer_node_id) { - connections.retain(|c| c.upgrade().is_some()); + connections.retain(|c| c.upgrade().is_some_and(|c| !c.is_closed())); } } - fn get_random_connection(&self, peer_node_id: &GenerationalNodeId) -> Option> { + fn get_random_connection( + &self, + peer_node_id: &GenerationalNodeId, + ) -> Option> { self.connection_by_gen_id .get(peer_node_id) .and_then(|connections| connections.choose(&mut rand::thread_rng())?.upgrade()) @@ -86,41 +82,71 @@ impl Default for ConnectionManagerInner { connections: HashMap::default(), connection_by_gen_id: HashMap::default(), observed_generations: HashMap::default(), - channel_cache: HashMap::default(), } } } -#[derive(Clone)] -pub struct ConnectionManager { +pub struct ConnectionManager { inner: Arc>, + networking_options: NetworkingOptions, + transport_connector: Arc, metadata: Metadata, - options: NetworkingOptions, } -impl ConnectionManager { +impl Clone for ConnectionManager { + fn clone(&self) -> Self { + Self { + inner: self.inner.clone(), + networking_options: self.networking_options.clone(), + transport_connector: self.transport_connector.clone(), + metadata: self.metadata.clone(), + } + } +} + +#[cfg(any(test, feature = "test-util"))] +/// used for testing. Accepts connections but can't establish new connections +impl ConnectionManager { + pub fn new_incoming_only(metadata: Metadata) -> Self { + let inner = Arc::new(Mutex::new(ConnectionManagerInner::default())); + + Self { + metadata, + inner, + transport_connector: Arc::new(super::FailingConnector::default()), + networking_options: NetworkingOptions::default(), + } + } +} + +impl ConnectionManager { /// Creates the connection manager. - pub(super) fn new(metadata: Metadata, options: NetworkingOptions) -> Self { + pub fn new( + metadata: Metadata, + transport_connector: Arc, + networking_options: NetworkingOptions, + ) -> Self { let inner = Arc::new(Mutex::new(ConnectionManagerInner::default())); Self { metadata, inner, - options, + transport_connector, + networking_options, } } /// Updates the message router. Note that this only impacts new connections. /// In general, this should be called once on application start after /// initializing all message handlers. pub fn set_message_router(&self, router: MessageRouter) { - self.inner.lock().unwrap().router = router; + self.inner.lock().router = router; } /// Accept a new incoming connection stream and register a network reactor task for it. pub async fn accept_incoming_connection( &self, mut incoming: S, - ) -> Result>, NetworkError> + ) -> Result + Unpin + Send + 'static, NetworkError> where S: Stream> + Unpin + Send + 'static, { @@ -146,8 +172,11 @@ impl ConnectionManager { // window) to avoid dangling resources by misbehaving peers or under sever load conditions. // The client can retry with an exponential backoff on handshake timeout. debug!("Accepting incoming connection"); - let (header, hello) = - wait_for_hello(&mut incoming, self.options.handshake_timeout.into()).await?; + let (header, hello) = wait_for_hello( + &mut incoming, + self.networking_options.handshake_timeout.into(), + ) + .await?; let nodes_config = self.metadata.nodes_config_ref(); let my_node_id = self.metadata.my_node_id(); // NodeId **must** be generational at this layer @@ -187,7 +216,9 @@ impl ConnectionManager { self.verify_node_id(peer_node_id, header, &nodes_config)?; - let (tx, rx) = mpsc::channel(SEND_QUEUE_SIZE); + let (tx, output_stream) = + mpsc::channel(self.networking_options.outbound_queue_length.into()); + let output_stream = ReceiverStream::new(output_stream); // Enqueue the welcome message let welcome = Welcome::new(my_node_id, selected_protocol_version); @@ -205,136 +236,60 @@ impl ConnectionManager { tx.try_send(welcome) .expect("channel accept Welcome message"); + let connection = OwnedConnection::new(peer_node_id, selected_protocol_version, tx); INCOMING_CONNECTION.increment(1); - let connection = Connection::new(peer_node_id, selected_protocol_version, tx); // Register the connection. let _ = self.start_connection_reactor(connection, incoming)?; - // For uniformity with outbound connections, we map all responses to Ok, we never rely on - // sending tonic::Status errors explicitly. We use ConnectionControl frames to communicate - // errors and/or drop the stream when necessary. - let transformed = ReceiverStream::new(rx).map(Ok); - Ok(Box::pin(transformed)) - } - - fn verify_node_id( - &self, - peer_node_id: GenerationalNodeId, - header: Header, - nodes_config: &NodesConfiguration, - ) -> Result<(), NetworkError> { - if let Err(e) = nodes_config.find_node_by_id(peer_node_id) { - // If nodeId is unrecognized and peer is at higher nodes configuration version, - // then we have to update our NodesConfiguration - if let Some(other_nodes_config_version) = header.my_nodes_config_version.map(Into::into) - { - let peer_is_in_the_future = other_nodes_config_version > nodes_config.version(); - - if peer_is_in_the_future { - self.metadata.notify_observed_version( - MetadataKind::NodesConfiguration, - other_nodes_config_version, - None, - Urgency::High, - ); - debug!("Remote node '{}' with newer nodes configuration '{}' tried to connect. Trying to fetch newer version before accepting connection.", peer_node_id, other_nodes_config_version); - } else { - info!("Unknown remote node '{}' tried to connect to cluster. Rejecting connection.", peer_node_id); - } - } else { - info!("Unknown remote node '{}' w/o specifying its node configuration tried to connect to cluster. Rejecting connection.", peer_node_id); - } - - return Err(NetworkError::UnknownNode(e)); - } - - Ok(()) + // Our output stream, i.e. responses. + Ok(output_stream) } /// Always attempts to create a new connection with peer - pub async fn enforced_new_node_sender( + pub async fn enforce_new_connection( &self, node_id: GenerationalNodeId, - ) -> Result { + ) -> Result { let connection = self.connect(node_id).await?; - Ok(connection.sender(&self.metadata)) + Ok(connection.downgrade()) } /// Gets an existing connection or creates a new one if no active connection exists. If /// multiple connections already exist, it returns a random one. - pub async fn get_node_sender( + pub async fn get_or_connect( &self, node_id: GenerationalNodeId, - ) -> Result { + ) -> Result, NetworkError> { // find a connection by node_id - let maybe_connection: Option> = { - let guard = self.inner.lock().unwrap(); + let maybe_connection: Option> = { + let guard = self.inner.lock(); guard.get_random_connection(&node_id) // lock is dropped. }; if let Some(connection) = maybe_connection { - return Ok(connection.sender(&self.metadata)); + return Ok(connection); } - // We have no connection, or the connection we picked is stale. We attempt to create a - // new connection anyway. - let connection = self.connect(node_id).await?; - Ok(connection.sender(&self.metadata)) - } - - async fn connect(&self, node_id: GenerationalNodeId) -> Result, NetworkError> { - let address = self - .metadata - .nodes_config_ref() - .find_node_by_id(node_id)? - .address - .clone(); - - trace!("Attempting to connect to node {} at {}", node_id, address); - // Do we have a channel in cache for this address? - let channel = { - let mut guard = self.inner.lock().unwrap(); - if let hash_map::Entry::Vacant(entry) = guard.channel_cache.entry(address.clone()) { - let channel = create_tonic_channel_from_advertised_address(address) - .map_err(|e| NetworkError::BadNodeAddress(node_id.into(), e))?; - entry.insert(channel.clone()); - channel - } else { - guard.channel_cache.get(&address).unwrap().clone() - } - }; - - self.connect_with_channel(node_id, channel).await + // We have no connection. We attempt to create a new connection. + self.connect(node_id).await } - // Left here for future use. This allows the node to connect to itself and bypass the - // networking stack. - #[cfg(test)] - fn _connect_loopback( + async fn connect( &self, node_id: GenerationalNodeId, - ) -> Result, NetworkError> { - let (tx, rx) = mpsc::channel(SEND_QUEUE_SIZE); - let connection = Connection::new(node_id, restate_types::net::CURRENT_PROTOCOL_VERSION, tx); - - let transformed = ReceiverStream::new(rx).map(Ok); - let incoming = Box::pin(transformed); - OUTGOING_CONNECTION.increment(1); - INCOMING_CONNECTION.increment(1); - self.start_connection_reactor(connection, incoming) - } + ) -> Result, NetworkError> { + if node_id == self.metadata.my_node_id() { + return self.connect_loopback(); + } - async fn connect_with_channel( - &self, - node_id: GenerationalNodeId, - channel: Channel, - ) -> Result, NetworkError> { - let mut client = NodeSvcClient::new(channel); - let cluster_name = self.metadata.nodes_config_ref().cluster_name().to_owned(); let my_node_id = self.metadata.my_node_id(); + let nodes_config = self.metadata.nodes_config_snapshot(); + let cluster_name = nodes_config.cluster_name().to_owned(); - let (tx, rx) = mpsc::channel(SEND_QUEUE_SIZE); + let (tx, output_stream) = + mpsc::channel(self.networking_options.outbound_queue_length.into()); + let output_stream = ReceiverStream::new(output_stream); let hello = Hello::new(my_node_id, cluster_name); // perform handshake. @@ -354,15 +309,16 @@ impl ConnectionManager { tx.send(hello).await.expect("Channel accept hello message"); // Establish the connection - let incoming = client - .create_connection(ReceiverStream::new(rx)) - .await? - .into_inner(); - - let mut transformed = incoming.map(|x| x.map_err(ProtocolError::from)); + let mut incoming = self + .transport_connector + .connect(node_id, &nodes_config, output_stream) + .await?; // finish the handshake - let (_header, welcome) = - wait_for_welcome(&mut transformed, self.options.handshake_timeout.into()).await?; + let (_header, welcome) = wait_for_welcome( + &mut incoming, + self.networking_options.handshake_timeout.into(), + ) + .await?; let protocol_version = welcome.protocol_version(); if !protocol_version.is_supported() { @@ -387,8 +343,7 @@ impl ConnectionManager { .into()); } - OUTGOING_CONNECTION.increment(1); - let connection = Connection::new( + let connection = OwnedConnection::new( peer_node_id .as_generational() .expect("must be generational id"), @@ -396,14 +351,61 @@ impl ConnectionManager { tx, ); - self.start_connection_reactor(connection, transformed) + OUTGOING_CONNECTION.increment(1); + self.start_connection_reactor(connection, incoming) + } + + fn connect_loopback(&self) -> Result, NetworkError> { + let (tx, rx) = mpsc::channel(self.networking_options.outbound_queue_length.into()); + let connection = OwnedConnection::new( + self.metadata.my_node_id(), + restate_types::net::CURRENT_PROTOCOL_VERSION, + tx, + ); + + let incoming = ReceiverStream::new(rx).map(Ok); + self.start_connection_reactor(connection, incoming) + } + + fn verify_node_id( + &self, + peer_node_id: GenerationalNodeId, + header: Header, + nodes_config: &NodesConfiguration, + ) -> Result<(), NetworkError> { + if let Err(e) = nodes_config.find_node_by_id(peer_node_id) { + // If nodeId is unrecognized and peer is at higher nodes configuration version, + // then we have to update our NodesConfiguration + if let Some(other_nodes_config_version) = header.my_nodes_config_version.map(Into::into) + { + let peer_is_in_the_future = other_nodes_config_version > nodes_config.version(); + + if peer_is_in_the_future { + self.metadata.notify_observed_version( + MetadataKind::NodesConfiguration, + other_nodes_config_version, + None, + Urgency::High, + ); + debug!("Remote node '{}' with newer nodes configuration '{}' tried to connect. Trying to fetch newer version before accepting connection.", peer_node_id, other_nodes_config_version); + } else { + info!("Unknown remote node '{}' tried to connect to cluster. Rejecting connection.", peer_node_id); + } + } else { + info!("Unknown remote node '{}' w/o specifying its node configuration tried to connect to cluster. Rejecting connection.", peer_node_id); + } + + return Err(NetworkError::UnknownNode(e)); + } + + Ok(()) } fn start_connection_reactor( &self, - connection: Connection, + connection: OwnedConnection, incoming: S, - ) -> Result, NetworkError> + ) -> Result, NetworkError> where S: Stream> + Unpin + Send + 'static, { @@ -412,7 +414,7 @@ impl ConnectionManager { // If we have a connection with an older generation, we request to drop it. // However, more than one connection with the same generation is allowed. let mut _cleanup = false; - let mut guard = self.inner.lock().unwrap(); + let mut guard = self.inner.lock(); let known_generation = guard .observed_generations .get(&connection.peer.as_plain()) @@ -461,10 +463,13 @@ impl ConnectionManager { ) .instrument(span), )?; - debug!( - peer_node_id = %peer_node_id, - task_id = %task_id, - "Incoming connection accepted from node {}", peer_node_id); + if peer_node_id != self.metadata.my_node_id() { + debug!( + peer_node_id = %peer_node_id, + task_id = %task_id, + "Incoming connection accepted from node {}", peer_node_id + ); + } // Reactor has already started by now. guard.connections.insert(task_id, connection_weak.clone()); @@ -482,7 +487,7 @@ impl ConnectionManager { async fn run_reactor( connection_manager: Arc>, - connection: Arc, + connection: Arc, router: MessageRouter, mut incoming: S, metadata: Metadata, @@ -583,9 +588,8 @@ where if let Err(e) = router .call( Incoming::from_parts( - connection.peer, msg, - Arc::downgrade(&connection), + connection.downgrade(), header.msg_id, header.in_response_to, ), @@ -632,10 +636,9 @@ where if let Err(e) = router .call( Incoming::from_parts( - peer_node_id, msg, // This is a dying connection, don't pass it down. - Weak::new(), + WeakConnection::new_closed(peer_node_id), header.msg_id, header.in_response_to, ), @@ -666,8 +669,11 @@ where Ok(()) } -fn on_connection_draining(connection: &Connection, inner_manager: &Mutex) { - let mut guard = inner_manager.lock().unwrap(); +fn on_connection_draining( + connection: &OwnedConnection, + inner_manager: &Mutex, +) { + let mut guard = inner_manager.lock(); if let Some(connections) = guard.connection_by_gen_id.get_mut(&connection.peer) { // Remove this connection from connections map to reduce the chance // of picking it up as connection. @@ -681,7 +687,7 @@ fn on_connection_draining(connection: &Connection, inner_manager: &Mutex) { let task_id = current_task_id().expect("TaskId is set"); - let mut guard = inner_manager.lock().unwrap(); + let mut guard = inner_manager.lock(); guard.drop_connection(task_id); } @@ -737,34 +743,47 @@ impl MetadataVersions { mod tests { use super::*; - use crate::{MetadataBuilder, MockNetworkSender, TestCoreEnv, TestCoreEnvBuilder}; use googletest::prelude::*; + use test_log::test; + use tokio::sync::mpsc; + use restate_test_util::{assert_eq, let_assert}; - use restate_types::net::codec::{serialize_message, Targeted, WireDecode, WireEncode}; + use restate_types::net::codec::WireDecode; use restate_types::net::metadata::{GetMetadataRequest, MetadataMessage}; use restate_types::net::partition_processor_manager::GetProcessorsState; use restate_types::net::{ - ProtocolVersion, CURRENT_PROTOCOL_VERSION, MIN_SUPPORTED_PROTOCOL_VERSION, + AdvertisedAddress, ProtocolVersion, CURRENT_PROTOCOL_VERSION, + MIN_SUPPORTED_PROTOCOL_VERSION, + }; + use restate_types::nodes_config::{ + LogServerConfig, NodeConfig, NodesConfigError, NodesConfiguration, Role, }; - use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfigError, Role}; - use restate_types::protobuf::node::message; use restate_types::protobuf::node::message::Body; + use restate_types::protobuf::node::{Header, Hello}; use restate_types::Version; - use test_log::test; - use tonic::Status; + + use crate::network::MockPeerConnection; + use crate::{TestCoreEnv, TestCoreEnvBuilder}; // Test handshake with a client #[tokio::test] async fn test_hello_welcome_handshake() -> Result<()> { - let test_setup = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let test_setup = TestCoreEnv::create_with_single_node(1, 1).await; test_setup .tc .run_in_scope("test", None, async { let metadata = crate::metadata(); - let connections = - ConnectionManager::new(metadata.clone(), NetworkingOptions::default()); + let connections = ConnectionManager::new_incoming_only(metadata.clone()); - let _ = establish_connection(metadata.my_node_id(), &metadata, &connections).await; + let _mock_connection = MockPeerConnection::connect( + GenerationalNodeId::new(1, 1), + metadata.nodes_config_version(), + metadata.nodes_config_ref().cluster_name().to_owned(), + &connections, + 10, + ) + .await + .unwrap(); Ok(()) }) @@ -773,13 +792,14 @@ mod tests { #[tokio::test(start_paused = true)] async fn test_hello_welcome_timeout() -> Result<()> { - let test_setup = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let test_setup = TestCoreEnv::create_with_single_node(1, 1).await; let metadata = test_setup.metadata; + let net_opts = NetworkingOptions::default(); test_setup .tc .run_in_scope("test", None, async { let (_tx, rx) = mpsc::channel(1); - let connections = ConnectionManager::new(metadata, NetworkingOptions::default()); + let connections = ConnectionManager::new_incoming_only(metadata); let start = tokio::time::Instant::now(); let incoming = ReceiverStream::new(rx); @@ -791,7 +811,7 @@ mod tests { ProtocolError::HandshakeTimeout(_) )) )); - assert!(start.elapsed() >= connections.options.handshake_timeout.into()); + assert!(start.elapsed() >= net_opts.handshake_timeout.into()); Ok(()) }) .await @@ -799,7 +819,7 @@ mod tests { #[tokio::test] async fn test_bad_handshake() -> Result<()> { - let test_setup = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let test_setup = TestCoreEnv::create_with_single_node(1, 1).await; let metadata = test_setup.metadata; test_setup .tc @@ -829,8 +849,7 @@ mod tests { .await .expect("Channel accept hello message"); - let connections = - ConnectionManager::new(metadata.clone(), NetworkingOptions::default()); + let connections = ConnectionManager::new_incoming_only(metadata.clone()); let incoming = ReceiverStream::new(rx); let resp = connections.accept_incoming_connection(incoming).await; assert!(resp.is_err()); @@ -863,7 +882,7 @@ mod tests { ); tx.send(Ok(hello)).await?; - let connections = ConnectionManager::new(metadata, NetworkingOptions::default()); + let connections = ConnectionManager::new_incoming_only(metadata.clone()); let incoming = ReceiverStream::new(rx); let err = connections .accept_incoming_connection(incoming) @@ -883,7 +902,7 @@ mod tests { #[tokio::test] async fn test_node_generation() -> Result<()> { - let test_setup = TestCoreEnv::create_with_mock_nodes_config(1, 2).await; + let test_setup = TestCoreEnv::create_with_single_node(1, 2).await; let metadata = test_setup.metadata; test_setup .tc @@ -913,8 +932,7 @@ mod tests { .await .expect("Channel accept hello message"); - let connections = - ConnectionManager::new(metadata.clone(), NetworkingOptions::default()); + let connections = ConnectionManager::new_incoming_only(metadata.clone()); let incoming = ReceiverStream::new(rx); let err = connections @@ -953,7 +971,7 @@ mod tests { .await .expect("Channel accept hello message"); - let connections = ConnectionManager::new(metadata, NetworkingOptions::default()); + let connections = ConnectionManager::new_incoming_only(metadata); let incoming = ReceiverStream::new(rx); let err = connections @@ -984,26 +1002,25 @@ mod tests { ); nodes_config.upsert_node(node_config); - let (network_tx, mut network_rx) = mpsc::unbounded_channel(); - let metadata_builder = MetadataBuilder::default(); - - let test_env = TestCoreEnvBuilder::new( - MockNetworkSender::from_sender(network_tx, metadata_builder.to_metadata()), - metadata_builder, - ) - .with_nodes_config(nodes_config) - .build() - .await; + let test_env = TestCoreEnvBuilder::with_incoming_only_connector() + .set_nodes_config(nodes_config) + .build() + .await; test_env .tc .run_in_scope("test", None, async { let metadata = crate::metadata(); - let connections = - ConnectionManager::new(metadata.clone(), NetworkingOptions::default()); - let (connection, _rx) = - establish_connection(node_id, &metadata, &connections).await; + let mut connection = MockPeerConnection::connect( + node_id, + metadata.nodes_config_version(), + metadata.nodes_config_ref().cluster_name().to_string(), + test_env.networking.connection_manager(), + 10, + ) + .await + .into_test_result()?; let request = GetProcessorsState {}; let partition_table_version = metadata.partition_table_version().next(); @@ -1016,10 +1033,13 @@ mod tests { None, ); - connection.send(request, header).await?; + connection + .send_raw(request, header) + .await + .into_test_result()?; - let (target, message) = network_rx.recv().await.expect("some message"); - assert_eq!(NodeId::from(target), node_id); + // we expect the request to go throught he existing open connection to my node + let message = connection.recv_stream.next().await.expect("some message"); assert_get_metadata_request( message, connection.protocol_version, @@ -1061,84 +1081,4 @@ mod tests { MetadataMessage::decode(&mut binary_message.payload, protocol_version)?; Ok(metadata_message) } - - async fn establish_connection( - node_id: GenerationalNodeId, - metadata: &Metadata, - connections: &ConnectionManager, - ) -> ( - TestConnection, - BoxStream<'static, std::result::Result>, - ) { - let (tx, rx) = mpsc::channel(1); - - let hello = Hello::new( - node_id, - metadata.nodes_config_ref().cluster_name().to_owned(), - ); - let hello = Message::new( - Header::new( - metadata.nodes_config_version(), - None, - None, - None, - crate::network::generate_msg_id(), - None, - ), - hello, - ); - tx.send(Ok(hello)) - .await - .expect("Channel accept hello message"); - - let incoming = ReceiverStream::new(rx); - let mut output_stream = connections - .accept_incoming_connection(incoming) - .await - .expect("handshake"); - let msg = output_stream - .next() - .await - .expect("welcome message") - .expect("ok"); - let welcome = match msg.body { - Some(message::Body::Welcome(welcome)) => welcome, - _ => panic!("unexpected message"), - }; - assert_eq!(welcome.my_node_id, Some(metadata.my_node_id().into())); - - ( - TestConnection::new(welcome.protocol_version(), tx), - output_stream, - ) - } - - struct TestConnection { - protocol_version: ProtocolVersion, - tx: mpsc::Sender>, - } - - impl TestConnection { - fn new( - protocol_version: ProtocolVersion, - tx: mpsc::Sender>, - ) -> Self { - Self { - protocol_version, - tx, - } - } - - async fn send(&self, message: M, header: Header) -> Result<()> - where - M: WireEncode + Targeted, - { - let body = serialize_message(&message, self.protocol_version)?; - let message = Message::new(header, body); - - self.tx.send(Ok(message)).await?; - - Ok(()) - } - } } diff --git a/crates/core/src/network/error.rs b/crates/core/src/network/error.rs index 0604a123a..c91bfb8b1 100644 --- a/crates/core/src/network/error.rs +++ b/crates/core/src/network/error.rs @@ -14,8 +14,6 @@ use restate_types::NodeId; use crate::{ShutdownError, SyncError}; -use super::Outgoing; - #[derive(Debug, thiserror::Error)] pub enum RouterError { #[error("codec error: {0}")] @@ -27,14 +25,14 @@ pub enum RouterError { #[derive(Debug, thiserror::Error)] #[error("send error: {source}")] pub struct NetworkSendError { - pub message: Outgoing, + pub original: M, #[source] pub source: NetworkError, } impl NetworkSendError { - pub fn new(message: Outgoing, source: NetworkError) -> Self { - Self { message, source } + pub fn new(original: M, source: NetworkError) -> Self { + Self { original, source } } } diff --git a/crates/core/src/network/message_router.rs b/crates/core/src/network/message_router.rs index 179f061ba..2e3d7ff36 100644 --- a/crates/core/src/network/message_router.rs +++ b/crates/core/src/network/message_router.rs @@ -39,10 +39,38 @@ pub trait MessageHandler { ) -> impl std::future::Future + Send; } +impl MessageHandler for Arc +where + T: MessageHandler, +{ + type MessageType = T::MessageType; + + fn on_message( + &self, + msg: Incoming, + ) -> impl std::future::Future + Send { + (**self).on_message(msg) + } +} + +impl MessageHandler for Box +where + T: MessageHandler, +{ + type MessageType = T::MessageType; + + fn on_message( + &self, + msg: Incoming, + ) -> impl std::future::Future + Send { + (**self).on_message(msg) + } +} + /// A low-level handler trait. #[async_trait] -pub trait Handler: Send + Sync { - type Error; +pub trait Handler: Send { + type Error: std::fmt::Debug; /// Deserialize and process the message asynchronously. async fn call( &self, @@ -51,6 +79,38 @@ pub trait Handler: Send + Sync { ) -> Result<(), Self::Error>; } +#[async_trait] +impl Handler for Arc +where + T: Handler + Send + Sync + 'static, +{ + type Error = T::Error; + + async fn call( + &self, + message: Incoming, + protocol_version: ProtocolVersion, + ) -> Result<(), Self::Error> { + (**self).call(message, protocol_version).await + } +} + +#[async_trait] +impl Handler for Box +where + T: Handler + Send + Sync + 'static, +{ + type Error = T::Error; + + async fn call( + &self, + message: Incoming, + protocol_version: ProtocolVersion, + ) -> Result<(), Self::Error> { + (**self).call(message, protocol_version).await + } +} + #[derive(Clone, Default)] pub struct MessageRouter(Arc); @@ -68,7 +128,7 @@ impl Handler for MessageRouter { message: Incoming, protocol_version: ProtocolVersion, ) -> Result<(), Self::Error> { - let target = message.target(); + let target = message.body().target(); let Some(handler) = self.0.handlers.get(&target) else { return Err(RouterError::NotRegisteredTarget(target.to_string())); }; @@ -86,7 +146,7 @@ impl MessageRouterBuilder { /// Attach a handler that implements [`MessageHandler`] to receive messages /// for the associated target. #[track_caller] - pub fn add_message_handler(&mut self, handler: H) + pub fn add_message_handler(&mut self, handler: H) -> &mut Self where H: MessageHandler + Send + Sync + 'static, { @@ -95,20 +155,20 @@ impl MessageRouterBuilder { if self.handlers.insert(target, Box::new(wrapped)).is_some() { panic!("Handler for target {} has been registered already!", target); } + self } /// Attach a handler that receives all messages targeting a certain [`TargetName`]. #[track_caller] - pub fn add_raw_handler( + pub fn add_raw_handler( &mut self, target: TargetName, handler: Box + Send + Sync>, - ) where - H: Handler + Send + Sync + 'static, - { + ) -> &mut Self { if self.handlers.insert(target, handler).is_some() { panic!("Handler for target {} has been registered already!", target); } + self } /// Subscribe to a stream of messages for a specific target. This enables consumers of messages diff --git a/crates/core/src/network/mod.rs b/crates/core/src/network/mod.rs index 69b8f67a2..441b69bc4 100644 --- a/crates/core/src/network/mod.rs +++ b/crates/core/src/network/mod.rs @@ -19,12 +19,20 @@ mod network_sender; mod networking; pub mod protobuf; pub mod rpc_router; +pub mod transport_connector; mod types; -pub use connection::{Connection, ConnectionSender}; +pub use connection::{OwnedConnection, WeakConnection}; pub use connection_manager::ConnectionManager; pub use error::*; pub use message_router::*; pub use network_sender::*; pub use networking::Networking; +pub use transport_connector::{GrpcConnector, TransportConnect}; pub use types::*; + +#[cfg(any(test, feature = "test-util"))] +pub use connection::test_util::*; + +#[cfg(any(test, feature = "test-util"))] +pub use transport_connector::test_util::*; diff --git a/crates/core/src/network/network_sender.rs b/crates/core/src/network/network_sender.rs index a58942c64..e4aaf3f85 100644 --- a/crates/core/src/network/network_sender.rs +++ b/crates/core/src/network/network_sender.rs @@ -10,10 +10,13 @@ use restate_types::net::codec::{Targeted, WireEncode}; -use super::{NetworkSendError, Outgoing}; +use super::{NetworkSendError, NoConnection, Outgoing}; /// Send NetworkMessage to nodes -pub trait NetworkSender: Send + Sync + Clone { +pub trait NetworkSender: Send + Sync + Clone +where + S: super::private::Sealed, +{ /// Send a message to a peer node. Order of messages is not guaranteed since underlying /// implementations might load balance message writes across multiple connections or re-order /// messages in-flight based on priority. If ordered delivery is required, then use @@ -35,8 +38,8 @@ pub trait NetworkSender: Send + Sync + Clone { /// over the network or that the peer have received it. fn send( &self, - message: Outgoing, - ) -> impl std::future::Future>> + Send + message: Outgoing, + ) -> impl std::future::Future>>> + Send where M: WireEncode + Targeted + Send + Sync; } diff --git a/crates/core/src/network/networking.rs b/crates/core/src/network/networking.rs index 9fc399be3..7d9044b8c 100644 --- a/crates/core/src/network/networking.rs +++ b/crates/core/src/network/networking.rs @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0. use std::num::NonZeroUsize; +use std::sync::Arc; use tracing::{info, instrument, trace}; @@ -17,34 +18,63 @@ use restate_types::net::codec::{Targeted, WireEncode}; use restate_types::NodeId; use super::{ - ConnectionManager, ConnectionSender, NetworkError, NetworkSendError, NetworkSender, Outgoing, + ConnectionManager, HasConnection, NetworkError, NetworkSendError, NetworkSender, NoConnection, + Outgoing, WeakConnection, }; +use super::{GrpcConnector, TransportConnect}; use crate::Metadata; /// Access to node-to-node networking infrastructure. -#[derive(Clone)] -pub struct Networking { - connections: ConnectionManager, +pub struct Networking { + connections: ConnectionManager, metadata: Metadata, options: NetworkingOptions, } -impl Networking { +impl Clone for Networking { + fn clone(&self) -> Self { + Self { + connections: self.connections.clone(), + metadata: self.metadata.clone(), + options: self.options.clone(), + } + } +} + +impl Networking { pub fn new(metadata: Metadata, options: NetworkingOptions) -> Self { Self { - connections: ConnectionManager::new(metadata.clone(), options.clone()), + connections: ConnectionManager::new( + metadata.clone(), + Arc::new(GrpcConnector::new(options.clone())), + options.clone(), + ), + metadata, + options, + } + } +} + +impl Networking { + pub fn with_connection_manager( + metadata: Metadata, + options: NetworkingOptions, + connection_manager: ConnectionManager, + ) -> Self { + Self { + connections: connection_manager, metadata, options, } } - pub fn connection_manager(&self) -> ConnectionManager { - self.connections.clone() + pub fn connection_manager(&self) -> &ConnectionManager { + &self.connections } /// A connection sender is pinned to a single stream, thus guaranteeing ordered delivery of /// messages. - pub async fn node_connection(&self, node: NodeId) -> Result { + pub async fn node_connection(&self, node: NodeId) -> Result { // find latest generation if this is not generational node id let node = match node.as_generational() { Some(node) => node, @@ -56,20 +86,21 @@ impl Networking { } }; - self.connections.get_node_sender(node).await + Ok(self.connections.get_or_connect(node).await?.downgrade()) } } -impl NetworkSender for Networking { - #[instrument(level = "trace", skip(self, msg), fields(to = %msg.peer(), msg = ?msg.target()))] - async fn send(&self, mut msg: Outgoing) -> Result<(), NetworkSendError> +impl NetworkSender for Networking { + #[instrument(level = "trace", skip(self, msg), fields(to = %msg.peer(), msg = ?msg.body().target()))] + async fn send(&self, mut msg: Outgoing) -> Result<(), NetworkSendError>> where M: WireEncode + Targeted + Send + Sync, { let target_is_generational = msg.peer().is_generational(); - let original_peer = msg.peer(); + let original_peer = *msg.peer(); let mut attempts = 0; let mut retry_policy = self.options.connect_retry_policy.iter(); + let mut peer_as_generational = msg.peer().as_generational(); loop { // find latest generation if this is not generational node id. We do this in the loop // to ensure we get the latest if it has been updated since last attempt. @@ -82,7 +113,7 @@ impl NetworkSender for Networking { Ok(node) => node.current_generation, Err(e) => return Err(NetworkSendError::new(msg, NetworkError::UnknownNode(e))), }; - msg.set_peer(current_generation); + peer_as_generational = Some(current_generation); }; attempts += 1; @@ -104,87 +135,82 @@ impl NetworkSender for Networking { } } - let mut sender = { - // if we already know a connection to use, let's try that unless it's dropped. - if let Some(connection) = msg.get_connection() { - // let's try this connection - connection.sender(&self.metadata) - } else { - match self - .connections - .get_node_sender(msg.peer().as_generational().unwrap()) - .await - { - Ok(sender) => sender, - // retryable errors - Err( - e @ NetworkError::Timeout(_) - | e @ NetworkError::ConnectError(_) - | e @ NetworkError::ConnectionClosed, - ) => { - info!( - "Connection to node {} failed with {}, next retry is attempt {}/{}", - msg.peer(), - e, - attempts + 1, - self.options - .connect_retry_policy - .max_attempts() - .unwrap_or(NonZeroUsize::MAX), // max_attempts() be Some at this point - ); - continue; - } - // terminal errors - Err(NetworkError::OldPeerGeneration(e)) => { - if target_is_generational { - // Caller asked for this specific node generation and we know it's old. - return Err(NetworkSendError::new( - msg, - NetworkError::OldPeerGeneration(e), - )); - } - info!( - "Connection to node {} failed with {}, next retry is attempt {}/{}", - msg.peer(), - e, - attempts + 1, - self.options - .connect_retry_policy - .max_attempts() - .unwrap_or(NonZeroUsize::MAX), // max_attempts() be Some at this point - ); - continue; - } - Err(e) => { + let sender = { + match self + .connections + .get_or_connect(peer_as_generational.unwrap()) + .await + { + Ok(sender) => sender, + // retryable errors + Err( + e @ NetworkError::Timeout(_) + | e @ NetworkError::ConnectError(_) + | e @ NetworkError::ConnectionClosed, + ) => { + info!( + "Connection to node {} failed with {}, next retry is attempt {}/{}", + msg.peer(), + e, + attempts + 1, + self.options + .connect_retry_policy + .max_attempts() + .unwrap_or(NonZeroUsize::MAX), // max_attempts() be Some at this point + ); + continue; + } + // terminal errors + Err(NetworkError::OldPeerGeneration(e)) => { + if target_is_generational { + // Caller asked for this specific node generation and we know it's old. return Err(NetworkSendError::new( msg, - NetworkError::Unavailable(e.to_string()), - )) + NetworkError::OldPeerGeneration(e), + )); } + info!( + "Connection to node {} failed with {}, next retry is attempt {}/{}", + msg.peer(), + e, + attempts + 1, + self.options + .connect_retry_policy + .max_attempts() + .unwrap_or(NonZeroUsize::MAX), // max_attempts() be Some at this point + ); + continue; + } + Err(e) => { + return Err(NetworkSendError::new( + msg, + NetworkError::Unavailable(e.to_string()), + )) } } }; // can only fail due to codec errors or if connection is closed. Retry only if // connection closed. - match sender.send(msg).await { + let msg_with_connection = msg.assign_connection(sender.downgrade()); + match msg_with_connection.send().await { Ok(_) => return Ok(()), Err(NetworkSendError { - message, + original, source: NetworkError::ConnectionClosed, }) => { info!( "Sending message to node {} failed due to connection reset, next retry is attempt {}/{}", - message.peer(), + peer_as_generational.unwrap(), attempts + 1, self.options.connect_retry_policy.max_attempts().unwrap_or(NonZeroUsize::MAX), // max_attempts() be Some at this point ); - msg = message; + msg = original.forget_connection(); continue; } Err(e) => { return Err(NetworkSendError::new( - e.message, + e.original.forget_connection().set_peer(original_peer), NetworkError::Unavailable(e.source.to_string()), )) } @@ -193,4 +219,16 @@ impl NetworkSender for Networking { } } -static_assertions::assert_impl_all!(Networking: Send, Sync); +impl NetworkSender for Networking { + #[instrument(level = "trace", skip(self, msg), fields(to = %msg.peer(), msg = ?msg.body().target()))] + async fn send( + &self, + msg: Outgoing, + ) -> Result<(), NetworkSendError>> + where + M: WireEncode + Targeted + Send + Sync, + { + // connection is set. Just use it. + msg.send().await + } +} diff --git a/crates/core/src/network/rpc_router.rs b/crates/core/src/network/rpc_router.rs index 2b4d114ac..c7c7250a2 100644 --- a/crates/core/src/network/rpc_router.rs +++ b/crates/core/src/network/rpc_router.rs @@ -14,6 +14,7 @@ use dashmap::mapref::entry::Entry; use dashmap::DashMap; use futures::stream::BoxStream; use futures::StreamExt; +use restate_types::NodeId; use tokio::sync::oneshot; use tracing::{error, warn}; @@ -21,7 +22,8 @@ use restate_types::net::codec::{Targeted, WireDecode, WireEncode}; use restate_types::net::RpcRequest; use super::{ - Incoming, MessageHandler, MessageRouterBuilder, NetworkSendError, NetworkSender, Outgoing, + HasConnection, Incoming, MessageHandler, MessageRouterBuilder, NetworkSendError, NetworkSender, + Outgoing, }; use crate::{cancellation_watcher, ShutdownError}; @@ -32,11 +34,10 @@ use crate::{cancellation_watcher, ShutdownError}; /// /// This type is designed to be used by senders of RpcRequest(s). #[derive(Clone)] -pub struct RpcRouter +pub struct RpcRouter where T: RpcRequest, { - networking: N, response_tracker: ResponseTracker, } @@ -47,34 +48,52 @@ pub enum RpcError { Shutdown(#[from] ShutdownError), } -impl RpcRouter +impl RpcRouter where T: RpcRequest + WireEncode + Send + Sync + 'static, T::ResponseMessage: WireDecode + Send + Sync + 'static, - N: NetworkSender, { - pub fn new(networking: N, router_builder: &mut MessageRouterBuilder) -> Self { + pub fn new(router_builder: &mut MessageRouterBuilder) -> Self { let response_tracker = ResponseTracker::::default(); router_builder.add_message_handler(response_tracker.clone()); - Self { - networking, - response_tracker, - } + Self { response_tracker } } pub async fn call( &self, - msg: Outgoing, + network_sender: &impl NetworkSender, + peer: impl Into, + msg: T, ) -> Result, RpcError> { + let outgoing = Outgoing::new(peer, msg); let token = self .response_tracker - .new_token(msg.msg_id()) - .expect("msg-id is unique"); + .register(&outgoing) + .expect("msg-id is registered once"); - self.networking - .send(msg) + network_sender.send(outgoing).await.map_err(|e| { + RpcError::SendError(NetworkSendError::new( + Outgoing::into_body(e.original), + e.source, + )) + })?; + token + .recv() .await - .map_err(RpcError::SendError)?; + .map_err(|_| RpcError::Shutdown(ShutdownError)) + } + + /// Use this method when you have a connection associated with the outgoing request + pub async fn call_on_connection( + &self, + outgoing: Outgoing, + ) -> Result, RpcError>> { + let token = self + .response_tracker + .register(&outgoing) + .expect("msg-id is registered once"); + + outgoing.send().await.map_err(RpcError::SendError)?; token .recv() .await @@ -92,7 +111,7 @@ pub struct ResponseTracker where T: Targeted, { - inner: Arc>, + in_flight: Arc>>, } impl Clone for ResponseTracker @@ -101,27 +120,18 @@ where { fn clone(&self) -> Self { Self { - inner: self.inner.clone(), + in_flight: Arc::clone(&self.in_flight), } } } -struct Inner -where - T: Targeted, -{ - in_flight: DashMap>, -} - impl Default for ResponseTracker where T: Targeted, { fn default() -> Self { Self { - inner: Arc::new(Inner { - in_flight: Default::default(), - }), + in_flight: Default::default(), } } } @@ -131,50 +141,55 @@ where T: Targeted, { pub fn num_in_flight(&self) -> usize { - self.inner.in_flight.len() + self.in_flight.len() } /// Returns None if an in-flight request holds the same msg_id. - pub fn new_token(&self, msg_id: u64) -> Option> { - match self.inner.in_flight.entry(msg_id) { - Entry::Occupied(_) => { - error!( - "msg_id {:?} was already in-flight when this rpc was issued, this is an indicator that the msg_id is not unique across RPC calls", - msg_id - ); - None - } - Entry::Vacant(entry) => { - let (sender, receiver) = oneshot::channel(); - entry.insert(RpcTokenSender { sender }); - - Some(RpcToken { - msg_id, - router: Arc::downgrade(&self.inner), - receiver: Some(receiver), - }) - } - } + pub fn register(&self, outgoing: &Outgoing) -> Option> { + self.register_raw(outgoing.msg_id()) } - /// Handle a message through this response tracker. + /// Handle a message through this response tracker. Returns None on success or Some(incoming) + /// if the message doesn't correspond to an in-flight request. pub fn handle_message(&self, msg: Incoming) -> Option> { let Some(original_msg_id) = msg.in_response_to() else { warn!( - message_target = msg.kind(), + message_target = msg.body().kind(), "received a message with a `in_response_to` field unset! The message will be dropped", ); return None; }; // find the token and send, message is dropped on the floor if no valid match exist for the // msg id. - if let Some((_, token)) = self.inner.in_flight.remove(&original_msg_id) { + if let Some((_, token)) = self.in_flight.remove(&original_msg_id) { let _ = token.sender.send(msg); None } else { Some(msg) } } + + fn register_raw(&self, msg_id: u64) -> Option> { + match self.in_flight.entry(msg_id) { + Entry::Occupied(entry) => { + error!( + "msg_id {:?} was already in-flight when this rpc was issued, this is an indicator that the msg_id is not unique across RPC calls", + entry.key() + ); + None + } + Entry::Vacant(entry) => { + let (sender, receiver) = oneshot::channel(); + entry.insert(RpcTokenSender { sender }); + + Some(RpcToken { + msg_id, + router: Arc::downgrade(&self.in_flight), + receiver: Some(receiver), + }) + } + } + } } pub struct StreamingResponseTracker @@ -198,8 +213,12 @@ where } /// Returns None if an in-flight request holds the same msg_id. - pub fn new_token(&self, msg_id: u64) -> Option> { - self.flight_tracker.new_token(msg_id) + pub fn register(&self, outgoing: &Outgoing) -> Option> { + self.flight_tracker.register(outgoing) + } + + pub fn register_raw(&self, msg_id: u64) -> Option> { + self.flight_tracker.register_raw(msg_id) } /// Handles the next message. This will **return** the message if no correlated request is @@ -224,7 +243,7 @@ where T: Targeted, { msg_id: u64, - router: Weak>, + router: Weak>>, // This is Option to get around Rust's borrow checker rules when a type implements the Drop // trait. Without this, we cannot move receiver out. receiver: Option>>, @@ -271,7 +290,7 @@ where let Some(router) = self.router.upgrade() else { return; }; - let _ = router.in_flight.remove(&self.msg_id); + let _ = router.remove(&self.msg_id); } } @@ -292,6 +311,8 @@ where #[cfg(test)] mod test { + use crate::network::WeakConnection; + use super::*; use futures::future::join_all; use restate_types::net::{CodecError, TargetName}; @@ -326,12 +347,12 @@ mod test { async fn test_rpc_flight_tracker_drop() { let tracker = ResponseTracker::::default(); assert_eq!(tracker.num_in_flight(), 0); - let token = tracker.new_token(1).unwrap(); + let token = tracker.register_raw(1).unwrap(); assert_eq!(tracker.num_in_flight(), 1); drop(token); assert_eq!(tracker.num_in_flight(), 0); - let token = tracker.new_token(1).unwrap(); + let token = tracker.register_raw(1).unwrap(); assert_eq!(tracker.num_in_flight(), 1); // receive with timeout, this should drop the token let start = tokio::time::Instant::now(); @@ -346,17 +367,16 @@ mod test { async fn test_rpc_flight_tracker_send_recv() { let tracker = ResponseTracker::::default(); assert_eq!(tracker.num_in_flight(), 0); - let token = tracker.new_token(1).unwrap(); + let token = tracker.register_raw(1).unwrap(); assert_eq!(tracker.num_in_flight(), 1); // dropped on the floor tracker .on_message(Incoming::from_parts( - GenerationalNodeId::new(1, 1), TestResponse { text: "test".to_string(), }, - Weak::new(), + WeakConnection::new_closed(GenerationalNodeId::new(1, 1)), 1, Some(42), )) @@ -365,11 +385,10 @@ mod test { assert_eq!(tracker.num_in_flight(), 1); let maybe_msg = tracker.handle_message(Incoming::from_parts( - GenerationalNodeId::new(1, 1), TestResponse { text: "test".to_string(), }, - Weak::new(), + WeakConnection::new_closed(GenerationalNodeId::new(1, 1)), 1, Some(42), )); @@ -380,11 +399,10 @@ mod test { // matches msg id tracker .on_message(Incoming::from_parts( - GenerationalNodeId::new(1, 1), TestResponse { text: "a very real message".to_string(), }, - Weak::new(), + WeakConnection::new_closed(GenerationalNodeId::new(1, 1)), 1, Some(1), )) @@ -395,8 +413,8 @@ mod test { let msg = token.recv().await.unwrap(); assert_eq!(Some(1), msg.in_response_to()); - let (from, msg) = msg.split(); - assert_eq!(GenerationalNodeId::new(1, 1), from); + let (reciprocal, msg) = msg.split(); + assert_eq!(GenerationalNodeId::new(1, 1), *reciprocal.peer()); assert_eq!("a very real message", msg.text); } @@ -406,7 +424,11 @@ mod test { let response_tracker = ResponseTracker::default(); let rpc_tokens: Vec> = (0..num_responses) - .map(|idx| response_tracker.new_token(idx).expect("first time created")) + .map(|idx| { + response_tracker + .register_raw(idx) + .expect("first time created") + }) .collect(); let barrier = Arc::new(Barrier::new((2 * num_responses) as usize)); @@ -418,11 +440,10 @@ mod test { tokio::spawn(async move { barrier_handle_message.wait().await; response_tracker_handle_message.handle_message(Incoming::from_parts( - GenerationalNodeId::new(0, 0), TestResponse { text: format!("{}", idx), }, - Weak::new(), + WeakConnection::new_closed(GenerationalNodeId::new(0, 0)), 1, Some(idx), )); @@ -433,7 +454,7 @@ mod test { tokio::spawn(async move { barrier_new_token.wait().await; - response_tracker_new_token.new_token(idx); + response_tracker_new_token.register_raw(idx); }); } @@ -447,7 +468,7 @@ mod test { for result in results { assert_eq!( - Some(result.text.parse::().expect("valid u64")), + Some(result.body().text.parse::().expect("valid u64")), result.in_response_to() ); } diff --git a/crates/core/src/network/transport_connector.rs b/crates/core/src/network/transport_connector.rs new file mode 100644 index 000000000..9e34a8e32 --- /dev/null +++ b/crates/core/src/network/transport_connector.rs @@ -0,0 +1,246 @@ +// Copyright (c) 2024 - Restate Software, Inc., Restate GmbH. +// All rights reserved. +// +// Use of this software is governed by the Business Source License +// included in the LICENSE file. +// +// As of the Change Date specified in that file, in accordance with +// the Business Source License, use of this software will be governed +// by the Apache License, Version 2.0. + +use std::future::Future; + +use dashmap::DashMap; +use futures::{Stream, StreamExt}; +use tonic::transport::Channel; +use tracing::trace; + +use restate_types::config::NetworkingOptions; +use restate_types::net::AdvertisedAddress; +use restate_types::nodes_config::NodesConfiguration; +use restate_types::protobuf::node::Message; +use restate_types::GenerationalNodeId; + +use super::protobuf::node_svc::node_svc_client::NodeSvcClient; +use super::{NetworkError, ProtocolError}; +use crate::network::net_util::create_tonic_channel_from_advertised_address; + +pub trait TransportConnect: Send + Sync + 'static { + fn connect( + &self, + node_id: GenerationalNodeId, + nodes_config: &NodesConfiguration, + output_stream: impl Stream + Send + Unpin + 'static, + ) -> impl Future< + Output = Result< + impl Stream> + Send + Unpin + 'static, + NetworkError, + >, + > + Send; +} + +pub struct GrpcConnector { + _networking_options: NetworkingOptions, + channel_cache: DashMap, +} + +impl GrpcConnector { + pub fn new(_networking_options: NetworkingOptions) -> Self { + Self { + _networking_options, + channel_cache: DashMap::new(), + } + } +} + +impl TransportConnect for GrpcConnector { + async fn connect( + &self, + node_id: GenerationalNodeId, + nodes_config: &NodesConfiguration, + output_stream: impl Stream + Send + Unpin + 'static, + ) -> Result< + impl Stream> + Send + Unpin + 'static, + NetworkError, + > { + let address = nodes_config.find_node_by_id(node_id)?.address.clone(); + + trace!("Attempting to connect to node {} at {}", node_id, address); + // Do we have a channel in cache for this address? + let channel = { + if let dashmap::Entry::Vacant(entry) = self.channel_cache.entry(address.clone()) { + let channel = create_tonic_channel_from_advertised_address(address) + .map_err(|e| NetworkError::BadNodeAddress(node_id.into(), e))?; + entry.insert(channel.clone()); + channel + } else { + self.channel_cache.get(&address).unwrap().clone() + } + }; + + // Establish the connection + let mut client = NodeSvcClient::new(channel); + let incoming = client.create_connection(output_stream).await?.into_inner(); + Ok(incoming.map(|x| x.map_err(ProtocolError::from))) + } +} + +#[cfg(any(test, feature = "test-util"))] +pub mod test_util { + + use super::*; + + use std::sync::Arc; + use std::time::Instant; + + use futures::{Stream, StreamExt}; + use parking_lot::Mutex; + use tokio::sync::mpsc; + use tokio_stream::wrappers::ReceiverStream; + use tracing::info; + + use restate_types::nodes_config::NodesConfiguration; + use restate_types::protobuf::node::message::BinaryMessage; + use restate_types::protobuf::node::Message; + use restate_types::GenerationalNodeId; + + use super::{NetworkError, ProtocolError}; + use crate::network::{Incoming, MockPeerConnection, PartialPeerConnection, WeakConnection}; + use crate::{TaskCenter, TaskHandle, TaskKind}; + + #[derive(Clone)] + pub struct MockConnector { + pub sendbuf: usize, + pub new_connection_sender: mpsc::UnboundedSender, + } + + #[cfg(any(test, feature = "test-util"))] + impl MockConnector { + pub fn new(sendbuf: usize) -> (Self, mpsc::UnboundedReceiver) { + let (new_connection_sender, rx) = mpsc::unbounded_channel(); + ( + Self { + sendbuf, + new_connection_sender, + }, + rx, + ) + } + } + + impl TransportConnect for MockConnector { + async fn connect( + &self, + node_id: GenerationalNodeId, + nodes_config: &NodesConfiguration, + output_stream: impl Stream + Send + Unpin + 'static, + ) -> Result< + impl Stream> + Send + Unpin + 'static, + NetworkError, + > { + // validates that the node is known in the config + let current_generation = nodes_config.find_node_by_id(node_id)?.current_generation; + info!( + "Attempting to fake a connection to node {} and current_generation is {}", + node_id, current_generation + ); + + let (sender, rx) = mpsc::channel(self.sendbuf); + + let peer_connection = PartialPeerConnection { + my_node_id: node_id, + peer: crate::metadata().my_node_id(), + sender, + recv_stream: output_stream.boxed(), + created: Instant::now(), + }; + + let peer_connection = peer_connection.handshake(nodes_config).await.unwrap(); + + if self.new_connection_sender.send(peer_connection).is_err() { + // reciever has closed. Cannot accept connections + return Err(NetworkError::Unavailable(format!( + "MockConnector has been terminated, cannot connect to {}", + node_id + ))); + } + let incoming = ReceiverStream::new(rx).map(Ok); + Ok(incoming) + } + } + + /// Accepts all connections, performs handshake and sends all received messages to a single + /// stream + pub struct MessageCollectorMockConnector { + pub mock_connector: MockConnector, + pub tasks: Mutex>)>>, + } + + impl MessageCollectorMockConnector { + pub fn new( + task_center: TaskCenter, + sendbuf: usize, + sender: mpsc::Sender<(GenerationalNodeId, Incoming)>, + ) -> Arc { + let (mock_connector, mut new_connections) = MockConnector::new(sendbuf); + let connector = Arc::new(Self { + mock_connector, + tasks: Default::default(), + }); + + // start acceptor + let _ = task_center + .clone() + .spawn(TaskKind::TestRunner, "test-connection-acceptor", None, { + let connector = connector.clone(); + async move { + while let Some(connection) = new_connections.recv().await { + let (connection, task) = + connection.forward_to_sender(&task_center, sender.clone())?; + connector.tasks.lock().push((connection, task)); + } + Ok(()) + } + }) + .unwrap(); + connector + } + } + + impl TransportConnect for MessageCollectorMockConnector { + async fn connect( + &self, + node_id: GenerationalNodeId, + nodes_config: &NodesConfiguration, + output_stream: impl Stream + Send + Unpin + 'static, + ) -> Result< + impl Stream> + Send + Unpin + 'static, + NetworkError, + > { + self.mock_connector + .connect(node_id, nodes_config, output_stream) + .await + } + } + + /// Transport that fails all outgoing connections + #[derive(Default)] + pub struct FailingConnector {} + + #[cfg(any(test, feature = "test-util"))] + impl TransportConnect for FailingConnector { + async fn connect( + &self, + _node_id: GenerationalNodeId, + _nodes_config: &NodesConfiguration, + _output_stream: impl Stream + Send + Unpin + 'static, + ) -> Result< + impl Stream> + Send + Unpin + 'static, + NetworkError, + > { + Result::, _>::Err(NetworkError::ConnectError( + tonic::Status::unavailable("Trying to connect using failing transport"), + )) + } + } +} diff --git a/crates/core/src/network/types.rs b/crates/core/src/network/types.rs index dfe2cdc89..67ce0881e 100644 --- a/crates/core/src/network/types.rs +++ b/crates/core/src/network/types.rs @@ -8,9 +8,9 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. -use std::ops::{Deref, DerefMut}; use std::sync::atomic::AtomicU64; -use std::sync::{Arc, Weak}; +use std::sync::Arc; +use std::time::Instant; use restate_types::net::codec::{Targeted, WireEncode}; use restate_types::net::RpcRequest; @@ -18,8 +18,9 @@ use restate_types::{GenerationalNodeId, NodeId}; use crate::with_metadata; -use super::connection::{Connection, HeaderMetadataVersions}; -use super::{NetworkError, NetworkSendError}; +use super::connection::OwnedConnection; +use super::metric_definitions::CONNECTION_SEND_DURATION; +use super::{NetworkError, NetworkSendError, WeakConnection}; static NEXT_MSG_ID: AtomicU64 = const { AtomicU64::new(1) }; @@ -29,292 +30,345 @@ pub(crate) fn generate_msg_id() -> u64 { NEXT_MSG_ID.fetch_add(1, std::sync::atomic::Ordering::Relaxed) } -/// A wrapper for incoming messages that includes the sender information +macro_rules! bail_on_error { + ($original:ident, $expression:expr) => { + match $expression { + Ok(a) => a, + Err(e) => return Err(NetworkSendError::new($original, e)), + } + }; +} + +macro_rules! bail_on_none { + ($original:ident, $expression:expr, $err:expr) => { + match $expression { + Some(a) => a, + None => return Err(NetworkSendError::new($original, $err)), + } + }; +} + +// Using type-state pattern to model Outgoing +#[derive(Debug)] +pub struct HasConnection(WeakConnection); +#[derive(Debug)] +pub struct NoConnection(NodeId); + +pub(super) mod private { + use super::*; + + // Make sure that NetworkSender can be implemented on this set of types only. + pub trait Sealed {} + impl Sealed for HasConnection {} + impl Sealed for NoConnection {} +} + #[derive(Debug, Clone)] -pub struct Incoming { - peer: GenerationalNodeId, +struct MsgMeta { msg_id: u64, - connection: Weak, - body: M, in_response_to: Option, } -impl Deref for Incoming { - type Target = M; - fn deref(&self) -> &Self::Target { - &self.body - } -} - -impl DerefMut for Incoming { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.body - } +/// A wrapper for incoming messages that includes the sender information +#[derive(Debug, Clone)] +pub struct Incoming { + meta: MsgMeta, + connection: WeakConnection, + body: M, } impl Incoming { pub(crate) fn from_parts( - peer: GenerationalNodeId, body: M, - connection: Weak, + connection: WeakConnection, msg_id: u64, in_response_to: Option, ) -> Self { Self { - peer, - msg_id, connection, body, - in_response_to, + meta: MsgMeta { + msg_id, + in_response_to, + }, } } #[cfg(any(test, feature = "test-util"))] - pub fn for_testing(connection: &Arc, body: M, in_response_to: Option) -> Self { - let peer = connection.peer; - let connection = Arc::downgrade(connection); + pub fn for_testing(connection: WeakConnection, body: M, in_response_to: Option) -> Self { let msg_id = generate_msg_id(); - Self::from_parts(peer, body, connection, msg_id, in_response_to) + Self::from_parts(body, connection, msg_id, in_response_to) } -} -impl Incoming { - pub fn peer(&self) -> GenerationalNodeId { - self.peer + pub fn peer(&self) -> &GenerationalNodeId { + self.connection.peer() } - pub fn msg_id(&self) -> u64 { - self.msg_id + /// Dissolve this incoming into [`Reciprocal`] which can be used to prepare responses, and the + /// body of this incoming message. + pub fn split(self) -> (Reciprocal, M) { + let reciprocal = Reciprocal::new(self.connection, self.meta.msg_id); + (reciprocal, self.body) } - pub fn split(self) -> (GenerationalNodeId, M) { - (self.peer, self.body) + pub fn into_body(self) -> M { + self.body } pub fn body(&self) -> &M { &self.body } - pub fn into_body(self) -> M { - self.body + pub fn msg_id(&self) -> u64 { + self.meta.msg_id + } + + pub fn in_response_to(&self) -> Option { + self.meta.in_response_to + } + + /// Creates a reciprocal for this incoming message without consuming it. This will internall + /// clone the original connection reference. + pub fn create_reciprocal(&self) -> Reciprocal { + Reciprocal::new(self.connection.clone(), self.meta.msg_id) } pub fn try_map(self, f: impl FnOnce(M) -> Result) -> Result, E> { Ok(Incoming { - peer: self.peer, - msg_id: self.msg_id, connection: self.connection, body: f(self.body)?, - in_response_to: self.in_response_to, + meta: self.meta, }) } pub fn map(self, f: impl FnOnce(M) -> O) -> Incoming { Incoming { - peer: self.peer, - msg_id: self.msg_id, connection: self.connection, body: f(self.body), - in_response_to: self.in_response_to, - } - } - - pub fn in_response_to(&self) -> Option { - self.in_response_to - } - - pub fn prepare_response(&self, body: O) -> Outgoing { - Outgoing { - peer: self.peer.into(), - connection: self.connection.clone(), - msg_id: generate_msg_id(), - body, - in_response_to: Some(self.msg_id), + meta: self.meta, } } - /// Sends a response on the same connection where we received the request. This will - /// fail with [`NetworkError::ConnectionClosed`] if the connection is terminated. + /// Create an [`Outgoing`] to respond to this request. /// - /// This fails immediately with [`NetworkError::Full`] if connection stream is out of capacity. - pub fn try_respond( - &self, - response: O, - ) -> Result<(), NetworkSendError> { - self.prepare_response(response).try_send() - } - - /// Sends a response on the same connection where we received the request. This will - /// fail with [`NetworkError::ConnectionClosed`] if the connection is terminated. - /// - /// This blocks until there is capacity on the connection stream. - pub async fn respond( - &self, - response: O, - ) -> Result<(), NetworkSendError> { - self.prepare_response(response).send().await + /// Sending this outgoing will reuse the same connection where this message arrived + pub fn into_outgoing(self, body: O) -> Outgoing { + let reciprocal = Reciprocal::new(self.connection, self.meta.msg_id); + reciprocal.prepare(body) } } +/// Only available if this in RpcRequest for convenience. impl Incoming { - pub fn prepare_rpc_response( - &self, + pub fn to_rpc_response( + self, response: M::ResponseMessage, - ) -> Outgoing { - self.prepare_response(response) + ) -> Outgoing { + self.into_outgoing(response) } +} - /// Sends a response on the same connection where we received the request. This will - /// fail with [`NetworkError::ConnectionClosed`] if the connection is terminated. - /// - /// This fails immediately with [`NetworkError::Full`] if connection stream is out of capacity. - pub fn try_respond_rpc( - &self, - response: M::ResponseMessage, - ) -> Result<(), NetworkSendError> { - self.prepare_response(response).try_send() +/// A type that represents a potential response (reciprocal to a request) that can be converted +/// into `Outgoing` once a message is ready. An [`Outgoing`] can be created with `prepare(body)` +#[derive(Debug)] +pub struct Reciprocal { + connection: WeakConnection, + in_response_to: u64, +} + +impl Reciprocal { + pub(crate) fn new(connection: WeakConnection, in_response_to: u64) -> Self { + Self { + connection, + in_response_to, + } } - /// Sends a response on the same connection where we received the request. This will - /// fail with [`NetworkError::ConnectionClosed`] if the connection is terminated. - /// - /// This blocks until there is capacity on the connection stream. - pub async fn respond_rpc( - &self, - response: M::ResponseMessage, - ) -> Result<(), NetworkSendError> { - self.prepare_response(response).send().await + pub fn peer(&self) -> &GenerationalNodeId { + self.connection.peer() + } + + /// Package this reciprocal as a ready-to-use Outgoing message that holds the connection + /// reference and the original message_id to response to. + pub fn prepare(self, body: O) -> Outgoing { + Outgoing { + connection: HasConnection(self.connection), + body, + meta: MsgMeta { + msg_id: generate_msg_id(), + in_response_to: Some(self.in_response_to), + }, + } } } /// A wrapper for outgoing messages that includes the correlation information if a message is in /// response to a request. #[derive(Debug, Clone)] -pub struct Outgoing { - peer: NodeId, - msg_id: u64, - connection: Weak, +pub struct Outgoing { + connection: State, body: M, - in_response_to: Option, -} - -impl Deref for Outgoing { - type Target = M; - fn deref(&self) -> &Self::Target { - &self.body - } + meta: MsgMeta, } -impl Outgoing { +impl Outgoing { pub fn new(peer: impl Into, body: M) -> Self { - let msg_id = generate_msg_id(); - Self { - peer: peer.into(), - msg_id, - connection: Weak::new(), - body, - in_response_to: None, - } - } - - pub fn from_parts(peer: NodeId, body: M, msg_id: u64, in_response_to: Option) -> Self { - Self { - peer, - msg_id, - connection: Weak::new(), + Outgoing { + connection: NoConnection(peer.into()), body, - in_response_to, + meta: MsgMeta { + msg_id: generate_msg_id(), + in_response_to: None, + }, } } } -impl Outgoing { - pub fn peer(&self) -> NodeId { - self.peer - } - pub fn set_peer(&mut self, peer: impl Into) { - self.peer = peer.into(); - // unset connection - self.reset_connection(); - } - - pub fn msg_id(&self) -> u64 { - self.msg_id - } - - pub(crate) fn get_connection(&self) -> Option> { - self.connection.upgrade() - } - - /// Detaches this message from the associated connection (if set). This allows this message to - /// be sent on any connection if NetworkSender is used to send this message. - pub fn reset_connection(&mut self) { - self.connection = Weak::new(); +impl Outgoing { + pub fn into_body(self) -> M { + self.body } pub fn body(&self) -> &M { &self.body } - pub fn into_body(self) -> M { - self.body + pub fn body_mut(&mut self) -> &mut M { + &mut self.body + } + + pub fn msg_id(&self) -> u64 { + self.meta.msg_id } - /// If this is a response to a request, what is the message id of the original request? pub fn in_response_to(&self) -> Option { - self.in_response_to + self.meta.in_response_to } - pub fn try_map(self, f: impl FnOnce(M) -> Result) -> Result, E> { + pub fn try_map(self, f: impl FnOnce(M) -> Result) -> Result, E> { Ok(Outgoing { - peer: self.peer, - msg_id: self.msg_id, connection: self.connection, body: f(self.body)?, - in_response_to: self.in_response_to, + meta: self.meta, }) } - pub fn map(self, f: impl FnOnce(M) -> O) -> Outgoing { + pub fn map(self, f: impl FnOnce(M) -> O) -> Outgoing { Outgoing { - peer: self.peer, - msg_id: self.msg_id, connection: self.connection, body: f(self.body), - in_response_to: self.in_response_to, + meta: self.meta, } } } -impl Outgoing { - /// Sends a response on the same connection where we received the request. This will - /// fail with [`NetworkError::ConnectionClosed`] if the connection is terminated. +/// Only available if this outgoing is pinned to a connection +impl Outgoing { + pub fn peer(&self) -> &GenerationalNodeId { + self.connection.0.peer() + } + + /// Unpins this outgoing from the connection. Note that the outgoing will still be pinned the + /// specific GenerationalNodeId originally associated with the connection. If you want to unset + /// the peer to use any generation. Call `to_any_generation()` on the returned Outgoing value. + pub fn forget_connection(self) -> Outgoing { + Outgoing { + connection: NoConnection((*self.peer()).into()), + body: self.body, + meta: self.meta, + } + } +} + +/// Only available if this outgoing is **not** pinned to a connection +impl Outgoing { + pub fn peer(&self) -> &NodeId { + &self.connection.0 + } + + pub fn set_peer(self, peer: NodeId) -> Self { + Self { + connection: NoConnection(peer), + ..self + } + } + + /// Ensures that this outgoing is not pinned to a specific node generation. + pub fn to_any_generation(self) -> Self { + Self { + connection: NoConnection(self.peer().id().into()), + ..self + } + } + + /// Panics (debug assertion) if connection doesn't match the plain node Id of the original message + pub fn assign_connection(self, connection: WeakConnection) -> Outgoing { + debug_assert_eq!(self.connection.0.id(), connection.peer().as_plain()); + Outgoing { + connection: HasConnection(connection), + body: self.body, + meta: self.meta, + } + } +} + +impl Outgoing { + /// Send a message on this connection. /// /// This blocks until there is capacity on the connection stream. - pub async fn send(self) -> Result<(), NetworkSendError> { - let (connection, versions, outgoing) = self.prepare_send()?; - connection.send(outgoing, versions).await + /// + /// This returns Ok(()) when the message is: + /// - Successfully serialized to the wire format based on the negotiated protocol + /// - Serialized message was enqueued on the send buffer of the socket + /// + /// That means that this is not a guarantee that the message has been sent + /// over the network or that the peer has received it. + /// + /// If this is needed, the caller must design the wire protocol with a + /// request/response state machine and perform retries on other nodes/connections if needed. + /// + /// This roughly maps to the semantics of a POSIX write/send socket operation. + /// + /// This doesn't auto-retry connection resets or send errors, this is up to the user + /// for retrying externally. + // #[instrument(level = "trace", skip_all, fields(peer_node_id = %self.peer, target_service = ?message.target(), msg = ?message.kind()))] + pub async fn send(self) -> Result<(), NetworkSendError> { + let send_start = Instant::now(); + let connection = bail_on_error!(self, self.try_upgrade()); + let permit = bail_on_none!( + self, + connection.reserve().await, + NetworkError::ConnectionClosed + ); + + with_metadata(|metadata| { + permit.send(self, metadata); + }); + CONNECTION_SEND_DURATION.record(send_start.elapsed()); + Ok(()) } /// Sends a response on the same connection where we received the request. This will /// fail with [`NetworkError::ConnectionClosed`] if the connection is terminated. /// /// This fails immediately with [`NetworkError::Full`] if connection stream is out of capacity. - pub fn try_send(self) -> Result<(), NetworkSendError> { - let (connection, versions, outgoing) = self.prepare_send()?; - connection.try_send(outgoing, versions) + pub fn try_send(self) -> Result<(), NetworkSendError> { + let send_start = Instant::now(); + let connection = bail_on_error!(self, self.try_upgrade()); + let permit = bail_on_error!(self, connection.try_reserve()); + + with_metadata(|metadata| { + permit.send(self, metadata); + }); + + CONNECTION_SEND_DURATION.record(send_start.elapsed()); + Ok(()) } - fn prepare_send( - self, - ) -> Result<(Arc, HeaderMetadataVersions, Self), NetworkSendError> { - let connection = match self.connection.upgrade() { - Some(connection) => connection, - None => { - return Err(NetworkSendError::new(self, NetworkError::ConnectionClosed)); - } - }; - let versions = with_metadata(HeaderMetadataVersions::from_metadata).unwrap_or_default(); - Ok((connection, versions, self)) + fn try_upgrade(&self) -> Result, NetworkError> { + match self.connection.0.connection.upgrade() { + Some(connection) => Ok(connection), + None => Err(NetworkError::ConnectionClosed), + } } } diff --git a/crates/core/src/task_center.rs b/crates/core/src/task_center.rs index 24c340a62..6fcb63725 100644 --- a/crates/core/src/task_center.rs +++ b/crates/core/src/task_center.rs @@ -1002,11 +1002,16 @@ pub fn metadata() -> Metadata { } #[track_caller] -pub fn with_metadata(f: F) -> Option +pub fn with_metadata(f: F) -> R where F: FnOnce(&Metadata) -> R, { - CONTEXT.with(|ctx| ctx.metadata.as_ref().map(f)) + CONTEXT.with(|ctx| { + f(ctx + .metadata + .as_ref() + .expect("metadata() is set in this task. Is global metadata set?")) + }) } /// Access to this node id. This is available in task-center tasks only! diff --git a/crates/core/src/test_env.rs b/crates/core/src/test_env.rs index fc9cd4a3e..cd7752b72 100644 --- a/crates/core/src/test_env.rs +++ b/crates/core/src/test_env.rs @@ -10,236 +10,94 @@ use std::marker::PhantomData; use std::str::FromStr; -use std::sync::{Arc, Weak}; +use std::sync::Arc; -use tokio::sync::{mpsc, RwLock}; -use tracing::info; +use futures::Stream; use restate_types::cluster_controller::{ReplicationStrategy, SchedulingPlan}; +use restate_types::config::NetworkingOptions; use restate_types::logs::metadata::{bootstrap_logs_metadata, ProviderKind}; use restate_types::metadata_store::keys::{ BIFROST_CONFIG_KEY, NODES_CONFIG_KEY, PARTITION_TABLE_KEY, SCHEDULING_PLAN_KEY, }; -use restate_types::net::codec::{ - serialize_message, MessageBodyExt, Targeted, WireDecode, WireEncode, -}; +use restate_types::net::codec::{Targeted, WireDecode}; use restate_types::net::metadata::MetadataKind; use restate_types::net::AdvertisedAddress; -use restate_types::net::CURRENT_PROTOCOL_VERSION; use restate_types::nodes_config::{LogServerConfig, NodeConfig, NodesConfiguration, Role}; use restate_types::partition_table::PartitionTable; -use restate_types::protobuf::node::{Header, Message}; +use restate_types::protobuf::node::Message; use restate_types::{GenerationalNodeId, Version}; use crate::metadata_store::{MetadataStoreClient, Precondition}; use crate::network::{ - Connection, Handler, Incoming, MessageHandler, MessageRouter, MessageRouterBuilder, - NetworkError, NetworkSendError, NetworkSender, Outgoing, -}; -use crate::{ - cancellation_watcher, metadata, spawn_metadata_manager, MetadataBuilder, ShutdownError, TaskId, + ConnectionManager, FailingConnector, Incoming, MessageHandler, MessageRouterBuilder, + NetworkError, Networking, ProtocolError, TransportConnect, }; +use crate::{spawn_metadata_manager, MetadataBuilder, TaskId}; use crate::{Metadata, MetadataManager, MetadataWriter}; use crate::{TaskCenter, TaskCenterBuilder}; -#[derive(Clone)] -pub struct MockNetworkSender { - sender: Option>, - metadata: Metadata, -} - -impl MockNetworkSender { - pub fn new(metadata: Metadata) -> Self { - Self { - sender: None, - metadata, - } - } -} - -impl NetworkSender for MockNetworkSender { - async fn send(&self, mut message: Outgoing) -> Result<(), NetworkSendError> - where - M: WireEncode + Targeted + Send + Sync, - { - let Some(sender) = &self.sender else { - info!("Not sending message, mock sender is not configured"); - return Ok(()); - }; - - if !message.peer().is_generational() { - let current_generation = match self - .metadata - .nodes_config_ref() - .find_node_by_id(message.peer()) - { - Ok(node) => node.current_generation, - Err(e) => return Err(NetworkSendError::new(message, NetworkError::UnknownNode(e))), - }; - message.set_peer(current_generation); - } - - let metadata = metadata(); - let header = Header::new( - metadata.nodes_config_version(), - None, - None, - None, - message.msg_id(), - message.in_response_to(), - ); - let body = match serialize_message(message.body(), CURRENT_PROTOCOL_VERSION) { - Ok(body) => body, - Err(e) => { - return Err(NetworkSendError::new( - message, - NetworkError::ProtocolError(e.into()), - )) - } - }; - sender - .send(( - message.peer().as_generational().unwrap(), - Message::new(header, body), - )) - .map_err(|_| NetworkSendError::new(message, NetworkError::Shutdown(ShutdownError)))?; - Ok(()) - } -} - -#[derive(Default)] -struct NetworkReceiver { - router: Arc>, -} - -impl NetworkReceiver { - async fn run( - self, - my_node_id: GenerationalNodeId, - mut receiver: mpsc::UnboundedReceiver<(GenerationalNodeId, Message)>, - ) -> anyhow::Result<()> { - let (reply_sender, mut reply_receiver) = mpsc::channel::(50); - // NOTE: rpc replies will only work if and only if the RpcRouter is using the same router_builder as the service you - // are trying to call. - // In other words, response will not be routed back if the Target component is registered with a different router_builder - // than the client you using to make the call. - let connection = Connection::new_fake(my_node_id, CURRENT_PROTOCOL_VERSION, reply_sender); - - loop { - tokio::select! { - _ = cancellation_watcher() => { - break; - } - maybe_msg = reply_receiver.recv() => { - let Some(msg) = maybe_msg else { - break; - }; - - let guard = self.router.read().await; - self.route_message(my_node_id, msg, &guard, Weak::new()).await?; - } - maybe_msg = receiver.recv() => { - let Some((from, msg)) = maybe_msg else { - break; - }; - { - let guard = self.router.read().await; - self.route_message(from, msg, &guard, Arc::downgrade(&connection)).await?; - } - } - } - } - Ok(()) - } - - async fn route_message( - &self, - peer: GenerationalNodeId, - msg: Message, - router: &MessageRouter, - connection: Weak, - ) -> anyhow::Result<()> { - let body = msg.body.expect("body must be set"); - let header = msg.header.expect("header must be set"); - - let msg = Incoming::from_parts( - peer, - body.try_as_binary_body(CURRENT_PROTOCOL_VERSION)?, - connection, - header.msg_id, - header.in_response_to, - ); - router.call(msg, CURRENT_PROTOCOL_VERSION).await?; - Ok(()) - } -} - -impl MockNetworkSender { - pub fn from_sender( - sender: mpsc::UnboundedSender<(GenerationalNodeId, Message)>, - metadata: Metadata, - ) -> Self { - Self { - sender: Some(sender), - metadata, - } - } - pub fn inner_sender(&self) -> Option> { - self.sender.clone() - } -} - -pub struct TestCoreEnvBuilder { +pub struct TestCoreEnvBuilder { pub tc: TaskCenter, pub my_node_id: GenerationalNodeId, - pub network_rx: Option>, - pub metadata_manager: MetadataManager, + pub metadata_manager: MetadataManager, pub metadata_writer: MetadataWriter, pub metadata: Metadata, + pub networking: Networking, pub nodes_config: NodesConfiguration, pub provider_kind: ProviderKind, pub router_builder: MessageRouterBuilder, - pub network_sender: N, pub partition_table: PartitionTable, pub scheduling_plan: SchedulingPlan, pub metadata_store_client: MetadataStoreClient, } -impl TestCoreEnvBuilder { - pub fn new_with_mock_network() -> TestCoreEnvBuilder { - let (tx, rx) = mpsc::unbounded_channel(); +impl TestCoreEnvBuilder { + pub fn with_incoming_only_connector() -> Self { + let tc = TaskCenterBuilder::default() + .default_runtime_handle(tokio::runtime::Handle::current()) + .ingress_runtime_handle(tokio::runtime::Handle::current()) + .build() + .expect("task_center builds"); let metadata_builder = MetadataBuilder::default(); - let network_sender = MockNetworkSender::from_sender(tx, metadata_builder.to_metadata()); + let net_opts = NetworkingOptions::default(); + let connection_manager = + ConnectionManager::new_incoming_only(metadata_builder.to_metadata()); + let networking = Networking::with_connection_manager( + metadata_builder.to_metadata(), + net_opts, + connection_manager, + ); - TestCoreEnvBuilder::new_with_network_tx_rx(network_sender, Some(rx), metadata_builder) + TestCoreEnvBuilder::with_networking(tc, networking, metadata_builder) } } +impl TestCoreEnvBuilder { + pub fn with_transport_connector(tc: TaskCenter, connector: Arc) -> TestCoreEnvBuilder { + let metadata_builder = MetadataBuilder::default(); + let net_opts = NetworkingOptions::default(); + let connection_manager = + ConnectionManager::new(metadata_builder.to_metadata(), connector, net_opts.clone()); + let networking = Networking::with_connection_manager( + metadata_builder.to_metadata(), + net_opts, + connection_manager, + ); -impl TestCoreEnvBuilder -where - N: NetworkSender + 'static, -{ - pub fn new(network_sender: N, metadata_builder: MetadataBuilder) -> Self { - TestCoreEnvBuilder::new_with_network_tx_rx(network_sender, None, metadata_builder) + TestCoreEnvBuilder::with_networking(tc, networking, metadata_builder) } - fn new_with_network_tx_rx( - network_sender: N, - network_rx: Option>, + pub fn with_networking( + tc: TaskCenter, + networking: Networking, metadata_builder: MetadataBuilder, ) -> Self { - let tc = TaskCenterBuilder::default() - .default_runtime_handle(tokio::runtime::Handle::current()) - .ingress_runtime_handle(tokio::runtime::Handle::current()) - .build() - .expect("task_center builds"); - let my_node_id = GenerationalNodeId::new(1, 1); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); let metadata_manager = MetadataManager::new( metadata_builder, - network_sender.clone(), + networking.clone(), metadata_store_client.clone(), ); let metadata_writer = metadata_manager.writer(); @@ -259,11 +117,10 @@ where TestCoreEnvBuilder { tc, my_node_id, - network_rx, metadata_manager, metadata_writer, metadata, - network_sender, + networking, nodes_config, router_builder, partition_table, @@ -273,17 +130,17 @@ where } } - pub fn with_nodes_config(mut self, nodes_config: NodesConfiguration) -> Self { + pub fn set_nodes_config(mut self, nodes_config: NodesConfiguration) -> Self { self.nodes_config = nodes_config; self } - pub fn with_partition_table(mut self, partition_table: PartitionTable) -> Self { + pub fn set_partition_table(mut self, partition_table: PartitionTable) -> Self { self.partition_table = partition_table; self } - pub fn with_scheduling_plan(mut self, scheduling_plan: SchedulingPlan) -> Self { + pub fn set_scheduling_plan(mut self, scheduling_plan: SchedulingPlan) -> Self { self.scheduling_plan = scheduling_plan; self } @@ -312,33 +169,16 @@ where self } - pub async fn build(mut self) -> TestCoreEnv { + pub async fn build(mut self) -> TestCoreEnv { self.metadata_manager .register_in_message_router(&mut self.router_builder); + self.networking + .connection_manager() + .set_message_router(self.router_builder.build()); - let router = Arc::new(RwLock::new(self.router_builder.build())); let metadata_manager_task = spawn_metadata_manager(&self.tc, self.metadata_manager) .expect("metadata manager should start"); - let network_task = match self.network_rx { - Some(network_rx) => { - let network_receiver = NetworkReceiver { - router: router.clone(), - }; - let network_task = self - .tc - .spawn( - crate::TaskKind::ConnectionReactor, - "test-network-receiver", - None, - async move { network_receiver.run(self.my_node_id, network_rx).await }, - ) - .unwrap(); - Some(network_task) - } - None => None, - }; - self.metadata_store_client .put( NODES_CONFIG_KEY.clone(), @@ -392,32 +232,28 @@ where TestCoreEnv { tc: self.tc, - network_task, metadata: self.metadata, metadata_manager_task, metadata_writer: self.metadata_writer, - network_sender: self.network_sender, - router, + networking: self.networking, metadata_store_client: self.metadata_store_client, } } } // This might need to be moved to a better place in the future. -pub struct TestCoreEnv { +pub struct TestCoreEnv { pub tc: TaskCenter, pub metadata: Metadata, pub metadata_writer: MetadataWriter, - pub network_sender: N, - pub network_task: Option, + pub networking: Networking, pub metadata_manager_task: TaskId, - pub router: Arc>, pub metadata_store_client: MetadataStoreClient, } -impl TestCoreEnv { - pub async fn create_with_mock_nodes_config(node_id: u32, generation: u32) -> Self { - TestCoreEnvBuilder::new_with_mock_network() +impl TestCoreEnv { + pub async fn create_with_single_node(node_id: u32, generation: u32) -> Self { + TestCoreEnvBuilder::with_incoming_only_connector() .set_my_node_id(GenerationalNodeId::new(node_id, generation)) .add_mock_nodes_config() .build() @@ -425,13 +261,18 @@ impl TestCoreEnv { } } -impl TestCoreEnv -where - N: NetworkSender, -{ - pub async fn set_message_router(&mut self, router: MessageRouter) { - let mut guard = self.router.write().await; - *guard = router; +impl TestCoreEnv { + pub async fn accept_incoming_connection( + &self, + incoming: S, + ) -> Result + Unpin + Send + 'static, NetworkError> + where + S: Stream> + Unpin + Send + 'static, + { + self.networking + .connection_manager() + .accept_incoming_connection(incoming) + .await } } diff --git a/crates/ingress-dispatcher/src/dispatcher.rs b/crates/ingress-dispatcher/src/dispatcher.rs index fd3fb5a10..d7d55abb7 100644 --- a/crates/ingress-dispatcher/src/dispatcher.rs +++ b/crates/ingress-dispatcher/src/dispatcher.rs @@ -146,8 +146,12 @@ impl MessageHandler for IngressDispatcher { type MessageType = IngressMessage; async fn on_message(&self, msg: Incoming) { - let (peer, msg) = msg.split(); - trace!("Processing message '{}' from '{}'", msg.kind(), peer); + let (reciprocal, msg) = msg.split(); + trace!( + "Processing message '{}' from '{}'", + msg.kind(), + reciprocal.peer() + ); match msg { IngressMessage::InvocationResponse(invocation_response) => { @@ -170,7 +174,7 @@ impl MessageHandler for IngressDispatcher { ); } else { trace!( - partition_processor_peer = %peer, + partition_processor_peer = %reciprocal.peer(), "Sent response of invocation {:?} out", invocation_response.invocation_id ); @@ -199,7 +203,7 @@ impl MessageHandler for IngressDispatcher { } else { trace!( restate.invocation.id = %attach_idempotent_invocation.original_invocation_id, - partition_processor_peer = %peer, + partition_processor_peer = %reciprocal.peer(), "Sent response of invocation out" ); } @@ -271,9 +275,9 @@ mod tests { #[test(tokio::test)] async fn idempotent_invoke() -> anyhow::Result<()> { // set it to 1 partition so that we know where the invocation for the IdempotentInvoker goes to - let mut env_builder = TestCoreEnvBuilder::new_with_mock_network() + let mut env_builder = TestCoreEnvBuilder::with_incoming_only_connector() .add_mock_nodes_config() - .with_partition_table(PartitionTable::with_equally_sized_partitions( + .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, )); @@ -353,7 +357,7 @@ mod tests { // Now check we get the response is routed back to the handler correctly let response = Bytes::from_static(b"vmoaifnuei"); node_env - .network_sender + .networking .send(Outgoing::new( metadata().my_node_id(), IngressMessage::InvocationResponse(InvocationResponse { @@ -385,9 +389,9 @@ mod tests { #[test(tokio::test)] async fn attach_invocation() { // set it to 1 partition so that we know where the invocation for the IdempotentInvoker goes to - let mut env_builder = TestCoreEnvBuilder::new_with_mock_network() + let mut env_builder = TestCoreEnvBuilder::with_incoming_only_connector() .add_mock_nodes_config() - .with_partition_table(PartitionTable::with_equally_sized_partitions( + .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, )); @@ -443,7 +447,7 @@ mod tests { // Now send the attach response let response = Bytes::from_static(b"vmoaifnuei"); node_env - .network_sender + .networking .send(Outgoing::new( metadata().my_node_id(), IngressMessage::InvocationResponse(InvocationResponse { diff --git a/crates/ingress-http/src/handler/tests.rs b/crates/ingress-http/src/handler/tests.rs index ead18e550..39808aff3 100644 --- a/crates/ingress-http/src/handler/tests.rs +++ b/crates/ingress-http/src/handler/tests.rs @@ -1083,7 +1083,7 @@ where ::Error: std::error::Error + Send + Sync + 'static, ::Data: Send + Sync + 'static, { - let node_env = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let node_env = TestCoreEnv::create_with_single_node(1, 1).await; let (ingress_request_tx, mut ingress_request_rx) = mpsc::unbounded_channel(); let dispatcher = MockDispatcher::new(ingress_request_tx); diff --git a/crates/ingress-http/src/server.rs b/crates/ingress-http/src/server.rs index 19b9d7710..ccd730762 100644 --- a/crates/ingress-http/src/server.rs +++ b/crates/ingress-http/src/server.rs @@ -339,7 +339,7 @@ mod tests { JoinHandle>, TestHandle, ) { - let node_env = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let node_env = TestCoreEnv::create_with_single_node(1, 1).await; let (ingress_request_tx, mut ingress_request_rx) = mpsc::unbounded_channel(); // Create the ingress and start it diff --git a/crates/invoker-impl/src/lib.rs b/crates/invoker-impl/src/lib.rs index a59f5b38e..0558cb6e2 100644 --- a/crates/invoker-impl/src/lib.rs +++ b/crates/invoker-impl/src/lib.rs @@ -1145,7 +1145,7 @@ mod tests { #[test(tokio::test)] async fn input_order_is_maintained() { - let node_env = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let node_env = TestCoreEnv::create_with_single_node(1, 1).await; let tc = node_env.tc; let invoker_options = InvokerOptionsBuilder::default() // fixed amount of retries so that an invocation eventually completes with a failure diff --git a/crates/log-server/src/loglet_worker.rs b/crates/log-server/src/loglet_worker.rs index b41e2d4cf..0783da300 100644 --- a/crates/log-server/src/loglet_worker.rs +++ b/crates/log-server/src/loglet_worker.rs @@ -190,27 +190,27 @@ impl LogletWorker { } // RELEASE Some(msg) = release_rx.recv() => { - self.global_tail_tracker.maybe_update(msg.known_global_tail); - known_global_tail = known_global_tail.max(msg.known_global_tail); + self.global_tail_tracker.maybe_update(msg.body().known_global_tail); + known_global_tail = known_global_tail.max(msg.body().known_global_tail); } Some(msg) = seal_rx.recv() => { + let (reciprocal, msg) = msg.split(); // this message might be telling us about a higher `known_global_tail` self.global_tail_tracker.maybe_update(msg.known_global_tail); known_global_tail = known_global_tail.max(msg.known_global_tail); // If we have a seal operation in-flight, we'd want this request to wait for // seal to happen - let response = msg.prepare_response(Sealed::empty()); let tail_watcher = self.loglet_state.get_tail_watch(); waiting_for_seal.push(async move { let seal_watcher = tail_watcher.wait_for_seal(); if seal_watcher.await.is_ok() { - let msg = Sealed::new(*tail_watcher.get()).with_status(Status::Ok); - let response = response.map(|_| msg); + let body = Sealed::new(*tail_watcher.get()).with_status(Status::Ok); + let response = reciprocal.prepare(body); // send the response over the network let _ = response.send().await; } }); - let seal_token = self.process_seal(msg.into_body(), &mut sealing_in_progress).await; + let seal_token = self.process_seal(msg, &mut sealing_in_progress).await; if let Some(seal_token) = seal_token { in_flight_seal.set(Some(seal_token).into()); } @@ -218,36 +218,36 @@ impl LogletWorker { } // GET_LOGLET_INFO Some(msg) = get_loglet_info_rx.recv() => { - self.global_tail_tracker.maybe_update(msg.known_global_tail); - known_global_tail = known_global_tail.max(msg.known_global_tail); + self.global_tail_tracker.maybe_update(msg.body().known_global_tail); + known_global_tail = known_global_tail.max(msg.body().known_global_tail); // drop response if connection is lost/congested - if let Err(e) = msg.try_respond_rpc(LogletInfo::new(self.loglet_state.local_tail(), self.loglet_state.trim_point())) { - debug!(?e.source, peer = %msg.peer(), "Failed to respond to GetLogletInfo message due to peer channel capacity being full"); + let peer = *msg.peer(); + if let Err(e) = msg.to_rpc_response(LogletInfo::new(self.loglet_state.local_tail(), self.loglet_state.trim_point())).try_send() { + debug!(?e.source, peer = %peer, "Failed to respond to GetLogletInfo message due to peer channel capacity being full"); } } // GET_RECORDS Some(msg) = get_records_rx.recv() => { - self.global_tail_tracker.maybe_update(msg.known_global_tail); - known_global_tail = known_global_tail.max(msg.known_global_tail); + self.global_tail_tracker.maybe_update(msg.body().known_global_tail); + known_global_tail = known_global_tail.max(msg.body().known_global_tail); // read responses are spawned as disposable tasks self.process_get_records(msg).await; } // TRIM Some(msg) = trim_rx.recv() => { - self.global_tail_tracker.maybe_update(msg.known_global_tail); - known_global_tail = known_global_tail.max(msg.known_global_tail); - self.process_trim(msg, known_global_tail).await; + self.global_tail_tracker.maybe_update(msg.body().known_global_tail); + known_global_tail = known_global_tail.max(msg.body().known_global_tail); + self.process_trim(msg, known_global_tail); } // STORE Some(msg) = store_rx.recv() => { + let (reciprocal, msg) = msg.split(); // this message might be telling us about a higher `known_global_tail` self.global_tail_tracker.maybe_update(msg.known_global_tail); known_global_tail = known_global_tail.max(msg.known_global_tail); let next_ok_offset = std::cmp::max(staging_local_tail, known_global_tail ); - let response = - msg.prepare_response(Stored::empty()); - let peer = msg.peer(); - let (status, maybe_store_token) = self.process_store(peer, msg.into_body(), &mut staging_local_tail, next_ok_offset, &sealing_in_progress).await; + let peer = *reciprocal.peer(); + let (status, maybe_store_token) = self.process_store(peer, msg, &mut staging_local_tail, next_ok_offset, &sealing_in_progress).await; // if this store is complete, the last committed is updated to this value. let future_last_committed = staging_local_tail; if let Some(store_token) = maybe_store_token { @@ -262,23 +262,22 @@ impl LogletWorker { local_tail_watch.notify_offset_update(future_last_committed); // ignoring the error if we couldn't send the response let msg = Stored::new(*local_tail_watch.get()).with_status(status); - let response = response.map(|_| msg); + let response = reciprocal.prepare(msg); // send the response over the network let _ = response.send().await; } Err(e) => { // log-store in failsafe mode and cannot process stores anymore. warn!(?e, "Log-store is in failsafe mode, dropping store"); - let response = response.map(|msg| msg.with_status(Status::Disabled)); - let _ = response.send().await; + let _ = reciprocal.prepare(Stored::empty()).send().await; } } }); } else { // we didn't store, let's respond immediately with status let msg = Stored::new(self.loglet_state.local_tail()).with_status(status); + let response = reciprocal.prepare(msg); in_flight_network_sends.push(async move { - let response = response.map(|_| msg); // ignore send errors. let _ = response.send().await; }); @@ -386,28 +385,28 @@ impl LogletWorker { let _ = self .task_center .spawn(TaskKind::Disposable, "loglet-read", None, async move { + let (reciprocal, msg) = msg.split(); + let from_offset = msg.from_offset; // validate that from_offset <= to_offset if msg.from_offset > msg.to_offset { - let response = msg.prepare_response(Records::empty(msg.from_offset)); - let response = response.map(|m| m.with_status(Status::Malformed)); + let response = reciprocal + .prepare(Records::empty(from_offset).with_status(Status::Malformed)); // ship the response to the original connection let _ = response.send().await; return Ok(()); } - // initial response - let response = - msg.prepare_response(Records::new(loglet_state.local_tail(), msg.from_offset)); - let response = match log_store.read_records(msg.into_body(), loglet_state).await { - Ok(records) => response.map(|_| records), - Err(_) => response.map(|m| m.with_status(Status::Disabled)), + let records = match log_store.read_records(msg, &loglet_state).await { + Ok(records) => records, + Err(_) => Records::new(loglet_state.local_tail(), from_offset) + .with_status(Status::Disabled), }; // ship the response to the original connection - let _ = response.send().await; + let _ = reciprocal.prepare(records).send().await; Ok(()) }); } - async fn process_trim(&mut self, mut msg: Incoming, known_global_tail: LogletOffset) { + fn process_trim(&mut self, msg: Incoming, known_global_tail: LogletOffset) { // When trimming, we eagerly update the in-memory view of the trim-point _before_ we // perform the trim on the log-store since it's safer to over report the trim-point than // under report. @@ -418,24 +417,24 @@ impl LogletWorker { let _ = self .task_center .spawn(TaskKind::Disposable, "loglet-trim", None, async move { - let loglet_id = msg.loglet_id; - let new_trim_point = msg.trim_point; - let response = msg.prepare_response(Trimmed::empty()); + let loglet_id = msg.body().loglet_id; + let new_trim_point = msg.body().trim_point; // cannot trim beyond the global known tail (if known) or the local_tail whichever is higher. let local_tail = loglet_state.local_tail(); let high_watermark = known_global_tail.max(local_tail.offset()); if new_trim_point < LogletOffset::OLDEST || new_trim_point >= high_watermark { - let _ = msg.respond(Trimmed::new(loglet_state.local_tail()).with_status(Status::Malformed)).await; + let _ = msg.to_rpc_response(Trimmed::new(loglet_state.local_tail()).with_status(Status::Malformed)).send().await; return Ok(()); } + let (reciprocal, mut msg) = msg.split(); // The trim point cannot be at or exceed the local_tail, we clip to the // local_tail-1 if that's the case. msg.trim_point = msg.trim_point.min(local_tail.offset().prev()); let body = if loglet_state.update_trim_point(msg.trim_point) { - match log_store.enqueue_trim(msg.into_body()).await?.await { + match log_store.enqueue_trim(msg).await?.await { Ok(_) => Trimmed::new(loglet_state.local_tail()).with_status(Status::Ok), Err(_) => { warn!( @@ -452,8 +451,7 @@ impl LogletWorker { }; // ship the response to the original connection - let response = response.map(|_| body); - let _ = response.send().await; + let _ = reciprocal.prepare(body).send().await; Ok(()) }); } @@ -490,8 +488,8 @@ mod tests { use googletest::prelude::*; use test_log::test; - use restate_core::network::Connection; - use restate_core::{TaskCenter, TaskCenterBuilder}; + use restate_core::network::OwnedConnection; + use restate_core::{MetadataBuilder, TaskCenter, TaskCenterBuilder}; use restate_rocksdb::RocksDbManager; use restate_types::config::Configuration; use restate_types::live::Live; @@ -514,6 +512,8 @@ mod tests { let log_store = tc .run_in_scope("test-setup", None, async { RocksDbManager::init(common_rocks_opts); + let metadata_builder = MetadataBuilder::default(); + assert!(tc.try_set_global_metadata(metadata_builder.to_metadata())); // create logstore. let builder = RocksDbLogStoreBuilder::create( config.clone().map(|c| &c.log_server).boxed(), @@ -536,7 +536,7 @@ mod tests { let mut loglet_state_map = LogletStateMap::default(); let global_tail_tracker = GlobalTailTrackerMap::default(); let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = Connection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; let worker = LogletWorker::start( @@ -576,8 +576,8 @@ mod tests { payloads: payloads.clone(), }; - let msg1 = Incoming::for_testing(&connection, msg1, None); - let msg2 = Incoming::for_testing(&connection, msg2, None); + let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); + let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); let msg1_id = msg1.msg_id(); let msg2_id = msg2.msg_id(); @@ -621,7 +621,7 @@ mod tests { let mut loglet_state_map = LogletStateMap::default(); let global_tail_tracker = GlobalTailTrackerMap::default(); let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = Connection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; let worker = LogletWorker::start( @@ -673,10 +673,10 @@ mod tests { payloads: payloads.clone(), }; - let msg1 = Incoming::for_testing(&connection, msg1, None); - let seal1 = Incoming::for_testing(&connection, seal1, None); - let seal2 = Incoming::for_testing(&connection, seal2, None); - let msg2 = Incoming::for_testing(&connection, msg2, None); + let msg1 = Incoming::for_testing(connection.downgrade(), msg1, None); + let seal1 = Incoming::for_testing(connection.downgrade(), seal1, None); + let seal2 = Incoming::for_testing(connection.downgrade(), seal2, None); + let msg2 = Incoming::for_testing(connection.downgrade(), msg2, None); let msg1_id = msg1.msg_id(); let seal1_id = seal1.msg_id(); let seal2_id = seal2.msg_id(); @@ -744,7 +744,7 @@ mod tests { flags: StoreFlags::empty(), payloads: payloads.clone(), }; - let msg3 = Incoming::for_testing(&connection, msg3, None); + let msg3 = Incoming::for_testing(connection.downgrade(), msg3, None); let msg3_id = msg3.msg_id(); worker.enqueue_store(msg3).unwrap(); let response = net_rx.recv().await.unwrap(); @@ -763,7 +763,7 @@ mod tests { loglet_id: LOGLET, known_global_tail: LogletOffset::INVALID, }; - let msg = Incoming::for_testing(&connection, msg, None); + let msg = Incoming::for_testing(connection.downgrade(), msg, None); let msg_id = msg.msg_id(); worker.enqueue_get_loglet_info(msg).unwrap(); @@ -794,7 +794,7 @@ mod tests { let mut loglet_state_map = LogletStateMap::default(); let global_tail_tracker = GlobalTailTrackerMap::default(); let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = Connection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; let worker = LogletWorker::start( @@ -809,7 +809,7 @@ mod tests { // Note: dots mean we don't have records at those globally committed offsets. worker .enqueue_store(Incoming::for_testing( - &connection, + connection.downgrade(), Store { loglet_id: LOGLET, timeout_at: None, @@ -827,7 +827,7 @@ mod tests { worker .enqueue_store(Incoming::for_testing( - &connection, + connection.downgrade(), Store { loglet_id: LOGLET, timeout_at: None, @@ -845,7 +845,7 @@ mod tests { worker .enqueue_store(Incoming::for_testing( - &connection, + connection.downgrade(), Store { loglet_id: LOGLET, timeout_at: None, @@ -876,7 +876,7 @@ mod tests { // We expect to see [2, 5]. No trim gaps, no filtered gaps. worker .enqueue_get_records(Incoming::for_testing( - &connection, + connection.downgrade(), GetRecords { loglet_id: LOGLET, filter: KeyFilter::Any, @@ -916,7 +916,7 @@ mod tests { // We expect to see [2, FILTERED(5), 10, 11]. No trim gaps. worker .enqueue_get_records(Incoming::for_testing( - &connection, + connection.downgrade(), GetRecords { loglet_id: LOGLET, // no memory limits @@ -965,7 +965,7 @@ mod tests { // We expect to see [FILTERED(5), 10]. (11 is not returend due to budget) worker .enqueue_get_records(Incoming::for_testing( - &connection, + connection.downgrade(), GetRecords { loglet_id: LOGLET, // no memory limits @@ -1025,7 +1025,7 @@ mod tests { let mut loglet_state_map = LogletStateMap::default(); let global_tail_tracker = GlobalTailTrackerMap::default(); let (net_tx, mut net_rx) = mpsc::channel(10); - let connection = Connection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); + let connection = OwnedConnection::new_fake(SEQUENCER, CURRENT_PROTOCOL_VERSION, net_tx); let loglet_state = loglet_state_map.get_or_load(LOGLET, &log_store).await?; let worker = LogletWorker::start( @@ -1041,7 +1041,7 @@ mod tests { // The loglet has no knowledge of global commits, it shouldn't accept trims. worker .enqueue_trim(Incoming::for_testing( - &connection, + connection.downgrade(), Trim { loglet_id: LOGLET, known_global_tail: LogletOffset::OLDEST, @@ -1066,7 +1066,7 @@ mod tests { // won't move trim point beyond its local tail. worker .enqueue_trim(Incoming::for_testing( - &connection, + connection.downgrade(), Trim { loglet_id: LOGLET, known_global_tail: LogletOffset::new(10), @@ -1090,7 +1090,7 @@ mod tests { // let's store some records at offsets (5, 6) worker .enqueue_store(Incoming::for_testing( - &connection, + connection.downgrade(), Store { loglet_id: LOGLET, timeout_at: None, @@ -1118,7 +1118,7 @@ mod tests { // trim to 5 worker .enqueue_trim(Incoming::for_testing( - &connection, + connection.downgrade(), Trim { loglet_id: LOGLET, known_global_tail: LogletOffset::new(10), @@ -1142,7 +1142,7 @@ mod tests { // Attempt to read. We expect to see a trim gap (1->5, 6 (data-record)) worker .enqueue_get_records(Incoming::for_testing( - &connection, + connection.downgrade(), GetRecords { loglet_id: LOGLET, total_limit_in_bytes: None, @@ -1189,7 +1189,7 @@ mod tests { // trim everything worker .enqueue_trim(Incoming::for_testing( - &connection, + connection.downgrade(), Trim { loglet_id: LOGLET, known_global_tail: LogletOffset::new(10), @@ -1213,7 +1213,7 @@ mod tests { // Attempt to read again. We expect to see a trim gap (1->6) worker .enqueue_get_records(Incoming::for_testing( - &connection, + connection.downgrade(), GetRecords { loglet_id: LOGLET, total_limit_in_bytes: None, diff --git a/crates/log-server/src/logstore.rs b/crates/log-server/src/logstore.rs index 1a3233f73..7958f98d5 100644 --- a/crates/log-server/src/logstore.rs +++ b/crates/log-server/src/logstore.rs @@ -54,7 +54,7 @@ pub trait LogStore: Clone + Send + 'static { fn read_records( &mut self, get_records_message: GetRecords, - loglet_state: LogletState, + loglet_state: &LogletState, ) -> impl Future> + Send; } diff --git a/crates/log-server/src/network.rs b/crates/log-server/src/network.rs index 568cbf931..76a4bde1a 100644 --- a/crates/log-server/src/network.rs +++ b/crates/log-server/src/network.rs @@ -22,7 +22,7 @@ use tokio_stream::StreamExt as TokioStreamExt; use tracing::{debug, trace}; use xxhash_rust::xxh3::Xxh3Builder; -use restate_core::network::{Incoming, MessageRouterBuilder, NetworkSender}; +use restate_core::network::{Incoming, MessageRouterBuilder}; use restate_core::{cancellation_watcher, Metadata, TaskCenter}; use restate_types::config::Configuration; use restate_types::live::Live; @@ -85,12 +85,7 @@ impl RequestPump { } /// Starts the main processing loop, exits on error or shutdown. - pub async fn run( - self, - _networking: N, - log_store: S, - _storage_state: StorageState, - ) -> anyhow::Result<()> + pub async fn run(self, log_store: S, _storage_state: StorageState) -> anyhow::Result<()> where S: LogStore + Clone + Sync + Send + 'static, { @@ -140,7 +135,7 @@ impl RequestPump { // find the worker or create one. // enqueue. let worker = Self::find_or_create_worker( - release.loglet_id, + release.body().loglet_id, &log_store, &task_center, &global_tail_tracker, @@ -153,7 +148,7 @@ impl RequestPump { // find the worker or create one. // enqueue. let worker = Self::find_or_create_worker( - seal.loglet_id, + seal.body().loglet_id, &log_store, &task_center, &global_tail_tracker, @@ -166,7 +161,7 @@ impl RequestPump { // find the worker or create one. // enqueue. let worker = Self::find_or_create_worker( - get_loglet_info.loglet_id, + get_loglet_info.body().loglet_id, &log_store, &task_center, &global_tail_tracker, @@ -179,7 +174,7 @@ impl RequestPump { // find the worker or create one. // enqueue. let worker = Self::find_or_create_worker( - get_records.loglet_id, + get_records.body().loglet_id, &log_store, &task_center, &global_tail_tracker, @@ -192,7 +187,7 @@ impl RequestPump { // find the worker or create one. // enqueue. let worker = Self::find_or_create_worker( - trim.loglet_id, + trim.body().loglet_id, &log_store, &task_center, &global_tail_tracker, @@ -205,7 +200,7 @@ impl RequestPump { // find the worker or create one. // enqueue. let worker = Self::find_or_create_worker( - store.loglet_id, + store.body().loglet_id, &log_store, &task_center, &global_tail_tracker, @@ -232,8 +227,8 @@ impl RequestPump { fn on_store(worker: &LogletWorkerHandle, msg: Incoming) { if let Err(msg) = worker.enqueue_store(msg) { // worker has crashed or shutdown in progress. Notify the sender and drop the message. - if let Err(e) = msg.try_respond_rpc(Stored::empty()) { - debug!(?e.source, peer = %msg.peer(), "Failed to respond to Store message with status Disabled due to peer channel capacity being full"); + if let Err(e) = msg.to_rpc_response(Stored::empty()).try_send() { + debug!(?e.source, peer = %e.original.peer(), "Failed to respond to Store message with status Disabled due to peer channel capacity being full"); } } } @@ -241,8 +236,8 @@ impl RequestPump { fn on_release(worker: &LogletWorkerHandle, msg: Incoming) { if let Err(msg) = worker.enqueue_release(msg) { // worker has crashed or shutdown in progress. Notify the sender and drop the message. - if let Err(e) = msg.try_respond_rpc(Released::empty()) { - debug!(?e.source, peer = %msg.peer(), "Failed to respond to Release message with status Disabled due to peer channel capacity being full"); + if let Err(e) = msg.to_rpc_response(Released::empty()).try_send() { + debug!(?e.source, peer = %e.original.peer(), "Failed to respond to Release message with status Disabled due to peer channel capacity being full"); } } } @@ -250,8 +245,8 @@ impl RequestPump { fn on_seal(worker: &LogletWorkerHandle, msg: Incoming) { if let Err(msg) = worker.enqueue_seal(msg) { // worker has crashed or shutdown in progress. Notify the sender and drop the message. - if let Err(e) = msg.try_respond_rpc(Sealed::empty()) { - debug!(?e.source, peer = %msg.peer(), "Failed to respond to Seal message with status Disabled due to peer channel capacity being full"); + if let Err(e) = msg.to_rpc_response(Sealed::empty()).try_send() { + debug!(?e.source, peer = %e.original.peer(), "Failed to respond to Seal message with status Disabled due to peer channel capacity being full"); } } } @@ -259,18 +254,18 @@ impl RequestPump { fn on_get_loglet_info(worker: &LogletWorkerHandle, msg: Incoming) { if let Err(msg) = worker.enqueue_get_loglet_info(msg) { // worker has crashed or shutdown in progress. Notify the sender and drop the message. - if let Err(e) = msg.try_respond_rpc(LogletInfo::empty()) { - debug!(?e.source, peer = %msg.peer(), "Failed to respond to GetLogletInfo message with status Disabled due to peer channel capacity being full"); + if let Err(e) = msg.to_rpc_response(LogletInfo::empty()).try_send() { + debug!(?e.source, peer = %e.original.peer(), "Failed to respond to GetLogletInfo message with status Disabled due to peer channel capacity being full"); } } } fn on_get_records(worker: &LogletWorkerHandle, msg: Incoming) { if let Err(msg) = worker.enqueue_get_records(msg) { - let next_offset = msg.from_offset; + let next_offset = msg.body().from_offset; // worker has crashed or shutdown in progress. Notify the sender and drop the message. - if let Err(e) = msg.try_respond_rpc(Records::empty(next_offset)) { - debug!(?e.source, peer = %msg.peer(), "Failed to respond to GetRecords message with status Disabled due to peer channel capacity being full"); + if let Err(e) = msg.to_rpc_response(Records::empty(next_offset)).try_send() { + debug!(?e.source, peer = %e.original.peer(), "Failed to respond to GetRecords message with status Disabled due to peer channel capacity being full"); } } } @@ -278,8 +273,8 @@ impl RequestPump { fn on_trim(worker: &LogletWorkerHandle, msg: Incoming) { if let Err(msg) = worker.enqueue_trim(msg) { // worker has crashed or shutdown in progress. Notify the sender and drop the message. - if let Err(e) = msg.try_respond_rpc(Trimmed::empty()) { - debug!(?e.source, peer = %msg.peer(), "Failed to respond to Trim message with status Disabled due to peer channel capacity being full"); + if let Err(e) = msg.to_rpc_response(Trimmed::empty()).try_send() { + debug!(?e.source, peer = %e.original.peer(), "Failed to respond to Trim message with status Disabled due to peer channel capacity being full"); } } } diff --git a/crates/log-server/src/rocksdb_logstore/store.rs b/crates/log-server/src/rocksdb_logstore/store.rs index fcdf6f136..82c8041dc 100644 --- a/crates/log-server/src/rocksdb_logstore/store.rs +++ b/crates/log-server/src/rocksdb_logstore/store.rs @@ -202,7 +202,7 @@ impl LogStore for RocksDbLogStore { async fn read_records( &mut self, msg: GetRecords, - loglet_state: LogletState, + loglet_state: &LogletState, ) -> Result { let data_cf = self.data_cf(); let loglet_id = msg.loglet_id; diff --git a/crates/log-server/src/service.rs b/crates/log-server/src/service.rs index 955c9d2a4..fa8d26b7f 100644 --- a/crates/log-server/src/service.rs +++ b/crates/log-server/src/service.rs @@ -11,7 +11,7 @@ use anyhow::Context; use tracing::{debug, info, instrument}; -use restate_core::network::{MessageRouterBuilder, Networking}; +use restate_core::network::MessageRouterBuilder; use restate_core::{Metadata, MetadataWriter, TaskCenter, TaskKind}; use restate_metadata_store::MetadataStoreClient; use restate_types::config::Configuration; @@ -31,7 +31,6 @@ pub struct LogServerService { updateable_config: Live, task_center: TaskCenter, metadata: Metadata, - networking: Networking, request_processor: RequestPump, metadata_store_client: MetadataStoreClient, } @@ -41,7 +40,6 @@ impl LogServerService { updateable_config: Live, task_center: TaskCenter, metadata: Metadata, - networking: Networking, metadata_store_client: MetadataStoreClient, router_builder: &mut MessageRouterBuilder, ) -> Result { @@ -58,7 +56,6 @@ impl LogServerService { updateable_config, task_center, metadata, - networking, request_processor, metadata_store_client, }) @@ -79,7 +76,6 @@ impl LogServerService { task_center, metadata, request_processor: request_pump, - networking, mut metadata_store_client, } = self; // What do we need to start the log-server? @@ -110,7 +106,7 @@ impl LogServerService { TaskKind::NetworkMessageHandler, "log-server-req-pump", None, - request_pump.run(networking, log_store, storage_state), + request_pump.run(log_store, storage_state), )?; Ok(()) } diff --git a/crates/metadata-store/src/local/tests.rs b/crates/metadata-store/src/local/tests.rs index ea23b3e6a..1660675f6 100644 --- a/crates/metadata-store/src/local/tests.rs +++ b/crates/metadata-store/src/local/tests.rs @@ -13,13 +13,14 @@ use std::time::Duration; use bytestring::ByteString; use futures::stream::FuturesUnordered; use futures::StreamExt; +use restate_core::network::FailingConnector; use serde::{Deserialize, Serialize}; use test_log::test; use tonic_health::pb::health_client::HealthClient; use tonic_health::pb::HealthCheckRequest; use restate_core::network::net_util::create_tonic_channel_from_advertised_address; -use restate_core::{MockNetworkSender, TaskCenter, TaskKind, TestCoreEnv, TestCoreEnvBuilder}; +use restate_core::{TaskCenter, TaskKind, TestCoreEnv, TestCoreEnvBuilder}; use restate_rocksdb::RocksDbManager; use restate_types::config::{ self, reset_base_temp_dir_and_retain, Configuration, MetadataStoreClientOptions, @@ -313,7 +314,7 @@ async fn durable_storage() -> anyhow::Result<()> { /// connected to it. async fn create_test_environment( opts: &MetadataStoreOptions, -) -> anyhow::Result<(MetadataStoreClient, TestCoreEnv)> { +) -> anyhow::Result<(MetadataStoreClient, TestCoreEnv)> { // Setup metadata store on unix domain socket. let mut config = Configuration::default(); let uds_path = tempfile::tempdir()?.into_path().join("grpc-server"); @@ -328,7 +329,9 @@ async fn create_test_environment( restate_types::config::set_current_config(config.clone()); let config = Live::from_value(config); - let env = TestCoreEnvBuilder::new_with_mock_network().build().await; + let env = TestCoreEnvBuilder::with_incoming_only_connector() + .build() + .await; let task_center = &env.tc; diff --git a/crates/node/src/lib.rs b/crates/node/src/lib.rs index 61a1966e7..7633a7d5b 100644 --- a/crates/node/src/lib.rs +++ b/crates/node/src/lib.rs @@ -20,8 +20,8 @@ use tokio::sync::oneshot; use codederror::CodedError; use restate_bifrost::BifrostService; use restate_core::metadata_store::{MetadataStoreClientError, ReadWriteError}; -use restate_core::network::MessageRouterBuilder; use restate_core::network::Networking; +use restate_core::network::{GrpcConnector, MessageRouterBuilder}; use restate_core::{ spawn_metadata_manager, MetadataBuilder, MetadataKind, MetadataManager, TargetVersion, }; @@ -103,12 +103,12 @@ pub enum BuildError { pub struct Node { updateable_config: Live, - metadata_manager: MetadataManager, + metadata_manager: MetadataManager, metadata_store_client: MetadataStoreClient, bifrost: BifrostService, metadata_store_role: Option, - admin_role: Option, - worker_role: Option, + admin_role: Option>, + worker_role: Option>, #[cfg(feature = "replicated-loglet")] log_server: Option, server: NetworkServer, @@ -197,7 +197,6 @@ impl Node { updateable_config.clone(), tc.clone(), metadata.clone(), - networking.clone(), metadata_store_client.clone(), &mut router_builder, ) @@ -242,7 +241,7 @@ impl Node { }; let server = NetworkServer::new( - networking.connection_manager(), + networking.connection_manager().clone(), worker_role .as_ref() .map(|worker| WorkerDependencies::new(worker.storage_query_context().clone())), diff --git a/crates/node/src/network_server/handler/node.rs b/crates/node/src/network_server/handler/node.rs index a679f583c..d24d3bf7b 100644 --- a/crates/node/src/network_server/handler/node.rs +++ b/crates/node/src/network_server/handler/node.rs @@ -16,8 +16,8 @@ use futures::TryStreamExt; use restate_core::network::protobuf::node_svc::node_svc_server::NodeSvc; use restate_core::network::protobuf::node_svc::IdentResponse; use restate_core::network::protobuf::node_svc::{StorageQueryRequest, StorageQueryResponse}; -use restate_core::network::ConnectionManager; use restate_core::network::ProtocolError; +use restate_core::network::{ConnectionManager, GrpcConnector}; use restate_core::{metadata, TaskCenter}; use restate_types::protobuf::common::NodeStatus; use restate_types::protobuf::node::Message; @@ -31,14 +31,14 @@ use crate::network_server::WorkerDependencies; pub struct NodeSvcHandler { task_center: TaskCenter, worker: Option, - connections: ConnectionManager, + connections: ConnectionManager, } impl NodeSvcHandler { pub fn new( task_center: TaskCenter, worker: Option, - connections: ConnectionManager, + connections: ConnectionManager, ) -> Self { Self { task_center, @@ -121,7 +121,10 @@ impl NodeSvc for NodeSvcHandler { ) .await?; - Ok(Response::new(output_stream)) + // For uniformity with outbound connections, we map all responses to Ok, we never rely on + // sending tonic::Status errors explicitly. We use ConnectionControl frames to communicate + // errors and/or drop the stream when necessary. + Ok(Response::new(Box::pin(output_stream.map(Ok)))) } } diff --git a/crates/node/src/network_server/service.rs b/crates/node/src/network_server/service.rs index 8f5e919c5..15e542b05 100644 --- a/crates/node/src/network_server/service.rs +++ b/crates/node/src/network_server/service.rs @@ -22,7 +22,7 @@ use restate_admin::cluster_controller::ClusterControllerHandle; use restate_bifrost::Bifrost; use restate_core::network::net_util::run_hyper_server; use restate_core::network::protobuf::node_svc::node_svc_server::NodeSvcServer; -use restate_core::network::ConnectionManager; +use restate_core::network::{ConnectionManager, GrpcConnector}; use restate_core::task_center; use restate_metadata_store::MetadataStoreClient; use restate_storage_query_datafusion::context::QueryContext; @@ -36,14 +36,14 @@ use crate::network_server::multiplex::MultiplexService; use crate::network_server::state::NodeCtrlHandlerStateBuilder; pub struct NetworkServer { - connection_manager: ConnectionManager, + connection_manager: ConnectionManager, worker_deps: Option, admin_deps: Option, } impl NetworkServer { pub fn new( - connection_manager: ConnectionManager, + connection_manager: ConnectionManager, worker_deps: Option, admin_deps: Option, ) -> Self { diff --git a/crates/node/src/roles/admin.rs b/crates/node/src/roles/admin.rs index a843addc9..a5597eb5e 100644 --- a/crates/node/src/roles/admin.rs +++ b/crates/node/src/roles/admin.rs @@ -20,6 +20,7 @@ use restate_core::metadata_store::MetadataStoreClient; use restate_core::network::protobuf::node_svc::node_svc_client::NodeSvcClient; use restate_core::network::MessageRouterBuilder; use restate_core::network::Networking; +use restate_core::network::TransportConnect; use restate_core::{task_center, Metadata, MetadataWriter, TaskCenter, TaskKind}; use restate_service_client::{AssumeRoleCacheMode, ServiceClient}; use restate_service_protocol::discovery::ServiceDiscovery; @@ -43,18 +44,18 @@ pub enum AdminRoleBuildError { ServiceClient(#[from] restate_service_client::BuildError), } -pub struct AdminRole { +pub struct AdminRole { updateable_config: Live, - controller: cluster_controller::Service, + controller: cluster_controller::Service, admin: AdminService, } -impl AdminRole { +impl AdminRole { pub async fn create( task_center: TaskCenter, updateable_config: Live, metadata: Metadata, - networking: Networking, + networking: Networking, metadata_writer: MetadataWriter, router_builder: &mut MessageRouterBuilder, metadata_store_client: MetadataStoreClient, diff --git a/crates/node/src/roles/worker.rs b/crates/node/src/roles/worker.rs index b6cf4a1c7..870971384 100644 --- a/crates/node/src/roles/worker.rs +++ b/crates/node/src/roles/worker.rs @@ -9,6 +9,7 @@ // by the Apache License, Version 2.0. use codederror::CodedError; +use restate_core::network::TransportConnect; use tokio::sync::oneshot; use restate_bifrost::Bifrost; @@ -59,17 +60,17 @@ pub enum WorkerRoleBuildError { ), } -pub struct WorkerRole { +pub struct WorkerRole { metadata: Metadata, - worker: Worker, + worker: Worker, } -impl WorkerRole { +impl WorkerRole { pub async fn create( metadata: Metadata, updateable_config: Live, router_builder: &mut MessageRouterBuilder, - networking: Networking, + networking: Networking, bifrost: Bifrost, metadata_store_client: MetadataStoreClient, updating_schema_information: Live, diff --git a/crates/types/src/config/networking.rs b/crates/types/src/config/networking.rs index c390530f4..a256aa644 100644 --- a/crates/types/src/config/networking.rs +++ b/crates/types/src/config/networking.rs @@ -8,6 +8,7 @@ // the Business Source License, use of this software will be governed // by the Apache License, Version 2.0. +use std::num::NonZeroUsize; use std::time::Duration; use serde::{Deserialize, Serialize}; @@ -28,6 +29,12 @@ pub struct NetworkingOptions { /// Retry policy to use for internal node-to-node networking. pub connect_retry_policy: RetryPolicy, + /// # Connection Send Buffer + /// + /// The number of messages that can be queued on the outbound stream of a single + /// connection + pub outbound_queue_length: NonZeroUsize, + /// # Handshake timeout /// /// Timeout for handshake message for internal node-to-node networking. @@ -36,8 +43,6 @@ pub struct NetworkingOptions { pub handshake_timeout: humantime::Duration, } -impl NetworkingOptions {} - impl Default for NetworkingOptions { fn default() -> Self { Self { @@ -47,6 +52,8 @@ impl Default for NetworkingOptions { Some(10), Some(Duration::from_millis(500)), ), + + outbound_queue_length: NonZeroUsize::new(1000).expect("Non zero number"), handshake_timeout: Duration::from_secs(3).into(), } } diff --git a/crates/types/src/net/codec.rs b/crates/types/src/net/codec.rs index 20df27c04..53fdf23f7 100644 --- a/crates/types/src/net/codec.rs +++ b/crates/types/src/net/codec.rs @@ -71,7 +71,7 @@ where pub trait WireEncode { fn encode( - &self, + self, buf: &mut B, protocol_version: ProtocolVersion, ) -> Result<(), CodecError>; @@ -83,42 +83,16 @@ pub trait WireDecode { Self: Sized; } -impl WireEncode for &T -where - T: WireEncode, -{ - fn encode( - &self, - buf: &mut B, - protocol_version: ProtocolVersion, - ) -> Result<(), CodecError> { - (*self).encode(buf, protocol_version) - } -} - -impl WireEncode for &mut T -where - T: WireEncode, -{ - fn encode( - &self, - buf: &mut B, - protocol_version: ProtocolVersion, - ) -> Result<(), CodecError> { - (**self).encode(buf, protocol_version) - } -} - impl WireEncode for Box where T: WireEncode, { fn encode( - &self, + self, buf: &mut B, protocol_version: ProtocolVersion, ) -> Result<(), CodecError> { - (**self).encode(buf, protocol_version) + (*self).encode(buf, protocol_version) } } @@ -134,19 +108,6 @@ where } } -impl WireEncode for Arc -where - T: WireEncode, -{ - fn encode( - &self, - buf: &mut B, - protocol_version: ProtocolVersion, - ) -> Result<(), CodecError> { - (**self).encode(buf, protocol_version) - } -} - impl WireDecode for Arc where T: WireDecode, @@ -160,7 +121,7 @@ where } pub fn serialize_message( - msg: &M, + msg: M, protocol_version: ProtocolVersion, ) -> Result { let mut payload = BytesMut::new(); diff --git a/crates/types/src/net/mod.rs b/crates/types/src/net/mod.rs index 2ee604c4d..165bb42f4 100644 --- a/crates/types/src/net/mod.rs +++ b/crates/types/src/net/mod.rs @@ -132,7 +132,7 @@ macro_rules! define_message { impl $crate::net::codec::WireEncode for $message { fn encode( - &self, + self, buf: &mut B, protocol_version: $crate::net::ProtocolVersion, ) -> Result<(), $crate::net::CodecError> { diff --git a/crates/types/src/replicated_loglet/params.rs b/crates/types/src/replicated_loglet/params.rs index 70b792588..d82957a15 100644 --- a/crates/types/src/replicated_loglet/params.rs +++ b/crates/types/src/replicated_loglet/params.rs @@ -21,20 +21,20 @@ use super::ReplicationProperty; #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq)] pub struct ReplicatedLogletParams { /// Unique identifier for this loglet - loglet_id: ReplicatedLogletId, + pub loglet_id: ReplicatedLogletId, /// The sequencer node #[serde(with = "serde_with::As::")] - sequencer: GenerationalNodeId, + pub sequencer: GenerationalNodeId, /// Replication properties of this loglet - replication: ReplicationProperty, - nodeset: NodeSet, + pub replication: ReplicationProperty, + pub nodeset: NodeSet, /// The set of nodes the sequencer has been considering for writes after the last /// known_global_tail advance. /// /// If unset, the entire nodeset is considered as part of the write set /// If set, tail repair will attempt reading only from this set. #[serde(skip_serializing_if = "Option::is_none")] - write_set: Option, + pub write_set: Option, } impl ReplicatedLogletParams { diff --git a/crates/worker/src/lib.rs b/crates/worker/src/lib.rs index d8392ec74..864076fad 100644 --- a/crates/worker/src/lib.rs +++ b/crates/worker/src/lib.rs @@ -29,6 +29,7 @@ pub use crate::subscription_integration::SubscriptionControllerHandle; use restate_bifrost::Bifrost; use restate_core::network::MessageRouterBuilder; use restate_core::network::Networking; +use restate_core::network::TransportConnect; use restate_core::{cancellation_watcher, task_center, Metadata, TaskKind}; use restate_ingress_dispatcher::IngressDispatcher; use restate_ingress_http::HyperServerIngress; @@ -89,21 +90,21 @@ pub enum Error { }, } -pub struct Worker { +pub struct Worker { updateable_config: Live, storage_query_context: QueryContext, storage_query_postgres: PostgresQueryService, external_client_ingress: ExternalClientIngress, ingress_kafka: IngressKafkaService, subscription_controller_handle: SubscriptionControllerHandle, - partition_processor_manager: PartitionProcessorManager, + partition_processor_manager: PartitionProcessorManager, } -impl Worker { +impl Worker { pub async fn create( updateable_config: Live, metadata: Metadata, - networking: Networking, + networking: Networking, bifrost: Bifrost, router_builder: &mut MessageRouterBuilder, schema: Live, diff --git a/crates/worker/src/partition/cleaner.rs b/crates/worker/src/partition/cleaner.rs index 02428060d..514a1dbda 100644 --- a/crates/worker/src/partition/cleaner.rs +++ b/crates/worker/src/partition/cleaner.rs @@ -214,8 +214,8 @@ mod tests { // Start paused makes sure the timer is immediately fired #[test(tokio::test(start_paused = true))] pub async fn cleanup_works() { - let env = TestCoreEnvBuilder::new_with_mock_network() - .with_partition_table(PartitionTable::with_equally_sized_partitions( + let env = TestCoreEnvBuilder::with_incoming_only_connector() + .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, )) diff --git a/crates/worker/src/partition/leadership.rs b/crates/worker/src/partition/leadership.rs index cd1913bf4..c3b8df0b4 100644 --- a/crates/worker/src/partition/leadership.rs +++ b/crates/worker/src/partition/leadership.rs @@ -24,7 +24,7 @@ use tokio_stream::wrappers::ReceiverStream; use tracing::{debug, instrument, trace, warn}; use restate_bifrost::Bifrost; -use restate_core::network::{NetworkSender, Outgoing}; +use restate_core::network::{NetworkSender, Networking, Outgoing, TransportConnect}; use restate_core::{ current_task_partition_id, metadata, task_center, ShutdownError, TaskId, TaskKind, }; @@ -130,7 +130,7 @@ impl PartitionProcessorMetadata { } } -pub(crate) struct LeadershipState { +pub(crate) struct LeadershipState { state: State, last_seen_leader_epoch: Option, @@ -139,14 +139,14 @@ pub(crate) struct LeadershipState { cleanup_interval: Duration, channel_size: usize, invoker_tx: I, - network_tx: N, + network_tx: Networking, bifrost: Bifrost, } -impl LeadershipState +impl LeadershipState where I: restate_invoker_api::InvokerHandle>, - N: NetworkSender + 'static, + T: TransportConnect, { #[allow(clippy::too_many_arguments)] pub(crate) fn new( @@ -156,7 +156,7 @@ where channel_size: usize, invoker_tx: I, bifrost: Bifrost, - network_tx: N, + network_tx: Networking, last_seen_leader_epoch: Option, ) -> Self { Self { @@ -541,7 +541,7 @@ where shuffle_hint_tx: &HintSender, mut timer_service: Pin<&mut TimerService>, actions_effects: &mut VecDeque, - network_tx: &N, + network_tx: &Networking, ) -> Result<(), Error> { match action { Action::Invoke { @@ -635,7 +635,7 @@ where } async fn send_ingress_message( - network_tx: N, + network_tx: Networking, invocation_id: Option, target_node: GenerationalNodeId, ingress_message: ingress::IngressMessage, @@ -771,7 +771,7 @@ mod tests { #[test(tokio::test)] async fn become_leader_then_step_down() -> googletest::Result<()> { - let env = TestCoreEnv::create_with_mock_nodes_config(0, 0).await; + let env = TestCoreEnv::create_with_single_node(0, 0).await; let tc = env.tc.clone(); let storage_options = StorageOptions::default(); let rocksdb_options = RocksDbOptions::default(); @@ -804,7 +804,7 @@ mod tests { 42, invoker_tx, bifrost.clone(), - env.network_sender.clone(), + env.networking.clone(), None, ); diff --git a/crates/worker/src/partition/mod.rs b/crates/worker/src/partition/mod.rs index efc64f632..5cc464ac8 100644 --- a/crates/worker/src/partition/mod.rs +++ b/crates/worker/src/partition/mod.rs @@ -24,7 +24,7 @@ use tracing::{debug, error, info, instrument, trace, warn, Span}; use restate_bifrost::{Bifrost, FindTailAttributes}; use restate_core::cancellation_watcher; use restate_core::metadata; -use restate_core::network::Networking; +use restate_core::network::{Networking, TransportConnect}; use restate_partition_store::{PartitionStore, PartitionStoreTransaction}; use restate_storage_api::deduplication_table::{ DedupInformation, DedupSequenceNumber, DeduplicationTable, ProducerId, @@ -118,12 +118,12 @@ where } } - pub async fn build( + pub async fn build( self, - networking: Networking, + networking: Networking, bifrost: Bifrost, mut partition_store: PartitionStore, - ) -> Result, StorageError> { + ) -> Result, StorageError> { let PartitionProcessorBuilder { partition_id, partition_key_range, @@ -214,10 +214,10 @@ where } } -pub struct PartitionProcessor { +pub struct PartitionProcessor { partition_id: PartitionId, partition_key_range: RangeInclusive, - leadership_state: LeadershipState, + leadership_state: LeadershipState, state_machine: StateMachine, bifrost: Bifrost, control_rx: mpsc::Receiver, @@ -230,10 +230,11 @@ pub struct PartitionProcessor { partition_store: Option, } -impl PartitionProcessor +impl PartitionProcessor where Codec: RawEntryCodec + Default + Debug, InvokerSender: restate_invoker_api::InvokerHandle> + Clone, + T: TransportConnect, { #[instrument(level = "error", skip_all, fields(partition_id = %self.partition_id, is_leader = tracing::field::Empty))] pub async fn run(mut self) -> anyhow::Result<()> { diff --git a/crates/worker/src/partition/shuffle.rs b/crates/worker/src/partition/shuffle.rs index 55bc88a6d..62ce7cdef 100644 --- a/crates/worker/src/partition/shuffle.rs +++ b/crates/worker/src/partition/shuffle.rs @@ -435,17 +435,19 @@ mod state_machine { #[cfg(test)] mod tests { - use anyhow::anyhow; - use assert2::let_assert; - use futures::{Stream, StreamExt}; use std::iter; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; + + use anyhow::anyhow; + use assert2::let_assert; + use futures::{Stream, StreamExt}; use test_log::test; use tokio::sync::mpsc; use restate_bifrost::{Bifrost, LogEntry}; - use restate_core::{MockNetworkSender, TaskKind, TestCoreEnv, TestCoreEnvBuilder}; + use restate_core::network::FailingConnector; + use restate_core::{TaskKind, TestCoreEnv, TestCoreEnvBuilder}; use restate_storage_api::outbox_table::OutboxMessage; use restate_storage_api::StorageError; use restate_types::identifiers::{InvocationId, LeaderEpoch, PartitionId}; @@ -612,7 +614,7 @@ mod tests { } struct ShuffleEnv { - env: TestCoreEnv, + env: TestCoreEnv, bifrost: Bifrost, shuffle: Shuffle, } @@ -621,8 +623,8 @@ mod tests { outbox_reader: OR, ) -> ShuffleEnv { // set numbers of partitions to 1 to easily find all sent messages by the shuffle - let env = TestCoreEnvBuilder::new_with_mock_network() - .with_partition_table(PartitionTable::with_equally_sized_partitions( + let env = TestCoreEnvBuilder::with_incoming_only_connector() + .set_partition_table(PartitionTable::with_equally_sized_partitions( Version::MIN, 1, )) diff --git a/crates/worker/src/partition_processor_manager.rs b/crates/worker/src/partition_processor_manager.rs index e9abae86e..e78de6a2c 100644 --- a/crates/worker/src/partition_processor_manager.rs +++ b/crates/worker/src/partition_processor_manager.rs @@ -27,9 +27,8 @@ use tracing::{debug, info, instrument, trace, warn}; use restate_bifrost::Bifrost; use restate_core::network::rpc_router::{RpcError, RpcRouter}; -use restate_core::network::Networking; -use restate_core::network::Outgoing; use restate_core::network::{Incoming, MessageRouterBuilder}; +use restate_core::network::{Networking, TransportConnect}; use restate_core::worker_api::{ProcessorsManagerCommand, ProcessorsManagerHandle}; use restate_core::{cancellation_watcher, Metadata, ShutdownError, TaskId, TaskKind}; use restate_core::{RuntimeError, TaskCenter}; @@ -75,7 +74,7 @@ use crate::partition::invoker_storage_reader::InvokerStorageReader; use crate::partition::PartitionProcessorControlCommand; use crate::PartitionProcessorBuilder; -pub struct PartitionProcessorManager { +pub struct PartitionProcessorManager { task_center: TaskCenter, updateable_config: Live, running_partition_processors: BTreeMap, @@ -84,12 +83,12 @@ pub struct PartitionProcessorManager { metadata: Metadata, metadata_store_client: MetadataStoreClient, partition_store_manager: PartitionStoreManager, - attach_router: RpcRouter, + attach_router: RpcRouter, incoming_get_state: Pin> + Send + Sync + 'static>>, incoming_update_processors: Pin> + Send + Sync + 'static>>, - networking: Networking, + networking: Networking, bifrost: Bifrost, rx: mpsc::Receiver, tx: mpsc::Sender, @@ -282,7 +281,7 @@ impl StatusHandle for MultiplexedInvokerStatusReader { } } -impl PartitionProcessorManager { +impl PartitionProcessorManager { #[allow(clippy::too_many_arguments)] pub fn new( task_center: TaskCenter, @@ -291,10 +290,10 @@ impl PartitionProcessorManager { metadata_store_client: MetadataStoreClient, partition_store_manager: PartitionStoreManager, router_builder: &mut MessageRouterBuilder, - networking: Networking, + networking: Networking, bifrost: Bifrost, ) -> Self { - let attach_router = RpcRouter::new(networking.clone(), router_builder); + let attach_router = RpcRouter::new(router_builder); let incoming_get_state = router_builder.subscribe_to_stream(2); let incoming_update_processors = router_builder.subscribe_to_stream(2); @@ -353,7 +352,7 @@ impl PartitionProcessorManager { match self .attach_router - .call(Outgoing::new(admin_node, AttachRequest::default())) + .call(&self.networking, admin_node, AttachRequest::default()) .await { Ok(response) => return Ok(response), @@ -380,8 +379,8 @@ impl PartitionProcessorManager { let (from, msg) = response.split(); self.apply_plan(&msg.actions).await?; - self.latest_attach_response = Some((from, msg)); - info!("Plan applied from attaching to controller {}", from); + self.latest_attach_response = Some((*from.peer(), msg)); + info!("Plan applied from attaching to controller {}", from.peer()); let (persisted_lsns_tx, persisted_lsns_rx) = watch::channel(BTreeMap::default()); self.persisted_lsns_rx = Some(persisted_lsns_rx); @@ -484,7 +483,8 @@ impl PartitionProcessorManager { None, async move { Ok(get_state_msg - .respond_rpc(ProcessorsStateResponse { state }) + .to_rpc_response(ProcessorsStateResponse { state }) + .send() .await?) }, ); @@ -715,7 +715,7 @@ impl PartitionProcessorManager { )?; pp_builder - .build::(networking, bifrost, partition_store) + .build::(networking, bifrost, partition_store) .await? .run() .await @@ -895,7 +895,7 @@ mod tests { #[test(tokio::test(start_paused = true))] async fn persisted_log_lsn_watchdog_detects_applied_lsns() -> anyhow::Result<()> { - let node_env = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let node_env = TestCoreEnv::create_with_single_node(1, 1).await; let storage_options = StorageOptions::default(); let rocksdb_options = RocksDbOptions::default(); diff --git a/tools/bifrost-benchpress/src/main.rs b/tools/bifrost-benchpress/src/main.rs index 4c8b08b5c..e5c71f921 100644 --- a/tools/bifrost-benchpress/src/main.rs +++ b/tools/bifrost-benchpress/src/main.rs @@ -13,14 +13,14 @@ use std::time::Duration; use bifrost_benchpress::util::{print_prometheus_stats, print_rocksdb_stats}; use clap::Parser; use codederror::CodedError; +use restate_core::network::Networking; use tracing::trace; use bifrost_benchpress::{append_latency, write_to_read, Arguments, Command}; use metrics_exporter_prometheus::PrometheusBuilder; use restate_bifrost::{Bifrost, BifrostService}; use restate_core::{ - spawn_metadata_manager, MetadataBuilder, MetadataManager, MockNetworkSender, TaskCenter, - TaskCenterBuilder, + spawn_metadata_manager, MetadataBuilder, MetadataManager, TaskCenter, TaskCenterBuilder, }; use restate_errors::fmt::RestateCode; use restate_metadata_store::{MetadataStoreClient, Precondition}; @@ -151,12 +151,15 @@ fn spawn_environment(config: Live, num_logs: u16) -> (TaskCenter, let task_center = tc.clone(); let bifrost = tc.block_on("spawn", None, async move { let metadata_builder = MetadataBuilder::default(); - let network_sender = MockNetworkSender::new(metadata_builder.to_metadata()); + let networking = Networking::new( + metadata_builder.to_metadata(), + config.pinned().networking.clone(), + ); let metadata_store_client = MetadataStoreClient::new_in_memory(); let metadata = metadata_builder.to_metadata(); let metadata_manager = MetadataManager::new( metadata_builder, - network_sender.clone(), + networking.clone(), metadata_store_client.clone(), ); diff --git a/tools/xtask/src/main.rs b/tools/xtask/src/main.rs index e3cf390d3..332bcb0ab 100644 --- a/tools/xtask/src/main.rs +++ b/tools/xtask/src/main.rs @@ -104,7 +104,7 @@ async fn generate_rest_api_doc() -> anyhow::Result<()> { ); // We start the Meta service, then download the openapi schema generated - let node_env = TestCoreEnv::create_with_mock_nodes_config(1, 1).await; + let node_env = TestCoreEnv::create_with_single_node(1, 1).await; let bifrost = node_env .tc .run_in_scope(