From 5de7d6b1c64721254b89643507cca6c978d854b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mario=20Alejandro=20Montoya=20Corte=CC=81s?= Date: Fri, 25 Oct 2024 13:55:27 -0500 Subject: [PATCH] Logical to physical plan lowering --- crates/core/src/sql/ast.rs | 4 +- .../subscription/module_subscription_actor.rs | 4 +- crates/expr/src/check.rs | 85 ++++--- crates/expr/src/expr.rs | 23 +- crates/expr/src/lib.rs | 19 ++ crates/expr/src/statement.rs | 14 +- crates/physical-plan/src/compile.rs | 218 ++++++++++++++++++ crates/physical-plan/src/lib.rs | 1 + crates/physical-plan/src/plan.rs | 47 +++- 9 files changed, 371 insertions(+), 44 deletions(-) create mode 100644 crates/physical-plan/src/compile.rs diff --git a/crates/core/src/sql/ast.rs b/crates/core/src/sql/ast.rs index 9ad49639863..15ab7443faa 100644 --- a/crates/core/src/sql/ast.rs +++ b/crates/core/src/sql/ast.rs @@ -2,7 +2,7 @@ use crate::db::relational_db::{MutTx, RelationalDB, Tx}; use crate::error::{DBError, PlanError}; use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap}; use spacetimedb_expr::check::SchemaView; -use spacetimedb_expr::statement::parse_and_type_sql; +use spacetimedb_expr::statement::compile_sql_stmt; use spacetimedb_lib::db::auth::StAccess; use spacetimedb_lib::db::error::RelationError; use spacetimedb_lib::identity::AuthCtx; @@ -952,7 +952,7 @@ pub(crate) fn compile_to_ast( ) -> Result, DBError> { // NOTE: The following ensures compliance with the 1.0 sql api. // Come 1.0, it will have replaced the current compilation stack. - parse_and_type_sql(sql_text, &SchemaViewer::new(db, tx, auth))?; + compile_sql_stmt(sql_text, &SchemaViewer::new(db, tx, auth))?; let dialect = PostgreSqlDialect {}; let ast = Parser::parse_sql(&dialect, sql_text).map_err(|error| DBError::SqlParser { diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index 9940f968877..c7cc1dbff27 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -15,7 +15,7 @@ use crate::vm::check_row_limit; use crate::worker_metrics::WORKER_METRICS; use parking_lot::RwLock; use spacetimedb_client_api_messages::websocket::FormatSwitch; -use spacetimedb_expr::check::parse_and_type_sub; +use spacetimedb_expr::check::compile_sql_sub; use spacetimedb_expr::ty::TyCtx; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_lib::Identity; @@ -89,7 +89,7 @@ impl ModuleSubscriptions { } else { // NOTE: The following ensures compliance with the 1.0 sql api. // Come 1.0, it will have replaced the current compilation stack. - parse_and_type_sub( + compile_sql_sub( &mut TyCtx::default(), sql, &SchemaViewer::new(&self.relational_db, &*tx, &auth), diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 72d2e382a94..7c65d0c93ed 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -1,5 +1,7 @@ use std::sync::Arc; +use crate::statement::Statement; +use crate::ty::TyId; use spacetimedb_schema::schema::{ColumnSchema, TableSchema}; use spacetimedb_sql_parser::{ ast::{ @@ -10,14 +12,12 @@ use spacetimedb_sql_parser::{ parser::sub::parse_subscription, }; -use crate::ty::TyId; - use super::{ assert_eq_types, errors::{DuplicateName, TypingError, Unresolved, Unsupported}, expr::{Expr, Let, RelExpr}, ty::{Symbol, TyCtx, TyEnv}, - type_expr, type_proj, type_select, + type_expr, type_proj, type_select, StatementCtx, StatementSource, }; /// The result of type checking and name resolution @@ -179,15 +179,24 @@ pub fn parse_and_type_sub(ctx: &mut TyCtx, sql: &str, tx: &impl SchemaView) -> T expect_table_type(ctx, expr) } +/// Parse and type check a *subscription* query into a `StatementCtx` +pub fn compile_sql_sub<'a>(ctx: &mut TyCtx, sql: &'a str, tx: &impl SchemaView) -> TypingResult> { + let expr = parse_and_type_sub(ctx, sql, tx)?; + Ok(StatementCtx { + statement: Statement::Select(expr), + sql, + source: StatementSource::Subscription, + }) +} + /// Returns an error if the input type is not a table type or relvar fn expect_table_type(ctx: &TyCtx, expr: RelExpr) -> TypingResult { let _ = expr.ty(ctx)?.expect_relvar().map_err(|_| Unsupported::ReturnType)?; Ok(expr) } -#[cfg(test)] -mod tests { - use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, AlgebraicType, ProductType}; +pub mod test_utils { + use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType}; use spacetimedb_primitives::TableId; use spacetimedb_schema::{ def::ModuleDef, @@ -195,36 +204,17 @@ mod tests { }; use std::sync::Arc; - use crate::ty::TyCtx; + use super::SchemaView; - use super::{parse_and_type_sub, SchemaView}; - - fn module_def() -> ModuleDef { + pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef { let mut builder = RawModuleDefV9Builder::new(); - builder.build_table_with_new_type( - "t", - ProductType::from([ - ("u32", AlgebraicType::U32), - ("f32", AlgebraicType::F32), - ("str", AlgebraicType::String), - ("arr", AlgebraicType::array(AlgebraicType::String)), - ]), - true, - ); - builder.build_table_with_new_type( - "s", - ProductType::from([ - ("id", AlgebraicType::identity()), - ("u32", AlgebraicType::U32), - ("arr", AlgebraicType::array(AlgebraicType::String)), - ("bytes", AlgebraicType::bytes()), - ]), - true, - ); + for (name, ty) in types { + builder.build_table_with_new_type(name, ty, true); + } builder.finish().try_into().expect("failed to generate module def") } - struct SchemaViewer(ModuleDef); + pub struct SchemaViewer(pub ModuleDef); impl SchemaView for SchemaViewer { fn schema(&self, name: &str) -> Option> { @@ -238,6 +228,39 @@ mod tests { }) } } +} + +#[cfg(test)] +mod tests { + use crate::check::test_utils::{build_module_def, SchemaViewer}; + use crate::ty::TyCtx; + use spacetimedb_lib::{AlgebraicType, ProductType}; + use spacetimedb_schema::def::ModuleDef; + + use super::parse_and_type_sub; + + fn module_def() -> ModuleDef { + build_module_def(vec![ + ( + "t", + ProductType::from([ + ("u32", AlgebraicType::U32), + ("f32", AlgebraicType::F32), + ("str", AlgebraicType::String), + ("arr", AlgebraicType::array(AlgebraicType::String)), + ]), + ), + ( + "s", + ProductType::from([ + ("id", AlgebraicType::identity()), + ("u32", AlgebraicType::U32), + ("arr", AlgebraicType::array(AlgebraicType::String)), + ("bytes", AlgebraicType::bytes()), + ]), + ), + ]) + } #[test] fn valid() { diff --git a/crates/expr/src/expr.rs b/crates/expr/src/expr.rs index edcf25e7c25..a799953d684 100644 --- a/crates/expr/src/expr.rs +++ b/crates/expr/src/expr.rs @@ -10,7 +10,7 @@ use spacetimedb_sql_parser::ast::BinOp; use super::ty::{InvalidTypeId, Symbol, TyCtx, TyId, Type, TypeWithCtx}; /// A logical relational expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum RelExpr { /// A base table RelVar(Arc, TyId), @@ -65,7 +65,7 @@ impl RelExpr { } /// A relational select operation or filter -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Select { /// The input relation pub input: RelExpr, @@ -74,7 +74,7 @@ pub struct Select { } /// A relational project operation or map -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Project { /// The input relation pub input: RelExpr, @@ -86,7 +86,7 @@ pub struct Project { /// /// Relational operators take a single input paramter. /// Let variables explicitly destructure the input row. -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct Let { /// The variable definitions for this let expression pub vars: Vec<(Symbol, Expr)>, @@ -94,8 +94,21 @@ pub struct Let { pub exprs: Vec, } +/// A let context for variable resolution +pub struct LetCtx<'a> { + pub vars: &'a [(Symbol, Expr)], +} + +impl<'a> LetCtx<'a> { + pub fn get_var(&self, sym: Symbol) -> Option<&Expr> { + self.vars + .iter() + .find_map(|(s, e)| if *s == sym { Some(e) } else { None }) + } +} + /// A typed scalar expression -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum Expr { /// A binary expression Bin(BinOp, Box, Box), diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index 71e8583de15..8f6f86dfdaf 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -1,5 +1,6 @@ use std::collections::HashSet; +use crate::statement::Statement; use check::TypingResult; use errors::{DuplicateName, InvalidLiteral, InvalidWildcard, UnexpectedType, Unresolved}; use expr::{Expr, Let, RelExpr}; @@ -113,6 +114,9 @@ pub(crate) fn type_proj( // Create a single field expression for the projection. // Note the variable reference has been inlined. // Hence no let variables are needed for this expression. + // + // This is because the expression here don't flatten the row, ie: + // `SELECT * FROM a JOIN b` = `Row{a:Row{...}, b:Row{...}}` Ok(RelExpr::project( input, Let { @@ -348,3 +352,18 @@ pub(crate) fn parse(value: String, ty: TypeWithCtx) -> Result Err(InvalidLiteral::new(value, &ty)), } } + +/// The source of a statement +pub enum StatementSource { + Subscription, + Query, +} + +/// A statement context. +/// +/// This is a wrapper around a statement, its source, and the original SQL text. +pub struct StatementCtx<'a> { + pub statement: Statement, + pub sql: &'a str, + pub source: StatementSource, +} diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index f6142bea047..05e631d0945 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -21,7 +21,7 @@ use super::{ expr::{Expr, RelExpr}, parse, ty::{TyCtx, TyEnv}, - type_expr, type_proj, type_select, + type_expr, type_proj, type_select, StatementCtx, StatementSource, }; pub enum Statement { @@ -348,7 +348,7 @@ impl TypeChecker for SqlChecker { } } -pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult { +fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult { match parse_sql(sql)? { SqlAst::Insert(insert) => Ok(Statement::Insert(type_insert(&mut TyCtx::default(), insert, tx)?)), SqlAst::Delete(delete) => Ok(Statement::Delete(type_delete(&mut TyCtx::default(), delete, tx)?)), @@ -358,3 +358,13 @@ pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult Ok(Statement::Show(type_show(show)?)), } } + +/// Parse and type check a *general* query into a [StatementCtx]. +pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult> { + let statement = parse_and_type_sql(sql, tx)?; + Ok(StatementCtx { + statement, + sql, + source: StatementSource::Query, + }) +} diff --git a/crates/physical-plan/src/compile.rs b/crates/physical-plan/src/compile.rs new file mode 100644 index 00000000000..a504507b86c --- /dev/null +++ b/crates/physical-plan/src/compile.rs @@ -0,0 +1,218 @@ +//! Lowering from the logical plan to the physical plan. + +use crate::plan; +use crate::plan::{CrossJoin, Filter, PhysicalCtx, PhysicalExpr, PhysicalPlan}; +use spacetimedb_expr::expr::{Expr, Let, LetCtx, Project, RelExpr, Select}; +use spacetimedb_expr::statement::Statement; +use spacetimedb_expr::ty::TyId; +use spacetimedb_expr::StatementCtx; +use spacetimedb_sql_parser::ast::BinOp; + +fn compile_expr(ctx: &LetCtx, expr: Expr) -> PhysicalExpr { + match expr { + Expr::Bin(op, lhs, rhs) => { + let lhs = compile_expr(ctx, *lhs); + let rhs = compile_expr(ctx, *rhs); + PhysicalExpr::BinOp(op, Box::new(lhs), Box::new(rhs)) + } + Expr::Var(sym, _ty) => { + let var = ctx.get_var(sym).cloned().unwrap(); + compile_expr(ctx, var) + } + Expr::Row(row, ty) => PhysicalExpr::Tuple( + row.iter() + // The `sym` is inline in `expr` + .map(|(_sym, expr)| compile_expr(ctx, expr.clone())) + .collect(), + ty, + ), + Expr::Lit(value, ty) => PhysicalExpr::Value(value, ty), + Expr::Field(expr, pos, ty) => { + let expr = compile_expr(ctx, *expr); + PhysicalExpr::Field(Box::new(expr), pos, ty) + } + Expr::Input(ty) => PhysicalExpr::Input(ty), + } +} + +fn join_exprs(exprs: Vec) -> Option { + exprs + .into_iter() + .reduce(|lhs, rhs| PhysicalExpr::BinOp(BinOp::And, Box::new(lhs), Box::new(rhs))) +} + +fn compile_let(expr: Let) -> Vec { + let ctx = LetCtx { vars: &expr.vars }; + + expr.exprs.into_iter().map(|expr| compile_expr(&ctx, expr)).collect() +} + +fn compile_filter(select: Select) -> PhysicalPlan { + let input = compile_rel_expr(select.input); + if let Some(op) = join_exprs(compile_let(select.expr)) { + PhysicalPlan::Filter(Filter { + input: Box::new(input), + op, + }) + } else { + input + } +} + +fn compile_project(expr: Project) -> PhysicalPlan { + let proj = plan::Project { + input: Box::new(compile_rel_expr(expr.input)), + op: join_exprs(compile_let(expr.expr)).unwrap(), + }; + + PhysicalPlan::Project(proj) +} + +fn compile_cross_joins(joins: &[RelExpr], ty: TyId) -> PhysicalPlan { + joins + .iter() + .map(|expr| compile_rel_expr(expr.clone())) + .reduce(|lhs, rhs| { + PhysicalPlan::CrossJoin(CrossJoin { + lhs: Box::new(lhs), + rhs: Box::new(rhs), + ty, + }) + }) + .unwrap() +} + +fn compile_rel_expr(ast: RelExpr) -> PhysicalPlan { + match ast { + RelExpr::RelVar(table, _ty) => PhysicalPlan::TableScan(table), + RelExpr::Select(select) => compile_filter(*select), + RelExpr::Proj(proj) => compile_project(*proj), + RelExpr::Join(joins, ty) => compile_cross_joins(&joins, ty), + RelExpr::Union(_, _) | RelExpr::Minus(_, _) | RelExpr::Dedup(_) => { + unreachable!("DISTINCT is not implemented") + } + } +} + +/// Compile a SQL statement into a physical plan. +/// +/// The input [Statement] is assumed to be valid so the lowering is not expected to fail. +/// +/// **NOTE:** It does not optimize the plan. +pub fn compile(ast: StatementCtx) -> PhysicalCtx { + let plan = match ast.statement { + Statement::Select(expr) => compile_rel_expr(expr), + _ => { + unreachable!("Only `SELECT` is implemented") + } + }; + + PhysicalCtx { + plan, + sql: ast.sql, + source: ast.source, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use spacetimedb_expr::check::compile_sql_sub; + use spacetimedb_expr::check::test_utils::{build_module_def, SchemaViewer}; + use spacetimedb_expr::statement::compile_sql_stmt; + use spacetimedb_expr::ty::TyCtx; + use spacetimedb_expr::StatementCtx; + use spacetimedb_lib::error::ResultTest; + use spacetimedb_lib::{AlgebraicType, ProductType}; + use spacetimedb_schema::def::ModuleDef; + + fn module_def() -> ModuleDef { + build_module_def(vec![ + ( + "t", + ProductType::from([ + ("u32", AlgebraicType::U32), + ("f32", AlgebraicType::F32), + ("str", AlgebraicType::String), + ]), + ), + ( + "u", + ProductType::from([ + ("u32", AlgebraicType::U32), + ("f32", AlgebraicType::F32), + ("str", AlgebraicType::String), + ]), + ), + ("x", ProductType::from([("u32", AlgebraicType::U32)])), + ]) + } + + fn compile_sql_sub_test(sql: &str) -> ResultTest { + let tx = SchemaViewer(module_def()); + let expr = compile_sql_sub(&mut TyCtx::default(), sql, &tx)?; + Ok(expr) + } + + fn compile_sql_stmt_test(sql: &str) -> ResultTest { + let tx = SchemaViewer(module_def()); + let statement = compile_sql_stmt(sql, &tx)?; + Ok(statement) + } + + #[test] + fn test_project() -> ResultTest<()> { + let ast = compile_sql_sub_test("SELECT * FROM t")?; + assert!(matches!(compile(ast).plan, PhysicalPlan::TableScan(_))); + + let ast = compile_sql_stmt_test("SELECT u32 FROM t")?; + assert!(matches!(compile(ast).plan, PhysicalPlan::Project(_))); + + Ok(()) + } + + #[test] + fn test_select() -> ResultTest<()> { + let ast = compile_sql_sub_test("SELECT * FROM t WHERE u32 = 1")?; + assert!(matches!(compile(ast).plan, PhysicalPlan::Filter(_))); + + let ast = compile_sql_sub_test("SELECT * FROM t WHERE u32 = 1 AND f32 = f32")?; + assert!(matches!(compile(ast).plan, PhysicalPlan::Filter(_))); + Ok(()) + } + + #[test] + fn test_joins() -> ResultTest<()> { + // Check we can do a cross join + let ast = compile(compile_sql_sub_test("SELECT t.* FROM t JOIN u")?).plan; + let plan::Project { input, op } = ast.as_project().unwrap(); + let CrossJoin { lhs, rhs, ty: _ } = input.as_cross().unwrap(); + + assert!(matches!(op, PhysicalExpr::Field(_, _, _))); + assert!(matches!(&**lhs, PhysicalPlan::TableScan(_))); + assert!(matches!(&**rhs, PhysicalPlan::TableScan(_))); + + // Check we can do multiple joins + let ast = compile(compile_sql_sub_test("SELECT t.* FROM t JOIN u JOIN x")?).plan; + let plan::Project { input, op: _ } = ast.as_project().unwrap(); + let CrossJoin { lhs, rhs, ty: _ } = input.as_cross().unwrap(); + assert!(matches!(&**rhs, PhysicalPlan::TableScan(_))); + + let CrossJoin { lhs, rhs, ty: _ } = lhs.as_cross().unwrap(); + assert!(matches!(&**lhs, PhysicalPlan::TableScan(_))); + assert!(matches!(&**rhs, PhysicalPlan::TableScan(_))); + + // Check we can do a join with a filter + let ast = compile(compile_sql_stmt_test("SELECT t.* FROM t JOIN u ON t.u32 = u.u32")?).plan; + + let plan::Project { input, op: _ } = ast.as_project().unwrap(); + let Filter { input, op } = input.as_filter().unwrap(); + assert!(matches!(op, PhysicalExpr::BinOp(_, _, _))); + + let CrossJoin { lhs, rhs, ty: _ } = input.as_cross().unwrap(); + assert!(matches!(&**lhs, PhysicalPlan::TableScan(_))); + assert!(matches!(&**rhs, PhysicalPlan::TableScan(_))); + + Ok(()) + } +} diff --git a/crates/physical-plan/src/lib.rs b/crates/physical-plan/src/lib.rs index 7764a5c307e..b79989e66e4 100644 --- a/crates/physical-plan/src/lib.rs +++ b/crates/physical-plan/src/lib.rs @@ -1 +1,2 @@ +pub mod compile; pub mod plan; diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index 9606940ece2..e0cd9a0a32f 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -1,13 +1,14 @@ -use std::{ops::Bound, sync::Arc}; - use spacetimedb_expr::ty::TyId; +use spacetimedb_expr::StatementSource; use spacetimedb_lib::AlgebraicValue; use spacetimedb_primitives::{ColId, IndexId}; use spacetimedb_schema::schema::TableSchema; use spacetimedb_sql_parser::ast::BinOp; +use std::{ops::Bound, sync::Arc}; /// A physical plan is a concrete query evaluation strategy. /// As such, we can reason about its energy consumption. +#[derive(Debug)] pub enum PhysicalPlan { /// Scan a table row by row, returning row ids TableScan(Arc), @@ -25,7 +26,34 @@ pub enum PhysicalPlan { Project(Project), } +impl PhysicalPlan { + pub fn as_project(&self) -> Option<&Project> { + if let PhysicalPlan::Project(p) = self { + Some(p) + } else { + None + } + } + + pub fn as_filter(&self) -> Option<&Filter> { + if let PhysicalPlan::Filter(p) = self { + Some(p) + } else { + None + } + } + + pub fn as_cross(&self) -> Option<&CrossJoin> { + if let PhysicalPlan::CrossJoin(p) = self { + Some(p) + } else { + None + } + } +} + /// Fetch and return row ids from a btree index +#[derive(Debug)] pub struct IndexScan { /// The table on which this index is defined pub table_schema: Arc, @@ -43,6 +71,7 @@ pub struct IndexScan { } /// BTrees support equality and range scans +#[derive(Debug)] pub enum IndexOp { Eq(AlgebraicValue, TyId), Range(Bound, Bound, TyId), @@ -50,6 +79,7 @@ pub enum IndexOp { /// Join an input relation with a base table using an index. /// Returns a 2-tuple of its lhs and rhs input rows. +#[derive(Debug)] pub struct IndexJoin { /// The lhs input used to probe the index pub input: Box, @@ -71,6 +101,7 @@ pub struct IndexJoin { /// An index join + projection. /// Returns tuples from the lhs (or rhs) exclusively. +#[derive(Debug)] pub struct IndexSemiJoin { /// The lhs input used to probe the index pub input: Box, @@ -90,6 +121,7 @@ pub struct IndexSemiJoin { } /// Which side of a semijoin to project? +#[derive(Debug)] pub enum SemiJoinProj { Lhs, Rhs, @@ -97,6 +129,7 @@ pub enum SemiJoinProj { /// Returns the cross product of two input relations. /// Returns a 2-tuple of its lhs and rhs input rows. +#[derive(Debug)] pub struct CrossJoin { /// The lhs input relation pub lhs: Box, @@ -108,6 +141,7 @@ pub struct CrossJoin { } /// A streaming or non-leaf filter operation +#[derive(Debug)] pub struct Filter { /// A generic filter always has an input pub input: Box, @@ -116,6 +150,7 @@ pub struct Filter { } /// A streaming project or map operation +#[derive(Debug)] pub struct Project { /// A projection always has an input pub input: Box, @@ -125,6 +160,7 @@ pub struct Project { } /// A physical scalar expression +#[derive(Debug)] pub enum PhysicalExpr { /// A binary expression BinOp(BinOp, Box, Box), @@ -137,3 +173,10 @@ pub enum PhysicalExpr { /// The input tuple to a relop Input(TyId), } + +/// A physical context for the result of a query compilation. +pub struct PhysicalCtx<'a> { + pub plan: PhysicalPlan, + pub sql: &'a str, + pub source: StatementSource, +}