Skip to content

Commit

Permalink
pull messages out of mempool rather than pushing messages into engine (
Browse files Browse the repository at this point in the history
…#203)

Pull messages out of the mempool and run validations at that point so
all messages fed to the engine are likely to be valid.
  • Loading branch information
aditiharini authored Jan 9, 2025
1 parent 6be566a commit 23fe464
Show file tree
Hide file tree
Showing 10 changed files with 179 additions and 99 deletions.
4 changes: 3 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,14 @@ async fn main() -> Result<(), Box<dyn Error>> {
// Use the new non-global metrics registry when we upgrade to newer version of malachite
let _ = Metrics::register(registry);

let (messages_request_tx, messages_request_rx) = mpsc::channel(100);
let node = SnapchainNode::create(
keypair.clone(),
app_config.consensus.clone(),
Some(app_config.rpc_address.clone()),
gossip_tx.clone(),
None,
messages_request_tx,
block_store.clone(),
app_config.rocksdb_dir.clone(),
statsd_client.clone(),
Expand All @@ -135,8 +137,8 @@ async fn main() -> Result<(), Box<dyn Error>> {
let (mempool_tx, mempool_rx) = mpsc::channel(app_config.mempool.queue_size as usize);
let mut mempool = Mempool::new(
mempool_rx,
messages_request_rx,
app_config.consensus.num_shards,
node.shard_senders.clone(),
node.shard_stores.clone(),
);
tokio::spawn(async move { mempool.run().await });
Expand Down
118 changes: 86 additions & 32 deletions src/mempool/mempool.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};

use serde::{Deserialize, Serialize};
use tokio::sync::mpsc;
use tokio::{
sync::{mpsc, oneshot},
time::Instant,
};

use crate::storage::{
store::{
engine::{MempoolMessage, Senders},
stores::Stores,
},
store::{engine::MempoolMessage, stores::Stores},
trie::merkle_trie::{self, TrieKey},
};

Expand All @@ -25,27 +25,40 @@ impl Default for Config {
}
}

#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct MempoolKey {
inserted_at: Instant,
}

pub struct MempoolMessagesRequest {
pub shard_id: u32,
pub message_tx: oneshot::Sender<Vec<MempoolMessage>>,
pub max_messages_per_block: u32,
}

pub struct Mempool {
shard_senders: HashMap<u32, Senders>,
shard_stores: HashMap<u32, Stores>,
message_router: Box<dyn MessageRouter>,
num_shards: u32,
mempool_rx: mpsc::Receiver<MempoolMessage>,
messages_request_rx: mpsc::Receiver<MempoolMessagesRequest>,
messages: HashMap<u32, BTreeMap<MempoolKey, MempoolMessage>>,
}

impl Mempool {
pub fn new(
mempool_rx: mpsc::Receiver<MempoolMessage>,
messages_request_rx: mpsc::Receiver<MempoolMessagesRequest>,
num_shards: u32,
shard_senders: HashMap<u32, Senders>,
shard_stores: HashMap<u32, Stores>,
) -> Self {
Mempool {
shard_senders,
shard_stores,
num_shards,
mempool_rx,
message_router: Box::new(ShardRouter {}),
messages: HashMap::new(),
messages_request_rx,
}
}

Expand Down Expand Up @@ -75,32 +88,59 @@ impl Mempool {
}
}

fn is_message_already_merged(&mut self, message: &MempoolMessage) -> bool {
let fid = message.fid();
match message {
MempoolMessage::UserMessage(message) => {
self.message_exists_in_trie(fid, TrieKey::for_message(message))
async fn pull_messages(&mut self, request: MempoolMessagesRequest) {
let mut messages = vec![];
while messages.len() < request.max_messages_per_block as usize {
let shard_messages = self.messages.get_mut(&request.shard_id);
match shard_messages {
None => break,
Some(shard_messages) => {
match shard_messages.pop_first() {
None => break,
Some((_, next_message)) => {
if self.message_is_valid(&next_message) {
messages.push(next_message);
}
}
};
}
}
}

if let Err(_) = request.message_tx.send(messages) {
error!("Unable to send message from mempool");
}
}

fn get_trie_key(message: &MempoolMessage) -> Option<Vec<u8>> {
match message {
MempoolMessage::UserMessage(message) => return Some(TrieKey::for_message(message)),
MempoolMessage::ValidatorMessage(validator_message) => {
if let Some(onchain_event) = &validator_message.on_chain_event {
return self
.message_exists_in_trie(fid, TrieKey::for_onchain_event(&onchain_event));
return Some(TrieKey::for_onchain_event(&onchain_event));
}

if let Some(fname_transfer) = &validator_message.fname_transfer {
if let Some(proof) = &fname_transfer.proof {
let name = String::from_utf8(proof.name.clone()).unwrap();
return self.message_exists_in_trie(
fid,
TrieKey::for_fname(fname_transfer.id, &name),
);
return Some(TrieKey::for_fname(fname_transfer.id, &name));
}
}
false

return None;
}
}
}

fn is_message_already_merged(&mut self, message: &MempoolMessage) -> bool {
let fid = message.fid();
let trie_key = Self::get_trie_key(&message);
match trie_key {
Some(trie_key) => self.message_exists_in_trie(fid, trie_key),
None => false,
}
}

pub fn message_is_valid(&mut self, message: &MempoolMessage) -> bool {
if self.is_message_already_merged(message) {
return false;
Expand All @@ -110,18 +150,32 @@ impl Mempool {
}

pub async fn run(&mut self) {
while let Some(message) = self.mempool_rx.recv().await {
if self.message_is_valid(&message) {
let fid = message.fid();
let shard = self.message_router.route_message(fid, self.num_shards);
let senders = self.shard_senders.get(&shard);
match senders {
None => {
error!("Unable to find shard to send message to")
loop {
tokio::select! {
biased;

message_request = self.messages_request_rx.recv() => {
if let Some(messages_request) = message_request {
self.pull_messages(messages_request).await
}
Some(senders) => {
if let Err(err) = senders.messages_tx.send(message).await {
error!("Unable to send message to engine: {}", err.to_string())
}
message = self.mempool_rx.recv() => {
if let Some(message) = message {
// TODO(aditi): Maybe we don't need to run validations here?
if self.message_is_valid(&message) {
let fid = message.fid();
let shard_id = self.message_router.route_message(fid, self.num_shards);
// TODO(aditi): We need a size limit on the mempool and we need to figure out what to do if it's exceeded
match self.messages.get_mut(&shard_id) {
None => {
let mut messages = BTreeMap::new();
messages.insert(MempoolKey { inserted_at: Instant::now()}, message.clone());
self.messages.insert(shard_id, messages);
}
Some(messages) => {
messages.insert(MempoolKey { inserted_at: Instant::now()}, message.clone());
}
}
}
}
}
Expand Down
3 changes: 2 additions & 1 deletion src/mempool/mempool_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@ mod tests {

fn setup() -> (ShardEngine, Mempool) {
let (_mempool_tx, mempool_rx) = mpsc::channel(100);
let (_mempool_tx, messages_request_rx) = mpsc::channel(100);
let (engine, _) = test_helper::new_engine();
let mut shard_senders = HashMap::new();
shard_senders.insert(1, engine.get_senders());
let mut shard_stores = HashMap::new();
shard_stores.insert(1, engine.get_stores());
let mempool = Mempool::new(mempool_rx, 1, shard_senders, shard_stores);
let mempool = Mempool::new(mempool_rx, messages_request_rx, 1, shard_stores);
(engine, mempool)
}

Expand Down
1 change: 1 addition & 0 deletions src/network/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ impl MyHubService {
stores.store_limits.clone(),
self.statsd_client.clone(),
100,
None,
);
let result = readonly_engine.simulate_message(&message);

Expand Down
10 changes: 6 additions & 4 deletions src/network/server_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,29 +134,31 @@ mod tests {
let (engine1, _) = test_helper::new_engine_with_options(test_helper::EngineOptions {
limits: Some(limits.clone()),
db_name: Some("db1.db".to_string()),
messages_request_tx: None,
});
let (engine2, _) = test_helper::new_engine_with_options(test_helper::EngineOptions {
limits: Some(limits.clone()),
db_name: Some("db2.db".to_string()),
messages_request_tx: None,
});
let db1 = engine1.db.clone();
let db2 = engine2.db.clone();

let (msgs_tx, _msgs_rx) = mpsc::channel(100);
let (_msgs_request_tx, msgs_request_rx) = mpsc::channel(100);

let shard1_stores = Stores::new(
db1,
merkle_trie::MerkleTrie::new(16).unwrap(),
limits.clone(),
);
let shard1_senders = Senders::new(msgs_tx.clone());
let shard1_senders = Senders::new();

let shard2_stores = Stores::new(
db2,
merkle_trie::MerkleTrie::new(16).unwrap(),
limits.clone(),
);
let shard2_senders = Senders::new(msgs_tx.clone());
let shard2_senders = Senders::new();
let stores = HashMap::from([(1, shard1_stores), (2, shard2_stores)]);
let senders = HashMap::from([(1, shard1_senders), (2, shard2_senders)]);
let num_shards = senders.len() as u32;
Expand All @@ -169,7 +171,7 @@ mod tests {
assert_eq!(message_router.route_message(SHARD2_FID, 2), 2);

let (mempool_tx, mempool_rx) = mpsc::channel(1000);
let mut mempool = Mempool::new(mempool_rx, num_shards, senders.clone(), stores.clone());
let mut mempool = Mempool::new(mempool_rx, msgs_request_rx, num_shards, stores.clone());
tokio::spawn(async move { mempool.run().await });

(
Expand Down
3 changes: 3 additions & 0 deletions src/node/snapchain_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use crate::core::types::{
Address, Height, ShardId, SnapchainShard, SnapchainValidator, SnapchainValidatorContext,
SnapchainValidatorSet,
};
use crate::mempool::mempool::MempoolMessagesRequest;
use crate::network::gossip::GossipEvent;
use crate::proto::{Block, ShardChunk};
use crate::storage::db::RocksDB;
Expand Down Expand Up @@ -39,6 +40,7 @@ impl SnapchainNode {
rpc_address: Option<String>,
gossip_tx: mpsc::Sender<GossipEvent<SnapchainValidatorContext>>,
block_tx: Option<mpsc::Sender<Block>>,
messages_request_tx: mpsc::Sender<MempoolMessagesRequest>,
block_store: BlockStore,
rocksdb_dir: String,
statsd_client: StatsdClientWrapper,
Expand Down Expand Up @@ -91,6 +93,7 @@ impl SnapchainNode {
StoreLimits::default(),
statsd_client.clone(),
config.max_messages_per_block,
Some(messages_request_tx.clone()),
);

shard_senders.insert(shard_id, engine.get_senders());
Expand Down
19 changes: 17 additions & 2 deletions src/perf/engine_only_perftest.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
use tokio::sync::mpsc;

use crate::mempool::mempool::Mempool;
use crate::proto::{Height, ShardChunk, ShardHeader};
use crate::storage::store::engine::{MempoolMessage, ShardStateChange};
use crate::storage::store::stores::StoreLimits;
use crate::storage::store::test_helper;
use crate::utils::cli::compose_message;
use std::collections::HashMap;
use std::error::Error;
use std::time::Duration;

Expand All @@ -28,16 +32,27 @@ fn state_change_to_shard_chunk(
}

pub async fn run() -> Result<(), Box<dyn Error>> {
let (mempool_tx, mempool_rx) = mpsc::channel(1000);
let (messages_request_tx, messages_request_rx) = mpsc::channel(100);

let (mut engine, _tmpdir) = test_helper::new_engine_with_options(test_helper::EngineOptions {
limits: Some(StoreLimits {
limits: test_helper::limits::unlimited(),
legacy_limits: test_helper::limits::unlimited(),
}),
db_name: None,
messages_request_tx: Some(messages_request_tx),
});

let mut shard_stores = HashMap::new();
shard_stores.insert(1, engine.get_stores());
let mut mempool = Mempool::new(mempool_rx, messages_request_rx, 1, shard_stores);

tokio::spawn(async move {
mempool.run().await;
});

let mut i = 0;
let messages_tx = engine.messages_tx();

let fid = test_helper::FID_FOR_TEST;

Expand All @@ -54,7 +69,7 @@ pub async fn run() -> Result<(), Box<dyn Error>> {
let text = format!("For benchmarking {}", i);
let msg = compose_message(fid, text.as_str(), None, None);

messages_tx
mempool_tx
.send(MempoolMessage::UserMessage(msg.clone()))
.await
.unwrap();
Expand Down
Loading

0 comments on commit 23fe464

Please sign in to comment.