diff --git a/src/main.rs b/src/main.rs index 25a7ef1..c2fec39 100644 --- a/src/main.rs +++ b/src/main.rs @@ -119,12 +119,14 @@ async fn main() -> Result<(), Box> { // Use the new non-global metrics registry when we upgrade to newer version of malachite let _ = Metrics::register(registry); + let (messages_request_tx, messages_request_rx) = mpsc::channel(100); let node = SnapchainNode::create( keypair.clone(), app_config.consensus.clone(), Some(app_config.rpc_address.clone()), gossip_tx.clone(), None, + messages_request_tx, block_store.clone(), app_config.rocksdb_dir.clone(), statsd_client.clone(), @@ -135,8 +137,8 @@ async fn main() -> Result<(), Box> { let (mempool_tx, mempool_rx) = mpsc::channel(app_config.mempool.queue_size as usize); let mut mempool = Mempool::new( mempool_rx, + messages_request_rx, app_config.consensus.num_shards, - node.shard_senders.clone(), node.shard_stores.clone(), ); tokio::spawn(async move { mempool.run().await }); diff --git a/src/mempool/mempool.rs b/src/mempool/mempool.rs index 52acf5e..65297b1 100644 --- a/src/mempool/mempool.rs +++ b/src/mempool/mempool.rs @@ -1,13 +1,13 @@ -use std::collections::HashMap; +use std::collections::{BTreeMap, HashMap}; use serde::{Deserialize, Serialize}; -use tokio::sync::mpsc; +use tokio::{ + sync::{mpsc, oneshot}, + time::Instant, +}; use crate::storage::{ - store::{ - engine::{MempoolMessage, Senders}, - stores::Stores, - }, + store::{engine::MempoolMessage, stores::Stores}, trie::merkle_trie::{self, TrieKey}, }; @@ -25,27 +25,40 @@ impl Default for Config { } } +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct MempoolKey { + inserted_at: Instant, +} + +pub struct MempoolMessagesRequest { + pub shard_id: u32, + pub message_tx: oneshot::Sender>, + pub max_messages_per_block: u32, +} + pub struct Mempool { - shard_senders: HashMap, shard_stores: HashMap, message_router: Box, num_shards: u32, mempool_rx: mpsc::Receiver, + messages_request_rx: mpsc::Receiver, + messages: HashMap>, } impl Mempool { pub fn new( mempool_rx: mpsc::Receiver, + messages_request_rx: mpsc::Receiver, num_shards: u32, - shard_senders: HashMap, shard_stores: HashMap, ) -> Self { Mempool { - shard_senders, shard_stores, num_shards, mempool_rx, message_router: Box::new(ShardRouter {}), + messages: HashMap::new(), + messages_request_rx, } } @@ -75,32 +88,59 @@ impl Mempool { } } - fn is_message_already_merged(&mut self, message: &MempoolMessage) -> bool { - let fid = message.fid(); - match message { - MempoolMessage::UserMessage(message) => { - self.message_exists_in_trie(fid, TrieKey::for_message(message)) + async fn pull_messages(&mut self, request: MempoolMessagesRequest) { + let mut messages = vec![]; + while messages.len() < request.max_messages_per_block as usize { + let shard_messages = self.messages.get_mut(&request.shard_id); + match shard_messages { + None => break, + Some(shard_messages) => { + match shard_messages.pop_first() { + None => break, + Some((_, next_message)) => { + if self.message_is_valid(&next_message) { + messages.push(next_message); + } + } + }; + } } + } + + if let Err(_) = request.message_tx.send(messages) { + error!("Unable to send message from mempool"); + } + } + + fn get_trie_key(message: &MempoolMessage) -> Option> { + match message { + MempoolMessage::UserMessage(message) => return Some(TrieKey::for_message(message)), MempoolMessage::ValidatorMessage(validator_message) => { if let Some(onchain_event) = &validator_message.on_chain_event { - return self - .message_exists_in_trie(fid, TrieKey::for_onchain_event(&onchain_event)); + return Some(TrieKey::for_onchain_event(&onchain_event)); } if let Some(fname_transfer) = &validator_message.fname_transfer { if let Some(proof) = &fname_transfer.proof { let name = String::from_utf8(proof.name.clone()).unwrap(); - return self.message_exists_in_trie( - fid, - TrieKey::for_fname(fname_transfer.id, &name), - ); + return Some(TrieKey::for_fname(fname_transfer.id, &name)); } } - false + + return None; } } } + fn is_message_already_merged(&mut self, message: &MempoolMessage) -> bool { + let fid = message.fid(); + let trie_key = Self::get_trie_key(&message); + match trie_key { + Some(trie_key) => self.message_exists_in_trie(fid, trie_key), + None => false, + } + } + pub fn message_is_valid(&mut self, message: &MempoolMessage) -> bool { if self.is_message_already_merged(message) { return false; @@ -110,18 +150,32 @@ impl Mempool { } pub async fn run(&mut self) { - while let Some(message) = self.mempool_rx.recv().await { - if self.message_is_valid(&message) { - let fid = message.fid(); - let shard = self.message_router.route_message(fid, self.num_shards); - let senders = self.shard_senders.get(&shard); - match senders { - None => { - error!("Unable to find shard to send message to") + loop { + tokio::select! { + biased; + + message_request = self.messages_request_rx.recv() => { + if let Some(messages_request) = message_request { + self.pull_messages(messages_request).await } - Some(senders) => { - if let Err(err) = senders.messages_tx.send(message).await { - error!("Unable to send message to engine: {}", err.to_string()) + } + message = self.mempool_rx.recv() => { + if let Some(message) = message { + // TODO(aditi): Maybe we don't need to run validations here? + if self.message_is_valid(&message) { + let fid = message.fid(); + let shard_id = self.message_router.route_message(fid, self.num_shards); + // TODO(aditi): We need a size limit on the mempool and we need to figure out what to do if it's exceeded + match self.messages.get_mut(&shard_id) { + None => { + let mut messages = BTreeMap::new(); + messages.insert(MempoolKey { inserted_at: Instant::now()}, message.clone()); + self.messages.insert(shard_id, messages); + } + Some(messages) => { + messages.insert(MempoolKey { inserted_at: Instant::now()}, message.clone()); + } + } } } } diff --git a/src/mempool/mempool_test.rs b/src/mempool/mempool_test.rs index 183cd31..f243d29 100644 --- a/src/mempool/mempool_test.rs +++ b/src/mempool/mempool_test.rs @@ -17,12 +17,13 @@ mod tests { fn setup() -> (ShardEngine, Mempool) { let (_mempool_tx, mempool_rx) = mpsc::channel(100); + let (_mempool_tx, messages_request_rx) = mpsc::channel(100); let (engine, _) = test_helper::new_engine(); let mut shard_senders = HashMap::new(); shard_senders.insert(1, engine.get_senders()); let mut shard_stores = HashMap::new(); shard_stores.insert(1, engine.get_stores()); - let mempool = Mempool::new(mempool_rx, 1, shard_senders, shard_stores); + let mempool = Mempool::new(mempool_rx, messages_request_rx, 1, shard_stores); (engine, mempool) } diff --git a/src/network/server.rs b/src/network/server.rs index 45dfa8e..32dbe7c 100644 --- a/src/network/server.rs +++ b/src/network/server.rs @@ -104,6 +104,7 @@ impl MyHubService { stores.store_limits.clone(), self.statsd_client.clone(), 100, + None, ); let result = readonly_engine.simulate_message(&message); diff --git a/src/network/server_tests.rs b/src/network/server_tests.rs index be4e05e..d8e1a61 100644 --- a/src/network/server_tests.rs +++ b/src/network/server_tests.rs @@ -134,29 +134,31 @@ mod tests { let (engine1, _) = test_helper::new_engine_with_options(test_helper::EngineOptions { limits: Some(limits.clone()), db_name: Some("db1.db".to_string()), + messages_request_tx: None, }); let (engine2, _) = test_helper::new_engine_with_options(test_helper::EngineOptions { limits: Some(limits.clone()), db_name: Some("db2.db".to_string()), + messages_request_tx: None, }); let db1 = engine1.db.clone(); let db2 = engine2.db.clone(); - let (msgs_tx, _msgs_rx) = mpsc::channel(100); + let (_msgs_request_tx, msgs_request_rx) = mpsc::channel(100); let shard1_stores = Stores::new( db1, merkle_trie::MerkleTrie::new(16).unwrap(), limits.clone(), ); - let shard1_senders = Senders::new(msgs_tx.clone()); + let shard1_senders = Senders::new(); let shard2_stores = Stores::new( db2, merkle_trie::MerkleTrie::new(16).unwrap(), limits.clone(), ); - let shard2_senders = Senders::new(msgs_tx.clone()); + let shard2_senders = Senders::new(); let stores = HashMap::from([(1, shard1_stores), (2, shard2_stores)]); let senders = HashMap::from([(1, shard1_senders), (2, shard2_senders)]); let num_shards = senders.len() as u32; @@ -169,7 +171,7 @@ mod tests { assert_eq!(message_router.route_message(SHARD2_FID, 2), 2); let (mempool_tx, mempool_rx) = mpsc::channel(1000); - let mut mempool = Mempool::new(mempool_rx, num_shards, senders.clone(), stores.clone()); + let mut mempool = Mempool::new(mempool_rx, msgs_request_rx, num_shards, stores.clone()); tokio::spawn(async move { mempool.run().await }); ( diff --git a/src/node/snapchain_node.rs b/src/node/snapchain_node.rs index 6e8318d..e929879 100644 --- a/src/node/snapchain_node.rs +++ b/src/node/snapchain_node.rs @@ -5,6 +5,7 @@ use crate::core::types::{ Address, Height, ShardId, SnapchainShard, SnapchainValidator, SnapchainValidatorContext, SnapchainValidatorSet, }; +use crate::mempool::mempool::MempoolMessagesRequest; use crate::network::gossip::GossipEvent; use crate::proto::{Block, ShardChunk}; use crate::storage::db::RocksDB; @@ -39,6 +40,7 @@ impl SnapchainNode { rpc_address: Option, gossip_tx: mpsc::Sender>, block_tx: Option>, + messages_request_tx: mpsc::Sender, block_store: BlockStore, rocksdb_dir: String, statsd_client: StatsdClientWrapper, @@ -91,6 +93,7 @@ impl SnapchainNode { StoreLimits::default(), statsd_client.clone(), config.max_messages_per_block, + Some(messages_request_tx.clone()), ); shard_senders.insert(shard_id, engine.get_senders()); diff --git a/src/perf/engine_only_perftest.rs b/src/perf/engine_only_perftest.rs index d9f4635..82142d1 100644 --- a/src/perf/engine_only_perftest.rs +++ b/src/perf/engine_only_perftest.rs @@ -1,8 +1,12 @@ +use tokio::sync::mpsc; + +use crate::mempool::mempool::Mempool; use crate::proto::{Height, ShardChunk, ShardHeader}; use crate::storage::store::engine::{MempoolMessage, ShardStateChange}; use crate::storage::store::stores::StoreLimits; use crate::storage::store::test_helper; use crate::utils::cli::compose_message; +use std::collections::HashMap; use std::error::Error; use std::time::Duration; @@ -28,16 +32,27 @@ fn state_change_to_shard_chunk( } pub async fn run() -> Result<(), Box> { + let (mempool_tx, mempool_rx) = mpsc::channel(1000); + let (messages_request_tx, messages_request_rx) = mpsc::channel(100); + let (mut engine, _tmpdir) = test_helper::new_engine_with_options(test_helper::EngineOptions { limits: Some(StoreLimits { limits: test_helper::limits::unlimited(), legacy_limits: test_helper::limits::unlimited(), }), db_name: None, + messages_request_tx: Some(messages_request_tx), + }); + + let mut shard_stores = HashMap::new(); + shard_stores.insert(1, engine.get_stores()); + let mut mempool = Mempool::new(mempool_rx, messages_request_rx, 1, shard_stores); + + tokio::spawn(async move { + mempool.run().await; }); let mut i = 0; - let messages_tx = engine.messages_tx(); let fid = test_helper::FID_FOR_TEST; @@ -54,7 +69,7 @@ pub async fn run() -> Result<(), Box> { let text = format!("For benchmarking {}", i); let msg = compose_message(fid, text.as_str(), None, None); - messages_tx + mempool_tx .send(MempoolMessage::UserMessage(msg.clone())) .await .unwrap(); diff --git a/src/storage/store/engine.rs b/src/storage/store/engine.rs index dd87bf5..fe2704e 100644 --- a/src/storage/store/engine.rs +++ b/src/storage/store/engine.rs @@ -2,6 +2,7 @@ use super::account::{IntoU8, OnchainEventStorageError, UserDataStore}; use crate::core::error::HubError; use crate::core::types::Height; use crate::core::validations; +use crate::mempool::mempool::MempoolMessagesRequest; use crate::proto::FarcasterNetwork; use crate::proto::HubEvent; use crate::proto::Message; @@ -20,10 +21,11 @@ use merkle_trie::TrieKey; use std::collections::HashSet; use std::str; use std::sync::Arc; -use std::time::{Duration, Instant}; +use std::time::Duration; use thiserror::Error; +use tokio::sync::oneshot; use tokio::sync::{broadcast, mpsc}; -use tokio::time::sleep; +use tokio::time::timeout; use tracing::{error, info, warn}; #[derive(Error, Debug)] @@ -47,7 +49,7 @@ pub enum EngineError { UsageCountError, #[error("message receive error")] - MessageReceiveError(#[from] mpsc::error::TryRecvError), + MessageReceiveError(#[from] oneshot::error::RecvError), #[error(transparent)] MergeOnchainEventError(#[from] OnchainEventStorageError), @@ -92,7 +94,7 @@ impl EngineError { } } -#[derive(Clone)] +#[derive(Clone, Debug)] pub enum MempoolMessage { UserMessage(proto::Message), ValidatorMessage(proto::ValidatorMessage), @@ -117,17 +119,13 @@ pub struct ShardStateChange { #[derive(Clone)] pub struct Senders { - pub messages_tx: mpsc::Sender, pub events_tx: broadcast::Sender, } impl Senders { - pub fn new(messages_tx: mpsc::Sender) -> Senders { + pub fn new() -> Senders { let (events_tx, _events_rx) = broadcast::channel::(100); - Senders { - events_tx, - messages_tx, - } + Senders { events_tx } } } @@ -142,9 +140,9 @@ pub struct ShardEngine { pub db: Arc, senders: Senders, stores: Stores, - messages_rx: mpsc::Receiver, statsd_client: StatsdClientWrapper, max_messages_per_block: u32, + messages_request_tx: Option>, } impl ShardEngine { @@ -155,24 +153,20 @@ impl ShardEngine { store_limits: StoreLimits, statsd_client: StatsdClientWrapper, max_messages_per_block: u32, + messages_request_tx: Option>, ) -> ShardEngine { // TODO: adding the trie here introduces many calls that want to return errors. Rethink unwrap strategy. - let (messages_tx, messages_rx) = mpsc::channel::(100); ShardEngine { shard_id, stores: Stores::new(db.clone(), trie, store_limits), - senders: Senders::new(messages_tx), - messages_rx, + senders: Senders::new(), db, statsd_client, max_messages_per_block, + messages_request_tx, } } - pub fn messages_tx(&self) -> mpsc::Sender { - self.senders.messages_tx.clone() - } - // statsd fn count(&self, key: &str, count: u64) { let key = format!("engine.{}", key); @@ -203,30 +197,37 @@ impl ShardEngine { &mut self, max_wait: Duration, ) -> Result, EngineError> { - let mut messages = Vec::new(); - let start_time = Instant::now(); - - loop { - if start_time.elapsed() >= max_wait { - break; + if let Some(messages_request_tx) = &self.messages_request_tx { + let (message_tx, message_rx) = oneshot::channel(); + + if let Err(err) = messages_request_tx + .send(MempoolMessagesRequest { + shard_id: self.shard_id, + message_tx, + max_messages_per_block: self.max_messages_per_block, + }) + .await + { + error!( + "Could not send request for messages to mempool {}", + err.to_string() + ) } - while messages.len() < self.max_messages_per_block as usize { - match self.messages_rx.try_recv() { - Ok(msg) => messages.push(msg), - Err(mpsc::error::TryRecvError::Empty) => break, - Err(err) => return Err(EngineError::from(err)), + match timeout(max_wait, message_rx).await { + Ok(response) => match response { + Ok(new_messages) => Ok(new_messages), + Err(err) => Err(EngineError::from(err)), + }, + Err(_) => { + error!("Did not receive messages from mempool in time"); + // Just proceed with no messages + Ok(vec![]) } } - - if messages.len() >= self.max_messages_per_block as usize { - break; - } - - sleep(Duration::from_millis(5)).await; + } else { + Ok(vec![]) } - - Ok(messages) } fn prepare_proposal( diff --git a/src/storage/store/test_helper.rs b/src/storage/store/test_helper.rs index da9cc4e..5c0c81b 100644 --- a/src/storage/store/test_helper.rs +++ b/src/storage/store/test_helper.rs @@ -1,3 +1,4 @@ +use crate::mempool::mempool::MempoolMessagesRequest; use crate::storage::db; use crate::storage::store::engine::ShardEngine; use crate::storage::store::stores::StoreLimits; @@ -7,6 +8,7 @@ use ed25519_dalek::{SecretKey, SigningKey}; use prost::Message; use std::sync::Arc; use tempfile; +use tokio::sync::mpsc; use crate::core::error::HubError; use crate::proto; @@ -92,6 +94,7 @@ pub mod limits { pub struct EngineOptions { pub limits: Option, pub db_name: Option, + pub messages_request_tx: Option>, } pub fn new_engine_with_options(options: EngineOptions) -> (ShardEngine, tempfile::TempDir) { @@ -120,6 +123,7 @@ pub fn new_engine_with_options(options: EngineOptions) -> (ShardEngine, tempfile test_limits, statsd_client, 256, + options.messages_request_tx, ), dir, ) @@ -130,6 +134,7 @@ pub fn new_engine() -> (ShardEngine, tempfile::TempDir) { new_engine_with_options(EngineOptions { limits: None, db_name: None, + messages_request_tx: None, }) } diff --git a/tests/consensus_test.rs b/tests/consensus_test.rs index 8b951ef..7df4279 100644 --- a/tests/consensus_test.rs +++ b/tests/consensus_test.rs @@ -11,6 +11,7 @@ use snapchain::node::snapchain_node::SnapchainNode; use snapchain::proto::hub_service_server::HubServiceServer; use snapchain::proto::Block; use snapchain::storage::db::{PageOptions, RocksDB}; +use snapchain::storage::store::engine::MempoolMessage; use snapchain::storage::store::BlockStore; use snapchain::utils::factory::messages_factory; use snapchain::utils::statsd_wrapper::StatsdClientWrapper; @@ -33,6 +34,7 @@ struct NodeForTest { grpc_addr: String, db: Arc, block_store: BlockStore, + mempool_tx: mpsc::Sender, } impl Drop for NodeForTest { @@ -70,12 +72,14 @@ impl NodeForTest { let db = Arc::new(RocksDB::new(&make_tmp_path())); db.open().unwrap(); let block_store = BlockStore::new(db.clone()); + let (messages_request_tx, messages_request_rx) = mpsc::channel(100); let node = SnapchainNode::create( keypair.clone(), config, None, gossip_tx, Some(block_tx), + messages_request_tx, block_store.clone(), make_tmp_path(), statsd_client.clone(), @@ -109,30 +113,27 @@ impl NodeForTest { let grpc_addr = format!("0.0.0.0:{}", grpc_port); let addr = grpc_addr.clone(); - let grpc_block_store = block_store.clone(); - let grpc_shard_stores = node.shard_stores.clone(); - let grpc_shard_senders = node.shard_senders.clone(); let (mempool_tx, mempool_rx) = mpsc::channel(100); let mut mempool = Mempool::new( mempool_rx, + messages_request_rx, num_shards, - node.shard_senders.clone(), node.shard_stores.clone(), ); tokio::spawn(async move { mempool.run().await }); - tokio::spawn(async move { - let service = MyHubService::new( - grpc_block_store, - grpc_shard_stores, - grpc_shard_senders, - statsd_client.clone(), - num_shards, - Box::new(routing::EvenOddRouterForTest {}), - mempool_tx, - None, - ); + let service = MyHubService::new( + block_store.clone(), + node.shard_stores.clone(), + node.shard_senders.clone(), + statsd_client.clone(), + num_shards, + Box::new(routing::EvenOddRouterForTest {}), + mempool_tx.clone(), + None, + ); + tokio::spawn(async move { let grpc_socket_addr: SocketAddr = addr.parse().unwrap(); let resp = Server::builder() .add_service(HubServiceServer::new(service)) @@ -154,6 +155,7 @@ impl NodeForTest { grpc_addr: grpc_addr.clone(), db: db.clone(), block_store, + mempool_tx, } } @@ -327,13 +329,7 @@ async fn test_basic_consensus() { let num_shards = 2; let mut network = TestNetwork::create(3, num_shards, 3380).await; - let messages_tx1 = network.nodes[0] - .node - .shard_senders - .get(&1u32) - .expect("message channel should exist") - .messages_tx - .clone(); + let messages_tx1 = network.nodes[0].mempool_tx.clone(); tokio::spawn(async move { let mut i: i32 = 0;