diff --git a/Cargo.lock b/Cargo.lock index 9d9c63625..7c8dc01f9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,6 +74,19 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "async-channel" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28243a43d821d11341ab73c80bed182dc015c514b951616cf79bd4af39af0c3" +dependencies = [ + "concurrent-queue", + "event-listener", + "event-listener-strategy", + "futures-core", + "pin-project-lite", +] + [[package]] name = "async-trait" version = "0.1.79" @@ -167,6 +180,12 @@ dependencies = [ "regex-automata 0.1.10", ] +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + [[package]] name = "bytes" version = "1.6.0" @@ -199,7 +218,7 @@ checksum = "4ea181bf566f71cb9a5d17a59e1871af638180a18fb0035c92ae62b705207123" dependencies = [ "atty", "bitflags 1.3.2", - "clap_derive", + "clap_derive 3.2.25", "clap_lex 0.2.4", "indexmap 1.9.3", "once_cell", @@ -215,6 +234,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" dependencies = [ "clap_builder", + "clap_derive 4.5.4", ] [[package]] @@ -235,13 +255,25 @@ version = "3.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae6371b8bdc8b7d3959e9cf7b22d4435ef3e79e138688421ec654acf8c81b008" dependencies = [ - "heck", + "heck 0.4.1", "proc-macro-error", "proc-macro2", "quote", "syn 1.0.109", ] +[[package]] +name = "clap_derive" +version = "4.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "528131438037fd55894f62d6e9f068b8f45ac57ffa77517819645d10aed04f64" +dependencies = [ + "heck 0.5.0", + "proc-macro2", + "quote", + "syn 2.0.58", +] + [[package]] name = "clap_lex" version = "0.2.4" @@ -272,6 +304,15 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" +[[package]] +name = "concurrent-queue" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "config" version = "0.14.0" @@ -342,6 +383,30 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ab3db02a9c5b5121e1e42fbdb1aeb65f5e02624cc58c43f2884c6ccac0b82f95" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + [[package]] name = "crunchy" version = "0.2.2" @@ -428,6 +493,12 @@ dependencies = [ "syn 2.0.58", ] +[[package]] +name = "either" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11157ac094ffbdde99aa67b23417ebdd801842852b500e395a45a9c0aac03e4a" + [[package]] name = "encode_unicode" version = "0.3.6" @@ -482,6 +553,27 @@ version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0474425d51df81997e2f90a21591180b38eccf27292d755f3e30750225c175b" +[[package]] +name = "event-listener" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9944b8ca13534cdfb2800775f8dd4902ff3fc75a50101466decadfdf322a24" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "332f51cb23d20b0de8458b86580878211da09bcd4503cb579c225b3d124cabb3" +dependencies = [ + "event-listener", + "pin-project-lite", +] + [[package]] name = "fastrand" version = "2.0.2" @@ -499,6 +591,16 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "flate2" +version = "1.0.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46303f565772937ffe1d394a4fac6f411c6013172fadde9dcdb1e147a086940e" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + [[package]] name = "fs4" version = "0.8.2" @@ -661,12 +763,32 @@ version = "0.14.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" +[[package]] +name = "hdrhistogram" +version = "7.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "765c9198f173dd59ce26ff9f95ef0aafd0a0fe01fb9d72841bc5066a4c06511d" +dependencies = [ + "base64", + "byteorder", + "crossbeam-channel", + "flate2", + "nom", + "num-traits", +] + [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + [[package]] name = "hermit-abi" version = "0.1.19" @@ -717,6 +839,15 @@ dependencies = [ "hashbrown 0.14.3", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -853,6 +984,15 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", +] + [[package]] name = "num_cpus" version = "1.16.0" @@ -903,6 +1043,12 @@ version = "6.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e2355d85b9a3786f481747ced0e0ff2ba35213a1f9bd406ed906554d7af805a1" +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + [[package]] name = "parking_lot" version = "0.12.1" @@ -1675,6 +1821,7 @@ dependencies = [ name = "toydb" version = "0.1.0" dependencies = [ + "async-channel", "bincode", "clap 4.5.4", "config", @@ -1683,7 +1830,9 @@ dependencies = [ "futures", "futures-util", "goldenfile", + "hdrhistogram", "hex", + "itertools", "lazy_static", "log", "names", diff --git a/Cargo.toml b/Cargo.toml index 5d7443fb5..d338cd026 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,14 +10,17 @@ default-run = "toydb" doctest = false [dependencies] +async-channel = "~2.2.0" bincode = "~1.3.3" -clap = { version = "~4.5.4", features = ["cargo"] } +clap = { version = "~4.5.4", features = ["cargo", "derive"] } config = "~0.14.0" derivative = "~2.2.0" fs4 = "~0.8.1" futures = "~0.3.15" futures-util = "~0.3.15" +hdrhistogram = "~7.5.4" hex = "~0.4.3" +itertools = "0.12.1" lazy_static = "~1.4.0" log = "~0.4.14" names = "~0.14.0" diff --git a/docs/architecture.md b/docs/architecture.md index b5f5e2de6..4ae6f07ee 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -735,13 +735,7 @@ out of scope for the project. The toyDB [`Client`](https://github.com/erikgrinaker/toydb/blob/master/src/client.rs) provides a simple API for interacting with a server, mainly by executing SQL statements via `execute()` -returning `sql::ResultSet`. It also has the convenience method `with_txn()`, taking a closure -that executes a series of SQL statements while automatically catching and retrying serialization -errors. - -There is also `client::Pool`, which manages a set of pre-connected clients that can be retrieved -for running short-lived queries in a multi-threaded application without incurring connection -setup costs. +returning `sql::ResultSet`. The [`toysql`](https://github.com/erikgrinaker/toydb/blob/master/src/bin/toysql.rs) command-line client is a simple REPL client that connects to a server using the toyDB `Client` and continually diff --git a/src/bin/bank.rs b/src/bin/bank.rs deleted file mode 100644 index d23013ab7..000000000 --- a/src/bin/bank.rs +++ /dev/null @@ -1,279 +0,0 @@ -/* - * Simulates a bank, by creating a set of accounts and making concurrent transfers between them: - * - * - Connect to the given toyDB hosts (-H default 127.0.0.1:9605, can give multiple) - * - Create C customers (-C default 100) - * - Create a accounts per customer with initial balance 100 (-a default 10) - * - Spawn c concurrent workers (-c default 8) - * - Queue t transactions between two random customers (-t default 1000) - * - Begin a new transaction - * - Find the sender account with the largest balance - * - Find the receiver account with the lowest balance - * - Reduce the sender account by a random amount - * - Increase the receiver account by the same amount - * - Commit the transaction, or retry with exponential backoff on serialization errors - * - Check that invariants still hold (same total balance, no negative balances) - */ - -#![warn(clippy::all)] - -use futures::stream::TryStreamExt as _; -use rand::distributions::Distribution; -use rand::Rng as _; -use std::cell::Cell; -use std::rc::Rc; -use tokio::net::ToSocketAddrs; -use toydb::client::Pool; -use toydb::error::{Error, Result}; - -#[tokio::main] -async fn main() -> Result<()> { - let args = clap::command!() - .about("A bank workload, which makes concurrent transfers between accounts.") - .args([ - clap::Arg::new("host") - .short('H') - .long("host") - .help("Host to connect to, optionally with port number") - .num_args(1..) - .default_value("127.0.0.1:9605"), - clap::Arg::new("concurrency") - .short('c') - .long("concurrency") - .help("Concurrent workers to spawn") - .value_parser(clap::value_parser!(u64)) - .default_value("8"), - clap::Arg::new("customers") - .short('C') - .long("customers") - .help("Number of customers to create") - .value_parser(clap::value_parser!(u64)) - .default_value("100"), - clap::Arg::new("accounts") - .short('a') - .long("accounts") - .help("Number of accounts to create per customer") - .value_parser(clap::value_parser!(u64)) - .default_value("10"), - clap::Arg::new("transactions") - .short('t') - .long("transactions") - .help("Number of account transfers to execute") - .value_parser(clap::value_parser!(u64)) - .default_value("1000"), - ]) - .get_matches(); - - Bank::new( - args.get_many::("host").unwrap().collect(), - *args.get_one::("concurrency").unwrap(), - *args.get_one("customers").unwrap(), - *args.get_one("accounts").unwrap(), - ) - .await? - .run(*args.get_one("transactions").unwrap()) - .await -} - -struct Bank { - clients: Pool, - customers: u64, - customer_accounts: u64, -} - -impl Bank { - const INITIAL_BALANCE: u64 = 100; - - // Creates a new bank simulation. - async fn new( - addrs: Vec, - concurrency: u64, - customers: u64, - accounts: u64, - ) -> Result { - Ok(Self { - clients: Pool::new(addrs, concurrency).await?, - customers, - customer_accounts: accounts, - }) - } - - // Runs the bank simulation, making transfers between customer accounts. - async fn run(&self, transactions: u64) -> Result<()> { - self.setup().await?; - self.verify().await?; - println!(); - - let mut rng = rand::thread_rng(); - let customers = rand::distributions::Uniform::from(1..=self.customers); - let transfers = futures::stream::iter( - std::iter::from_fn(|| Some((customers.sample(&mut rng), customers.sample(&mut rng)))) - .filter(|(from, to)| from != to) - .map(Ok) - .take(transactions as usize), - ); - - let start = std::time::Instant::now(); - transfers - .try_for_each_concurrent(self.clients.size(), |(from, to)| self.transfer(from, to)) - .await?; - let elapsed = start.elapsed().as_secs_f64(); - - println!(); - println!( - "Ran {} transactions in {:.3}s ({:.3}/s)", - transactions, - elapsed, - transactions as f64 / elapsed - ); - - self.verify().await?; - Ok(()) - } - - // Sets up the database with customers and accounts. - async fn setup(&self) -> Result<()> { - let client = self.clients.get().await; - let start = std::time::Instant::now(); - client.execute("BEGIN").await?; - client - .execute( - "CREATE TABLE customer ( - id INTEGER PRIMARY KEY, - name STRING NOT NULL - )", - ) - .await?; - client - .execute( - "CREATE TABLE account ( - id INTEGER PRIMARY KEY, - customer_id INTEGER NOT NULL INDEX REFERENCES customer, - balance INTEGER NOT NULL - )", - ) - .await?; - client - .execute(&format!( - "INSERT INTO customer VALUES {}", - (1..=self.customers) - .zip(names::Generator::with_naming(names::Name::Plain)) - .map(|(id, name)| format!("({}, '{}')", id, name)) - .collect::>() - .join(", ") - )) - .await?; - client - .execute(&format!( - "INSERT INTO account VALUES {}", - (1..=self.customers) - .flat_map(|c| (1..=self.customer_accounts).map(move |a| (c, a))) - .map(|(c, a)| (c, (c - 1) * self.customer_accounts + a)) - .map(|(c, a)| (format!("({}, {}, {})", a, c, Self::INITIAL_BALANCE))) - .collect::>() - .join(", ") - )) - .await?; - client.execute("COMMIT").await?; - - println!( - "Created {} customers ({} accounts) in {:.3}s", - self.customers, - self.customers * self.customer_accounts, - start.elapsed().as_secs_f64() - ); - Ok(()) - } - - /// Verifies that all invariants hold (same total balance, no negative balances). - async fn verify(&self) -> Result<()> { - let client = self.clients.get().await; - let expect = self.customers * self.customer_accounts * Self::INITIAL_BALANCE; - let balance = - client.execute("SELECT SUM(balance) FROM account").await?.into_value()?.integer()? - as u64; - if balance != expect { - return Err(Error::Value(format!( - "Expected total balance {}, found {}", - expect, balance - ))); - } - let negative = client - .execute("SELECT COUNT(*) FROM account WHERE balance < 0") - .await? - .into_value()? - .integer()?; - if negative > 0 { - return Err(Error::Value(format!("Found {} accounts with negative balance", negative))); - } - println!("Verified that total balance is {} with no negative balances", balance); - Ok(()) - } - - /// Transfers a random amount between two customers, retrying serialization failures. - async fn transfer(&self, from: u64, to: u64) -> Result<()> { - let client = self.clients.get().await; - let attempts = Rc::new(Cell::new(0_u8)); - let start = std::time::Instant::now(); - - let (from_account, to_account, amount) = client - .with_txn(|txn| { - let attempts = attempts.clone(); - async move { - attempts.set(attempts.get() + 1); - let mut row = txn - .execute(&format!( - "SELECT a.id, a.balance - FROM account a JOIN customer c ON a.customer_id = c.id - WHERE c.id = {} - ORDER BY a.balance DESC - LIMIT 1", - from - )) - .await? - .into_row()?; - let from_account = row.remove(0).integer()?; - let from_balance = row.remove(0).integer()?; - - let to_account = txn - .execute(&format!( - "SELECT a.id, a.balance - FROM account a JOIN customer c ON a.customer_id = c.id - WHERE c.id = {} - ORDER BY a.balance ASC - LIMIT 1", - to - )) - .await? - .into_value()? - .integer()?; - - let amount = rand::thread_rng().gen_range(0..=from_balance); - txn.execute(&format!( - "UPDATE account SET balance = balance - {} WHERE id = {}", - amount, from_account, - )) - .await?; - txn.execute(&format!( - "UPDATE account SET balance = balance + {} WHERE id = {}", - amount, to_account, - )) - .await?; - Ok((from_account, to_account, amount)) - } - }) - .await?; - - println!( - "Thread {} transferred {: >4} from {: >3} ({:0>4}) to {: >3} ({:0>4}) in {:.3}s ({} attempts)", - client.id(), - amount, - from, - from_account, - to, - to_account, - start.elapsed().as_secs_f64(), - attempts.get()); - Ok(()) - } -} diff --git a/src/bin/workload.rs b/src/bin/workload.rs new file mode 100644 index 000000000..d3dd250a5 --- /dev/null +++ b/src/bin/workload.rs @@ -0,0 +1,374 @@ +//! Runs toyDB workload benchmarks. For example, a read-only +//! workload can be run as: +//! +//! cargo run --bin workload -- +//! --hosts localhost:9605,localhost:9604,localhost:9603 +//! --concurrency 32 --count 100000 +//! read --rows 1000 --size 65536 --batch 10 +//! +//! See --help for a list of available workloads and arguments. + +#![warn(clippy::all)] + +use clap::Parser; +use itertools::Itertools; +use rand::distributions::Distribution; +use rand::rngs::StdRng; +use rand::SeedableRng; +use std::collections::HashSet; +use std::io::Write as _; +use std::time::Duration; +use toydb::error::Result; +use toydb::{Client, ResultSet}; + +#[tokio::main] +async fn main() -> Result<()> { + // TODO: is there a better way to handle subcommands? + let runner = Runner::parse(); + match runner.subcommand { + Subcommand::Read(read) => runner.run(read).await, + Subcommand::Write(write) => runner.run(write).await, + } +} + +#[derive(clap::Parser)] +#[command(about = "Runs toyDB workload benchmarks.", version, propagate_version = true)] +/// Runs a workload benchmark. +struct Runner { + #[arg(short = 'H', long, value_delimiter = ',', default_value = "localhost:9605")] + /// Hosts to connect to (optionally with port number). + hosts: Vec, + + #[arg(short, long, default_value = "32")] + /// Number of concurrent workers to spawn. + concurrency: usize, + + #[arg(short = 'n', long, default_value = "100000")] + /// Number of transactions to execute. + count: usize, + + #[arg(short, long, default_value = "16791084677885396490")] + /// Seed to use for random number generation. + seed: u64, + + #[command(subcommand)] + /// The workload subcommand to execute. + subcommand: Subcommand, +} + +#[derive(clap::Subcommand)] +enum Subcommand { + Read(Read), + Write(Write), +} + +impl Runner { + /// Runs the specified workload. + async fn run(self, workload: impl Workload) -> Result<()> { + let mut rng = rand::rngs::StdRng::seed_from_u64(self.seed); + let mut client = Client::new(&self.hosts[0]).await?; + + // Set up a histogram recording txn latencies as nanoseconds. The + // buckets range from 0.001s to 10s. + let mut hist = + hdrhistogram::Histogram::::new_with_bounds(1_000, 10_000_000_000, 3)?.into_sync(); + + // Prepare the dataset. + print!("Preparing initial dataset... "); + std::io::stdout().flush()?; + let start = std::time::Instant::now(); + workload.prepare(&mut client, &mut rng).await?; + println!("done ({:.3}s)", start.elapsed().as_secs_f64()); + + // Spawn workers, round robin across hosts. + print!("Spawning {} workers... ", self.concurrency); + std::io::stdout().flush()?; + let start = std::time::Instant::now(); + + let mut js = tokio::task::JoinSet::>::new(); + let (work_tx, work_rx) = async_channel::bounded(self.concurrency); + + for addr in self.hosts.iter().cycle().take(self.concurrency) { + let mut client = Client::new(addr).await?; + let work_rx = work_rx.clone(); + let mut recorder = hist.recorder(); + js.spawn(async move { + while let Ok(item) = work_rx.recv().await { + let start = std::time::Instant::now(); + workload.execute(&mut client, item).await?; + recorder.record(start.elapsed().as_nanos() as u64)?; + } + Ok(()) + }); + } + + println!("done ({:.3}s)", start.elapsed().as_secs_f64()); + + // Spawn work generator. + { + println!("Running workload {}...", workload); + js.spawn(async move { + for item in workload.generate(rng).take(self.count) { + work_tx.send(item).await?; + } + work_tx.close(); + Ok(()) + }); + } + + // Wait for workers to complete, and periodically print stats. + let start = std::time::Instant::now(); + let mut ticker = tokio::time::interval(Duration::from_secs(1)); + ticker.tick().await; // skip first tick + + println!(); + println!("Time Progress Txns Rate p50 p90 p99 pMax"); + + let mut print_stats = || { + let duration = start.elapsed().as_secs_f64(); + hist.refresh(); + println!( + "{:<8} {:>5.1}% {:>7} {:>6.0}/s {:>6.1}ms {:>6.1}ms {:>6.1}ms {:>6.1}ms", + format!("{:.1}s", duration), + hist.len() as f64 / self.count as f64 * 100.0, + hist.len(), + hist.len() as f64 / duration, + Duration::from_nanos(hist.value_at_quantile(0.5)).as_secs_f64() * 1000.0, + Duration::from_nanos(hist.value_at_quantile(0.9)).as_secs_f64() * 1000.0, + Duration::from_nanos(hist.value_at_quantile(0.99)).as_secs_f64() * 1000.0, + Duration::from_nanos(hist.max()).as_secs_f64() * 1000.0, + ); + }; + + loop { + tokio::select! { + // Print stats every second. + _ = ticker.tick() => print_stats(), + + // Check if tasks are done. + result = js.join_next() => match result { + Some(result) => result??, + None => break, + }, + } + } + print_stats(); + println!(); + + // Verify the final dataset. + print!("Verifying dataset... "); + std::io::stdout().flush()?; + let start = std::time::Instant::now(); + workload.verify(&mut client, self.count).await?; + println!("done ({:.3}s)", start.elapsed().as_secs_f64()); + + Ok(()) + } +} + +/// A workload. +/// +/// TODO: Copy is mostly needed to use the workload in async tasks, remove this +/// when Tokio is removed. +trait Workload: Copy + std::fmt::Display + Send + 'static { + /// A work item. + type Item: Send; + + /// Prepares the workload by creating initial tables and data. + async fn prepare(&self, client: &mut Client, rng: &mut StdRng) -> Result<()>; + + /// Generates work items as an iterator. + fn generate(&self, rng: StdRng) -> impl Iterator + Send; + + /// Executes a single work item. + fn execute( + &self, + client: &mut Client, + item: Self::Item, + ) -> impl std::future::Future> + Send; + + /// Verifies the dataset after the workload has completed. + async fn verify(&self, _client: &mut Client, _txns: usize) -> Result<()> { + Ok(()) + } +} + +#[derive(clap::Args, Clone, Copy)] +#[command(about = "A read-only workload using primary key lookups")] +/// A read-only workload. Creates an id,value table and populates it with the +/// given row count and value size. Then runs batches of random primary key +/// lookups (SELECT * FROM read WHERE id = 1 OR id = 2 ...). +struct Read { + #[arg(short, long, default_value = "1000")] + /// Total number of rows in data set. + rows: u64, + #[arg(short, long, default_value = "64")] + /// Row value size (excluding primary key). + size: usize, + #[arg(short, long, default_value = "1")] + /// Number of rows to fetch in a single select. + batch: usize, +} + +impl std::fmt::Display for Read { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "read (rows={} size={} batch={})", self.rows, self.size, self.batch) + } +} + +impl Workload for Read { + type Item = HashSet; + + async fn prepare(&self, client: &mut Client, rng: &mut StdRng) -> Result<()> { + client.execute("BEGIN").await?; + client.execute(r#"DROP TABLE IF EXISTS "read""#).await?; + client + .execute(r#"CREATE TABLE "read" (id INT PRIMARY KEY, value STRING NOT NULL)"#) + .await?; + + let chars = &mut rand::distributions::Alphanumeric.sample_iter(rng).map(|b| b as char); + let rows = (1..=self.rows).map(|id| (id, chars.take(self.size).collect::())); + let chunks = rows.chunks(100); + let queries = chunks.into_iter().map(|chunk| { + format!( + r#"INSERT INTO "read" (id, value) VALUES ({})"#, + chunk.map(|(id, value)| format!("{}, '{}'", id, value)).join("), (") + ) + }); + for query in queries { + client.execute(&query).await?; + } + client.execute("COMMIT").await?; + Ok(()) + } + + fn generate(&self, rng: StdRng) -> impl Iterator { + ReadGenerator { + batch: self.batch, + dist: rand::distributions::Uniform::new(1, self.rows + 1), + rng, + } + } + + async fn execute(&self, client: &mut Client, item: Self::Item) -> Result<()> { + let query = format!( + r#"SELECT * FROM "read" WHERE {}"#, + item.into_iter().map(|id| format!("id = {}", id)).join(" OR ") + ); + let rows = client.execute(&query).await?.into_rows()?; + assert_eq!(rows.count(), self.batch, "Unexpected row count"); + Ok(()) + } + + async fn verify(&self, client: &mut Client, _: usize) -> Result<()> { + let count = + client.execute(r#"SELECT COUNT(*) FROM "read""#).await?.into_value()?.integer()?; + assert_eq!(count as u64, self.rows, "Unexpected row count"); + Ok(()) + } +} + +/// A Read workload generator, yielding batches of random, unique primary keys. +struct ReadGenerator { + batch: usize, + rng: StdRng, + dist: rand::distributions::Uniform, +} + +impl Iterator for ReadGenerator { + type Item = ::Item; + + fn next(&mut self) -> Option { + let mut ids = HashSet::new(); + for id in self.dist.sample_iter(&mut self.rng) { + ids.insert(id); + if ids.len() >= self.batch { + break; + } + } + Some(ids) + } +} + +#[derive(clap::Args, Clone, Copy)] +#[command(about = "A write-only workload writing sequential rows")] +/// A write-only workload. Creates an id,value table, and writes rows with +/// sequential primary keys and the given value size, in the given batch size +/// (INSERT INTO write (id, value) VALUES ...). The number of rows written +/// is given by Runner.count * Write.batch. +struct Write { + #[arg(short, long, default_value = "64")] + /// Row value size (excluding primary key). + size: usize, + #[arg(short, long, default_value = "1")] + /// Number of rows to write in a single insert query. + batch: usize, +} + +impl std::fmt::Display for Write { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "write (size={} batch={})", self.size, self.batch) + } +} + +impl Workload for Write { + type Item = Vec<(u64, String)>; + + async fn prepare(&self, client: &mut Client, _: &mut StdRng) -> Result<()> { + client.execute("BEGIN").await?; + client.execute(r#"DROP TABLE IF EXISTS "write""#).await?; + client + .execute(r#"CREATE TABLE "write" (id INT PRIMARY KEY, value STRING NOT NULL)"#) + .await?; + client.execute("COMMIT").await?; + Ok(()) + } + + fn generate(&self, rng: StdRng) -> impl Iterator { + WriteGenerator { next_id: 1, size: self.size, batch: self.batch, rng } + } + + async fn execute(&self, client: &mut Client, item: Self::Item) -> Result<()> { + let query = format!( + r#"INSERT INTO "write" (id, value) VALUES {}"#, + item.into_iter().map(|(id, value)| format!("({}, '{}')", id, value)).join(", ") + ); + if let ResultSet::Create { count } = client.execute(&query).await? { + assert_eq!(count as usize, self.batch, "Unexpected row count"); + } else { + panic!("Unexpected result") + } + Ok(()) + } + + async fn verify(&self, client: &mut Client, txns: usize) -> Result<()> { + let count = + client.execute(r#"SELECT COUNT(*) FROM "write""#).await?.into_value()?.integer()?; + assert_eq!(count as usize, txns * self.batch, "Unexpected row count"); + Ok(()) + } +} + +/// A Write workload generator, yielding batches of sequential primary keys and +/// random rows. +struct WriteGenerator { + next_id: u64, + size: usize, + batch: usize, + rng: StdRng, +} + +impl Iterator for WriteGenerator { + type Item = ::Item; + + fn next(&mut self) -> Option { + let chars = + &mut rand::distributions::Alphanumeric.sample_iter(&mut self.rng).map(|b| b as char); + let mut rows = Vec::with_capacity(self.batch); + while rows.len() < self.batch { + rows.push((self.next_id, chars.take(self.size).collect())); + self.next_id += 1; + } + Some(rows) + } +} diff --git a/src/client.rs b/src/client.rs index e5b7df8ec..f786312ea 100644 --- a/src/client.rs +++ b/src/client.rs @@ -4,16 +4,9 @@ use crate::sql::engine::Status; use crate::sql::execution::ResultSet; use crate::sql::schema::Table; -use futures::future::FutureExt as _; use futures::sink::SinkExt as _; use futures::stream::TryStreamExt as _; -use rand::Rng as _; -use std::cell::Cell; -use std::future::Future; -use std::ops::{Deref, Drop}; -use std::sync::Arc; use tokio::net::{TcpStream, ToSocketAddrs}; -use tokio::sync::{Mutex, MutexGuard}; use tokio_util::codec::{Framed, LengthDelimitedCodec}; type Connection = tokio_serde::Framed< @@ -23,59 +16,43 @@ type Connection = tokio_serde::Framed< tokio_serde::formats::Bincode, Request>, >; -/// Number of serialization retries in with_txn() -const WITH_TXN_RETRIES: u8 = 8; - /// A toyDB client -#[derive(Clone)] pub struct Client { - conn: Arc>, - txn: Cell>, + conn: Connection, + txn: Option<(u64, bool)>, } impl Client { /// Creates a new client pub async fn new(addr: A) -> Result { Ok(Self { - conn: Arc::new(Mutex::new(tokio_serde::Framed::new( + conn: tokio_serde::Framed::new( Framed::new(TcpStream::connect(addr).await?, LengthDelimitedCodec::new()), tokio_serde::formats::Bincode::default(), - ))), - txn: Cell::new(None), + ), + txn: None, }) } /// Call a server method - async fn call(&self, request: Request) -> Result { - let mut conn = self.conn.lock().await; - self.call_locked(&mut conn, request).await - } - - /// Call a server method while holding the mutex lock - async fn call_locked( - &self, - conn: &mut MutexGuard<'_, Connection>, - request: Request, - ) -> Result { - conn.send(request).await?; - match conn.try_next().await? { + async fn call(&mut self, request: Request) -> Result { + self.conn.send(request).await?; + match self.conn.try_next().await? { Some(result) => result, None => Err(Error::Internal("Server disconnected".into())), } } /// Executes a query - pub async fn execute(&self, query: &str) -> Result { - let mut conn = self.conn.lock().await; - let mut resultset = - match self.call_locked(&mut conn, Request::Execute(query.into())).await? { - Response::Execute(rs) => rs, - resp => return Err(Error::Internal(format!("Unexpected response {:?}", resp))), - }; + pub async fn execute(&mut self, query: &str) -> Result { + let mut resultset = match self.call(Request::Execute(query.into())).await? { + Response::Execute(rs) => rs, + resp => return Err(Error::Internal(format!("Unexpected response {:?}", resp))), + }; if let ResultSet::Query { columns, .. } = resultset { // FIXME We buffer rows for now to avoid lifetime hassles let mut rows = Vec::new(); - while let Some(result) = conn.try_next().await? { + while let Some(result) = self.conn.try_next().await? { match result? { Response::Row(Some(row)) => rows.push(row), Response::Row(None) => break, @@ -87,16 +64,16 @@ impl Client { resultset = ResultSet::Query { columns, rows: Box::new(rows.into_iter().map(Ok)) } }; match &resultset { - ResultSet::Begin { version, read_only } => self.txn.set(Some((*version, *read_only))), - ResultSet::Commit { .. } => self.txn.set(None), - ResultSet::Rollback { .. } => self.txn.set(None), + ResultSet::Begin { version, read_only } => self.txn = Some((*version, *read_only)), + ResultSet::Commit { .. } => self.txn = None, + ResultSet::Rollback { .. } => self.txn = None, _ => {} } Ok(resultset) } /// Fetches the table schema as SQL - pub async fn get_table(&self, table: &str) -> Result { + pub async fn get_table(&mut self, table: &str) -> Result
{ match self.call(Request::GetTable(table.into())).await? { Response::GetTable(t) => Ok(t), resp => Err(Error::Value(format!("Unexpected response: {:?}", resp))), @@ -104,7 +81,7 @@ impl Client { } /// Lists database tables - pub async fn list_tables(&self) -> Result> { + pub async fn list_tables(&mut self) -> Result> { match self.call(Request::ListTables).await? { Response::ListTables(t) => Ok(t), resp => Err(Error::Value(format!("Unexpected response: {:?}", resp))), @@ -112,7 +89,7 @@ impl Client { } /// Checks server status - pub async fn status(&self) -> Result { + pub async fn status(&mut self) -> Result { match self.call(Request::Status).await? { Response::Status(s) => Ok(s), resp => Err(Error::Value(format!("Unexpected response: {:?}", resp))), @@ -121,106 +98,6 @@ impl Client { /// Returns the version and read-only state of the txn pub fn txn(&self) -> Option<(u64, bool)> { - self.txn.get() - } - - /// Runs a query in a transaction, automatically retrying serialization failures with - /// exponential backoff. - pub async fn with_txn(&self, mut with: W) -> Result - where - W: FnMut(Client) -> F, - F: Future>, - { - for i in 0..WITH_TXN_RETRIES { - if i > 0 { - tokio::time::sleep(std::time::Duration::from_millis( - 2_u64.pow(i as u32 - 1) * rand::thread_rng().gen_range(25..=75), - )) - .await; - } - let result = async { - self.execute("BEGIN").await?; - let result = with(self.clone()).await?; - self.execute("COMMIT").await?; - Ok(result) - } - .await; - if result.is_err() { - self.execute("ROLLBACK").await.ok(); - if matches!(result, Err(Error::Serialization) | Err(Error::Abort)) { - continue; - } - } - return result; - } - Err(Error::Serialization) - } -} - -/// A toyDB client pool -pub struct Pool { - clients: Vec>, -} - -impl Pool { - /// Creates a new connection pool for the given servers, eagerly connecting clients. - pub async fn new(addrs: Vec, size: u64) -> Result { - let mut addrs = addrs.into_iter().cycle(); - let clients = futures::future::try_join_all( - std::iter::from_fn(|| { - Some(Client::new(addrs.next().unwrap()).map(|r| r.map(Mutex::new))) - }) - .take(size as usize), - ) - .await?; - Ok(Self { clients }) - } - - /// Fetches a client from the pool. It is reset (i.e. any open txns are rolled back) and - /// returned when it goes out of scope. - pub async fn get(&self) -> PoolClient<'_> { - let (client, index, _) = - futures::future::select_all(self.clients.iter().map(|m| m.lock().boxed())).await; - PoolClient::new(index, client) - } - - /// Returns the size of the pool - pub fn size(&self) -> usize { - self.clients.len() - } -} - -/// A client returned from the pool -pub struct PoolClient<'a> { - id: usize, - client: MutexGuard<'a, Client>, -} - -impl<'a> PoolClient<'a> { - /// Creates a new PoolClient - fn new(id: usize, client: MutexGuard<'a, Client>) -> Self { - Self { id, client } - } - - /// Returns the ID of the client in the pool - pub fn id(&self) -> usize { - self.id - } -} - -impl<'a> Deref for PoolClient<'a> { - type Target = MutexGuard<'a, Client>; - - fn deref(&self) -> &Self::Target { - &self.client - } -} - -impl<'a> Drop for PoolClient<'a> { - fn drop(&mut self) { - if self.txn().is_some() { - // FIXME This should disconnect or destroy the client if it errors. - futures::executor::block_on(self.client.execute("ROLLBACK")).ok(); - } + self.txn } } diff --git a/src/error.rs b/src/error.rs index e2b4484be..39de0a5f6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -43,6 +43,12 @@ impl serde::de::Error for Error { } } +impl From> for Error { + fn from(err: async_channel::SendError) -> Self { + Error::Internal(err.to_string()) + } +} + impl From> for Error { fn from(err: Box) -> Self { Error::Internal(err.to_string()) @@ -55,6 +61,18 @@ impl From for Error { } } +impl From for Error { + fn from(err: hdrhistogram::CreationError) -> Self { + Error::Internal(err.to_string()) + } +} + +impl From for Error { + fn from(err: hdrhistogram::RecordError) -> Self { + Error::Internal(err.to_string()) + } +} + impl From for Error { fn from(err: hex::FromHexError) -> Self { Error::Internal(err.to_string()) diff --git a/src/lib.rs b/src/lib.rs index 9cd6a63d7..1a77a95be 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -11,3 +11,4 @@ pub mod storage; pub use client::Client; pub use server::Server; +pub use sql::execution::ResultSet; diff --git a/src/sql/execution/mod.rs b/src/sql/execution/mod.rs index 013c1df20..363dcbc67 100644 --- a/src/sql/execution/mod.rs +++ b/src/sql/execution/mod.rs @@ -129,8 +129,14 @@ impl ResultSet { /// Converts the ResultSet into a row, or errors if not a query result with rows. pub fn into_row(self) -> Result { - if let ResultSet::Query { mut rows, .. } = self { - rows.next().transpose()?.ok_or_else(|| Error::Value("No rows returned".into())) + self.into_rows()?.next().transpose()?.ok_or_else(|| Error::Value("No rows returned".into())) + } + + /// Converts the ResultSet into a row iterator, or errors if not a query + /// result with rows. + pub fn into_rows(self) -> Result { + if let ResultSet::Query { rows, .. } = self { + Ok(rows) } else { Err(Error::Value(format!("Not a query result: {:?}", self))) } diff --git a/tests/client/mod.rs b/tests/client/mod.rs index 2eaaaf112..dc6b6bd16 100644 --- a/tests/client/mod.rs +++ b/tests/client/mod.rs @@ -1,5 +1,3 @@ -mod pool; - use super::{assert_row, assert_rows, setup}; use toydb::error::{Error, Result}; @@ -18,7 +16,7 @@ use serial_test::serial; #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn get_table() -> Result<()> { - let (c, _teardown) = setup::server_with_client(setup::movies()).await?; + let (mut c, _teardown) = setup::server_with_client(setup::movies()).await?; assert_eq!( c.get_table("unknown").await, @@ -108,7 +106,7 @@ async fn get_table() -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn list_tables() -> Result<()> { - let (c, _teardown) = setup::server_with_client(setup::movies()).await?; + let (mut c, _teardown) = setup::server_with_client(setup::movies()).await?; assert_eq!(c.list_tables().await?, vec!["countries", "genres", "movies", "studios"]); Ok(()) @@ -117,7 +115,7 @@ async fn list_tables() -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn status() -> Result<()> { - let (c, _teardown) = setup::server_with_client(setup::movies()).await?; + let (mut c, _teardown) = setup::server_with_client(setup::movies()).await?; assert_eq!( c.status().await?, @@ -158,7 +156,7 @@ async fn status() -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn execute() -> Result<()> { - let (c, _teardown) = setup::server_with_client(setup::movies()).await?; + let (mut c, _teardown) = setup::server_with_client(setup::movies()).await?; // SELECT let result = c.execute("SELECT * FROM genres").await?; @@ -241,7 +239,7 @@ async fn execute() -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn execute_txn() -> Result<()> { - let (c, _teardown) = setup::server_with_client(setup::movies()).await?; + let (mut c, _teardown) = setup::server_with_client(setup::movies()).await?; assert_eq!(c.txn(), None); @@ -332,8 +330,8 @@ async fn execute_txn() -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn execute_txn_concurrent() -> Result<()> { - let (a, _teardown) = setup::server_with_client(setup::movies()).await?; - let b = Client::new("127.0.0.1:9605").await?; + let (mut a, _teardown) = setup::server_with_client(setup::movies()).await?; + let mut b = Client::new("127.0.0.1:9605").await?; // Concurrent updates should throw a serialization failure on conflict. assert_eq!(a.execute("BEGIN").await?, ResultSet::Begin { version: 2, read_only: false }); diff --git a/tests/client/pool.rs b/tests/client/pool.rs deleted file mode 100644 index 413958ba2..000000000 --- a/tests/client/pool.rs +++ /dev/null @@ -1,67 +0,0 @@ -use super::super::{assert_rows, setup}; - -use toydb::error::Result; -use toydb::sql::types::Value; - -use futures::future::FutureExt as _; -use pretty_assertions::assert_eq; -use serial_test::serial; -use std::collections::HashSet; -use std::iter::FromIterator as _; - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -#[serial] -#[allow(clippy::many_single_char_names)] -async fn get() -> Result<()> { - let (pool, _teardown) = setup::cluster_with_pool(3, 5, setup::simple()).await?; - - // The clients are allocated to all servers - let a = pool.get().await; - let b = pool.get().await; - let c = pool.get().await; - let d = pool.get().await; - let e = pool.get().await; - - let mut servers = HashSet::new(); - let mut ids = HashSet::new(); - for client in [a, b, c, d, e] { - servers.insert(client.status().await?.raft.server); - ids.insert(client.id()); - } - assert_eq!(servers, HashSet::from_iter(vec![1, 2, 3])); - assert_eq!(ids, HashSet::from_iter(vec![0, 1, 2, 3, 4])); - - // Further clients won't be ready - assert!(tokio::spawn(async move { pool.get().await.id() }).now_or_never().is_none()); - - Ok(()) -} - -#[tokio::test(flavor = "multi_thread", worker_threads = 4)] -#[serial] -async fn drop_rollback() -> Result<()> { - let (pool, _teardown) = setup::cluster_with_pool(3, 1, setup::simple()).await?; - - // Starting a client and dropping it mid-transaction should work. - let a = pool.get().await; - assert_eq!(a.id(), 0); - assert_eq!(a.txn(), None); - a.execute("BEGIN").await?; - a.execute("INSERT INTO test VALUES (1, 'a')").await?; - assert_rows( - a.execute("SELECT * FROM test").await?, - vec![vec![Value::Integer(1), Value::String("a".into())]], - ); - std::mem::drop(a); - - // Fetching the client again from the pool should have reset it. - let a = pool.get().await; - assert_eq!(a.id(), 0); - assert_eq!(a.txn(), None); - assert_rows(a.execute("SELECT * FROM test").await?, Vec::new()); - a.execute("BEGIN").await?; - a.execute("INSERT INTO test VALUES (1, 'a')").await?; - a.execute("COMMIT").await?; - - Ok(()) -} diff --git a/tests/cluster/isolation.rs b/tests/cluster/isolation.rs index 8e0b0c8a3..53be778e7 100644 --- a/tests/cluster/isolation.rs +++ b/tests/cluster/isolation.rs @@ -9,7 +9,7 @@ use serial_test::serial; #[serial] // A dirty write is when b overwrites an uncommitted value written by a. async fn anomaly_dirty_write() -> Result<()> { - let (a, b, _, _teardown) = setup::cluster_simple().await?; + let (mut a, mut b, _, _teardown) = setup::cluster_simple().await?; a.execute("BEGIN").await?; a.execute("INSERT INTO test VALUES (1, 'a')").await?; @@ -29,7 +29,7 @@ async fn anomaly_dirty_write() -> Result<()> { #[serial] // A dirty read is when b can read an uncommitted value set by a. async fn anomaly_dirty_read() -> Result<()> { - let (a, b, _, _teardown) = setup::cluster_simple().await?; + let (mut a, mut b, _, _teardown) = setup::cluster_simple().await?; a.execute("BEGIN").await?; a.execute("INSERT INTO test VALUES (1, 'a')").await?; @@ -43,7 +43,7 @@ async fn anomaly_dirty_read() -> Result<()> { #[serial] // A lost update is when a and b both read a value and update it, where b's update replaces a. async fn anomaly_lost_update() -> Result<()> { - let (a, b, c, _teardown) = setup::cluster_simple().await?; + let (mut a, mut b, mut c, _teardown) = setup::cluster_simple().await?; c.execute("INSERT INTO test VALUES (1, 'c')").await?; @@ -69,7 +69,7 @@ async fn anomaly_lost_update() -> Result<()> { #[serial] // A fuzzy (or unrepeatable) read is when b sees a value change after a updates it. async fn anomaly_fuzzy_read() -> Result<()> { - let (a, b, c, _teardown) = setup::cluster_simple().await?; + let (mut a, mut b, mut c, _teardown) = setup::cluster_simple().await?; c.execute("INSERT INTO test VALUES (1, 'c')").await?; @@ -94,7 +94,7 @@ async fn anomaly_fuzzy_read() -> Result<()> { #[serial] // Read skew is when a reads 1 and 2, but b modifies 2 in between the reads. async fn anomaly_read_skew() -> Result<()> { - let (a, b, c, _teardown) = setup::cluster_simple().await?; + let (mut a, mut b, mut c, _teardown) = setup::cluster_simple().await?; c.execute("INSERT INTO test VALUES (1, 'c'), (2, 'c')").await?; @@ -120,7 +120,7 @@ async fn anomaly_read_skew() -> Result<()> { // A phantom read is when a reads entries matching some predicate, but a modification by // b changes the entries that match the predicate such that a later read by a returns them. async fn anomaly_phantom_read() -> Result<()> { - let (a, b, c, _teardown) = setup::cluster_simple().await?; + let (mut a, mut b, mut c, _teardown) = setup::cluster_simple().await?; c.execute("INSERT INTO test VALUES (1, 'true'), (2, 'false')").await?; diff --git a/tests/cluster/recovery.rs b/tests/cluster/recovery.rs index 73ebfa113..d6b02af27 100644 --- a/tests/cluster/recovery.rs +++ b/tests/cluster/recovery.rs @@ -9,7 +9,7 @@ use serial_test::serial; #[serial] // A client disconnect or termination should roll back its transaction. async fn client_disconnect_rollback() -> Result<()> { - let (a, b, _, _teardown) = setup::cluster_simple().await?; + let (mut a, mut b, _, _teardown) = setup::cluster_simple().await?; a.execute("BEGIN").await?; a.execute("INSERT INTO test VALUES (1, 'a')").await?; @@ -28,7 +28,7 @@ async fn client_disconnect_rollback() -> Result<()> { #[tokio::test(flavor = "multi_thread", worker_threads = 2)] #[serial] async fn client_commit_error() -> Result<()> { - let (a, b, _, _teardown) = setup::cluster_simple().await?; + let (mut a, mut b, _, _teardown) = setup::cluster_simple().await?; a.execute("BEGIN").await?; a.execute("INSERT INTO test VALUES (1, 'a')").await?; diff --git a/tests/setup.rs b/tests/setup.rs index 2e87bf9dc..2f7c9e695 100644 --- a/tests/setup.rs +++ b/tests/setup.rs @@ -1,6 +1,6 @@ #![allow(clippy::implicit_hasher)] -use toydb::client::{Client, Pool}; +use toydb::client::Client; use toydb::error::Result; use toydb::server::Server; use toydb::{raft, sql, storage}; @@ -97,7 +97,7 @@ pub async fn server( /// Sets up a server with a client pub async fn server_with_client(queries: Vec<&str>) -> Result<(Client, Teardown)> { let teardown = server(1, "127.0.0.1:9605", "127.0.0.1:9705", HashMap::new()).await?; - let client = Client::new("127.0.0.1:9605").await?; + let mut client = Client::new("127.0.0.1:9605").await?; if !queries.is_empty() { client.execute("BEGIN").await?; for query in queries { @@ -124,7 +124,7 @@ pub async fn cluster(nodes: HashMap) -> Result match client.status().await { + Ok(mut client) => match client.status().await { Ok(status) if status.raft.leader > 0 => break, Ok(_) => log::error!("no leader"), Err(err) => log::error!("Status failed for {}: {}", id, err), @@ -151,7 +151,7 @@ pub async fn cluster_with_clients(size: u8, queries: Vec<&str>) -> Result<(Vec::new(); for (id, (addr_sql, _)) in nodes { - let client = Client::new(addr_sql).await?; + let mut client = Client::new(addr_sql).await?; assert_eq!(id, client.status().await?.raft.server); clients.push(client); } @@ -168,36 +168,6 @@ pub async fn cluster_with_clients(size: u8, queries: Vec<&str>) -> Result<(Vec, -) -> Result<(Pool, Teardown)> { - let mut nodes = HashMap::new(); - for i in 1..=cluster_size { - nodes.insert( - i, - (format!("127.0.0.1:{}", 9605 + i as u64), format!("127.0.0.1:{}", 9705 + i as u64)), - ); - } - let teardown = cluster(nodes.clone()).await?; - - let pool = Pool::new(nodes.into_iter().map(|(_, (addr, _))| addr).collect(), pool_size).await?; - pool.get().await.status().await?; - - if !queries.is_empty() { - let c = pool.get().await; - c.execute("BEGIN").await?; - for query in queries { - c.execute(query).await?; - } - c.execute("COMMIT").await?; - } - - Ok((pool, teardown)) -} - /// Sets up a simple cluster with 3 clients and a test table pub async fn cluster_simple() -> Result<(Client, Client, Client, Teardown)> { let (mut clients, teardown) = cluster_with_clients(3, simple()).await?;