Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
erikgrinaker committed Apr 7, 2024
1 parent b65f2a7 commit 0c6df37
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 61 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ async-channel = "~2.2.0"
bincode = "~1.3.3"
clap = { version = "~4.5.4", features = ["cargo", "derive"] }
config = "~0.14.0"
crossbeam-channel = "0.5.12"
derivative = "~2.2.0"
fs4 = "~0.8.1"
futures = "~0.3.15"
Expand Down
8 changes: 4 additions & 4 deletions src/bin/toydb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

use serde_derive::Deserialize;
use std::collections::HashMap;
use tokio::net::TcpListener;
use std::net::TcpListener;
use toydb::error::{Error, Result};
use toydb::raft;
use toydb::sql;
Expand Down Expand Up @@ -59,10 +59,10 @@ async fn main() -> Result<()> {

let srv = Server::new(cfg.id, cfg.peers, raft_log, raft_state)?;

let raft_listener = TcpListener::bind(&cfg.listen_raft).await?;
let sql_listener = TcpListener::bind(&cfg.listen_sql).await?;
let raft_listener = TcpListener::bind(&cfg.listen_raft)?;
let sql_listener = TcpListener::bind(&cfg.listen_sql)?;

srv.serve(raft_listener, sql_listener).await
srv.serve(raft_listener, sql_listener)
}

#[derive(Debug, Deserialize)]
Expand Down
24 changes: 24 additions & 0 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,30 @@ impl From<config::ConfigError> for Error {
}
}

impl From<crossbeam_channel::RecvError> for Error {
fn from(err: crossbeam_channel::RecvError) -> Self {
Error::Internal(err.to_string())
}
}

impl<T> From<crossbeam_channel::SendError<T>> for Error {
fn from(err: crossbeam_channel::SendError<T>) -> Self {
Error::Internal(err.to_string())
}
}

impl From<crossbeam_channel::TryRecvError> for Error {
fn from(err: crossbeam_channel::TryRecvError) -> Self {
Error::Internal(err.to_string())
}
}

impl<T> From<crossbeam_channel::TrySendError<T>> for Error {
fn from(err: crossbeam_channel::TrySendError<T>) -> Self {
Error::Internal(err.to_string())
}
}

