From 6be566a5bbeabdb9ec5523051d9e1bb1e33e4ed4 Mon Sep 17 00:00:00 2001 From: Aditi Srinivasan Date: Wed, 8 Jan 2025 10:36:24 -0500 Subject: [PATCH] add duplicate message validation to mempool (#202) Throw away messages that already exist in the trie in the mempool before forwarding to the engine. --- src/main.rs | 1 + src/mempool/mempool.rs | 93 ++++++++++++++++++++++++++++++++----- src/mempool/mempool_test.rs | 46 ++++++++++++++++++ src/mempool/mod.rs | 3 ++ src/network/server_tests.rs | 2 +- tests/consensus_test.rs | 7 ++- 6 files changed, 139 insertions(+), 13 deletions(-) create mode 100644 src/mempool/mempool_test.rs diff --git a/src/main.rs b/src/main.rs index d5cef86c..25a7ef16 100644 --- a/src/main.rs +++ b/src/main.rs @@ -137,6 +137,7 @@ async fn main() -> Result<(), Box> { mempool_rx, app_config.consensus.num_shards, node.shard_senders.clone(), + node.shard_stores.clone(), ); tokio::spawn(async move { mempool.run().await }); diff --git a/src/mempool/mempool.rs b/src/mempool/mempool.rs index 615c2b0f..52acf5ef 100644 --- a/src/mempool/mempool.rs +++ b/src/mempool/mempool.rs @@ -3,7 +3,13 @@ use std::collections::HashMap; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc; -use crate::storage::store::engine::{MempoolMessage, Senders}; +use crate::storage::{ + store::{ + engine::{MempoolMessage, Senders}, + stores::Stores, + }, + trie::merkle_trie::{self, TrieKey}, +}; use super::routing::{MessageRouter, ShardRouter}; use tracing::error; @@ -21,6 +27,7 @@ impl Default for Config { pub struct Mempool { shard_senders: HashMap, + shard_stores: HashMap, message_router: Box, num_shards: u32, mempool_rx: mpsc::Receiver, @@ -31,27 +38,91 @@ impl Mempool { mempool_rx: mpsc::Receiver, num_shards: u32, shard_senders: HashMap, + shard_stores: HashMap, ) -> Self { Mempool { shard_senders, + shard_stores, num_shards, mempool_rx, message_router: Box::new(ShardRouter {}), } } + fn message_exists_in_trie(&mut self, fid: u64, trie_key: Vec) -> bool { + let shard = self.message_router.route_message(fid, self.num_shards); + let stores = self.shard_stores.get_mut(&shard); + match stores { + None => { + error!("Error finding store for shard: {}", shard); + false + } + Some(stores) => { + // TODO(aditi): The engine reloads its ref to the trie on commit but we maintain a separate ref to the trie here. + stores.trie.reload(&stores.db).unwrap(); + match stores.trie.exists( + &merkle_trie::Context::new(), + &stores.db, + trie_key.as_ref(), + ) { + Err(err) => { + error!("Error finding key in trie: {}", err); + false + } + Ok(exists) => exists, + } + } + } + } + + fn is_message_already_merged(&mut self, message: &MempoolMessage) -> bool { + let fid = message.fid(); + match message { + MempoolMessage::UserMessage(message) => { + self.message_exists_in_trie(fid, TrieKey::for_message(message)) + } + MempoolMessage::ValidatorMessage(validator_message) => { + if let Some(onchain_event) = &validator_message.on_chain_event { + return self + .message_exists_in_trie(fid, TrieKey::for_onchain_event(&onchain_event)); + } + + if let Some(fname_transfer) = &validator_message.fname_transfer { + if let Some(proof) = &fname_transfer.proof { + let name = String::from_utf8(proof.name.clone()).unwrap(); + return self.message_exists_in_trie( + fid, + TrieKey::for_fname(fname_transfer.id, &name), + ); + } + } + false + } + } + } + + pub fn message_is_valid(&mut self, message: &MempoolMessage) -> bool { + if self.is_message_already_merged(message) { + return false; + } + + return true; + } + pub async fn run(&mut self) { while let Some(message) = self.mempool_rx.recv().await { - let fid = message.fid(); - let shard = self.message_router.route_message(fid, self.num_shards); - let senders = self.shard_senders.get(&shard); - match senders { - None => { - error!("Unable to find shard to send message to") - } - Some(senders) => { - if let Err(err) = senders.messages_tx.send(message).await { - error!("Unable to send message to engine: {}", err.to_string()) + if self.message_is_valid(&message) { + let fid = message.fid(); + let shard = self.message_router.route_message(fid, self.num_shards); + let senders = self.shard_senders.get(&shard); + match senders { + None => { + error!("Unable to find shard to send message to") + } + Some(senders) => { + if let Err(err) = senders.messages_tx.send(message).await { + error!("Unable to send message to engine: {}", err.to_string()) + } } } } diff --git a/src/mempool/mempool_test.rs b/src/mempool/mempool_test.rs new file mode 100644 index 00000000..183cd314 --- /dev/null +++ b/src/mempool/mempool_test.rs @@ -0,0 +1,46 @@ +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use tokio::sync::mpsc; + + use crate::{ + mempool::mempool::Mempool, + storage::store::{ + engine::{MempoolMessage, ShardEngine}, + test_helper, + }, + utils::factory::messages_factory, + }; + + use self::test_helper::{default_custody_address, default_signer}; + + fn setup() -> (ShardEngine, Mempool) { + let (_mempool_tx, mempool_rx) = mpsc::channel(100); + let (engine, _) = test_helper::new_engine(); + let mut shard_senders = HashMap::new(); + shard_senders.insert(1, engine.get_senders()); + let mut shard_stores = HashMap::new(); + shard_stores.insert(1, engine.get_stores()); + let mempool = Mempool::new(mempool_rx, 1, shard_senders, shard_stores); + (engine, mempool) + } + + #[tokio::test] + async fn test_duplicate_message_is_invalid() { + let (mut engine, mut mempool) = setup(); + test_helper::register_user( + 1234, + default_signer(), + default_custody_address(), + &mut engine, + ) + .await; + let cast = messages_factory::casts::create_cast_add(1234, "hello", None, None); + let valid = mempool.message_is_valid(&MempoolMessage::UserMessage(cast.clone())); + assert!(valid); + test_helper::commit_message(&mut engine, &cast).await; + let valid = mempool.message_is_valid(&MempoolMessage::UserMessage(cast.clone())); + assert!(!valid) + } +} diff --git a/src/mempool/mod.rs b/src/mempool/mod.rs index f08e8801..e789342e 100644 --- a/src/mempool/mod.rs +++ b/src/mempool/mod.rs @@ -1,2 +1,5 @@ pub mod mempool; pub mod routing; + +#[cfg(test)] +mod mempool_test; diff --git a/src/network/server_tests.rs b/src/network/server_tests.rs index d2b1ee76..be4e05ee 100644 --- a/src/network/server_tests.rs +++ b/src/network/server_tests.rs @@ -169,7 +169,7 @@ mod tests { assert_eq!(message_router.route_message(SHARD2_FID, 2), 2); let (mempool_tx, mempool_rx) = mpsc::channel(1000); - let mut mempool = Mempool::new(mempool_rx, num_shards, senders.clone()); + let mut mempool = Mempool::new(mempool_rx, num_shards, senders.clone(), stores.clone()); tokio::spawn(async move { mempool.run().await }); ( diff --git a/tests/consensus_test.rs b/tests/consensus_test.rs index 75ed293e..8b951efa 100644 --- a/tests/consensus_test.rs +++ b/tests/consensus_test.rs @@ -113,7 +113,12 @@ impl NodeForTest { let grpc_shard_stores = node.shard_stores.clone(); let grpc_shard_senders = node.shard_senders.clone(); let (mempool_tx, mempool_rx) = mpsc::channel(100); - let mut mempool = Mempool::new(mempool_rx, num_shards, node.shard_senders.clone()); + let mut mempool = Mempool::new( + mempool_rx, + num_shards, + node.shard_senders.clone(), + node.shard_stores.clone(), + ); tokio::spawn(async move { mempool.run().await }); tokio::spawn(async move {