impl From<hdrhistogram::CreationError> for Error {
fn from(err: hdrhistogram::CreationError) -> Self {
Error::Internal(err.to_string())
Expand Down
87 changes: 46 additions & 41 deletions src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,66 +5,70 @@ use crate::sql::engine::Engine as _;
use crate::sql::execution::ResultSet;
use crate::sql::schema::{Catalog as _, Table};
use crate::sql::types::Row;
use crate::storage::bincode;

use ::log::{debug, error, info};
use futures::sink::SinkExt as _;
use serde_derive::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::net::{TcpListener, TcpStream};
use tokio::sync::mpsc;
use tokio_stream::wrappers::TcpListenerStream;
use tokio_stream::StreamExt as _;
use tokio_util::codec::{Framed, LengthDelimitedCodec};
use std::net::{TcpListener, TcpStream};

/// A toyDB server.
pub struct Server {
raft: raft::Server,
}
pub struct Server {}

impl Server {
/// Creates a new toyDB server.
pub fn new(
id: raft::NodeID,
peers: HashMap<raft::NodeID, String>,
raft_log: raft::Log,
raft_state: Box<dyn raft::State>,
_id: raft::NodeID,
_peers: HashMap<raft::NodeID, String>,
_raft_log: raft::Log,
_raft_state: Box<dyn raft::State>,
) -> Result<Self> {
Ok(Server { raft: raft::Server::new(id, peers, raft_log, raft_state)? })
Ok(Server {})
}

/// Serves Raft and SQL requests until the returned future is dropped. Consumes the server.
pub async fn serve(self, raft_listener: TcpListener, sql_listener: TcpListener) -> Result<()> {
/// Serves Raft and SQL requests indefinitely. Consumes the server.
pub fn serve(self, raft_listener: TcpListener, sql_listener: TcpListener) -> Result<()> {
info!(
"Listening on {} (SQL) and {} (Raft)",
sql_listener.local_addr()?,
raft_listener.local_addr()?
);

let (raft_tx, raft_rx) = mpsc::unbounded_channel();
let sql_engine = sql::engine::Raft::new(raft_tx);
let (raft_tx, _raft_rx) = crossbeam_channel::unbounded();

std::thread::scope(|s| {
s.spawn(|| Self::serve_sql(sql_listener, raft_tx));
});

tokio::try_join!(
self.raft.serve(raft_listener, raft_rx),
Self::serve_sql(sql_listener, sql_engine),
)?;
Ok(())
}

/// Serves SQL clients.
async fn serve_sql(listener: TcpListener, engine: sql::engine::Raft) -> Result<()> {
let mut listener = TcpListenerStream::new(listener);
while let Some(socket) = listener.try_next().await? {
let peer = socket.peer_addr()?;
let session = Session::new(engine.clone());
tokio::spawn(async move {
fn serve_sql(
listener: TcpListener,
raft_tx: crossbeam_channel::Sender<(
raft::Request,
crossbeam_channel::Sender<Result<raft::Response>>,
)>,
) {
std::thread::scope(|s| loop {
let (socket, peer) = match listener.accept() {
Ok(r) => r,
Err(err) => {
error!("Connection failed: {}", err);
continue;
}
};
let raft_tx = raft_tx.clone();
s.spawn(move || {
let session = Session::new(sql::engine::Raft::new(raft_tx));
info!("Client {} connected", peer);
match session.handle(socket).await {
match session.handle(socket) {
Ok(()) => info!("Client {} disconnected", peer),
Err(err) => error!("Client {} error: {}", peer, err),
}
});
}
Ok(())
});
}
}

Expand Down Expand Up @@ -100,18 +104,16 @@ impl Session {
}

/// Handles a client connection.
async fn handle(mut self, socket: TcpStream) -> Result<()> {
let mut stream = tokio_serde::Framed::new(
Framed::new(socket, LengthDelimitedCodec::new()),
tokio_serde::formats::Bincode::default(),
);
while let Some(request) = stream.try_next().await? {
let mut response = tokio::task::block_in_place(|| self.request(request));
fn handle(mut self, mut socket: TcpStream) -> Result<()> {
loop {
let request = bincode::deserialize_from(&socket)?;
let mut response = self.request(request);
let mut rows: Box<dyn Iterator<Item = Result<Response>> + Send> =
Box::new(std::iter::empty());
if let Ok(Response::Execute(ResultSet::Query { rows: ref mut resultrows, .. })) =
&mut response
{
// TODO: don't stream results, for simplicity.
rows = Box::new(
std::mem::replace(resultrows, Box::new(std::iter::empty()))
.map(|result| result.map(|row| Response::Row(Some(row))))
Expand All @@ -127,10 +129,13 @@ impl Session {
.fuse(),
);
}
stream.send(response).await?;
stream.send_all(&mut tokio_stream::iter(rows.map(Ok))).await?;

bincode::serialize_into(&mut socket, &response)?;

for row in rows {
bincode::serialize_into(&mut socket, &row)?;
}
}
Ok(())
}

/// Executes a request.
Expand Down
20 changes: 14 additions & 6 deletions src/sql/engine/raft.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ use crate::storage::{self, bincode, mvcc::TransactionState};

use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::collections::HashSet;
use tokio::sync::{mpsc, oneshot};

/// A Raft state machine mutation.
///
Expand Down Expand Up @@ -67,22 +66,28 @@ pub struct Status {
/// A client for the local Raft node.
#[derive(Clone)]
struct Client {
tx: mpsc::UnboundedSender<(raft::Request, oneshot::Sender<Result<raft::Response>>)>,
tx: crossbeam_channel::Sender<(
raft::Request,
crossbeam_channel::Sender<Result<raft::Response>>,
)>,
}

impl Client {
/// Creates a new Raft client.
fn new(
tx: mpsc::UnboundedSender<(raft::Request, oneshot::Sender<Result<raft::Response>>)>,
tx: crossbeam_channel::Sender<(
raft::Request,
crossbeam_channel::Sender<Result<raft::Response>>,
)>,
) -> Self {
Self { tx }
}

/// Executes a request against the Raft cluster.
fn execute(&self, request: raft::Request) -> Result<raft::Response> {
let (response_tx, response_rx) = oneshot::channel();
let (response_tx, response_rx) = crossbeam_channel::bounded(1);
self.tx.send((request, response_tx))?;
futures::executor::block_on(response_rx)?
response_rx.recv()?
}

/// Mutates the Raft state machine, deserializing the response into the
Expand Down Expand Up @@ -121,7 +126,10 @@ pub struct Raft {
impl Raft {
/// Creates a new Raft-based SQL engine.
pub fn new(
tx: mpsc::UnboundedSender<(raft::Request, oneshot::Sender<Result<raft::Response>>)>,
tx: crossbeam_channel::Sender<(
raft::Request,
crossbeam_channel::Sender<Result<raft::Response>>,
)>,
) -> Self {
Self { client: Client::new(tx) }
}
Expand Down
20 changes: 10 additions & 10 deletions tests/setup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,11 @@ use toydb::error::Result;
use toydb::server::Server;
use toydb::{raft, sql, storage};

use futures_util::future::FutureExt as _;
use pretty_assertions::assert_eq;
use std::collections::HashMap;
use std::net::TcpListener;
use std::time::Duration;
use tempdir::TempDir;
use tokio::net::TcpListener;

// Movie data
pub fn movies() -> Vec<&'static str> {
Expand Down Expand Up @@ -79,17 +78,18 @@ pub async fn server(
let dir = TempDir::new("toydb")?;
let raft_log = raft::Log::new(storage::engine::BitCask::new(dir.path().join("log"))?, false)?;
let raft_state = Box::new(sql::engine::Raft::new_state(storage::engine::Memory::new())?);
let raft_listener = TcpListener::bind(addr_raft).await?;
let sql_listener = TcpListener::bind(addr_sql).await?;
let raft_listener = TcpListener::bind(addr_raft)?;
let sql_listener = TcpListener::bind(addr_sql)?;

let (task, abort) = Server::new(id, peers, raft_log, raft_state)?
.serve(raft_listener, sql_listener)
.remote_handle();

tokio::spawn(task);
std::thread::spawn(move || {
Server::new(id, peers, raft_log, raft_state)
.unwrap()
.serve(raft_listener, sql_listener)
.unwrap()
});

Ok(Teardown::new(move || {
std::mem::drop(abort);
// TODO: shut down the server here.
std::mem::drop(dir);
}))
}
Expand Down

0 comments on commit 0c6df37

Please sign in to comment.