Skip to content

Commit

Permalink
Logical to physical plan lowering
Browse files Browse the repository at this point in the history
  • Loading branch information
mamcx committed Oct 25, 2024
1 parent fa960b3 commit 5de7d6b
Show file tree
Hide file tree
Showing 9 changed files with 371 additions and 44 deletions.
4 changes: 2 additions & 2 deletions crates/core/src/sql/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::db::relational_db::{MutTx, RelationalDB, Tx};
use crate::error::{DBError, PlanError};
use spacetimedb_data_structures::map::{HashCollectionExt as _, IntMap};
use spacetimedb_expr::check::SchemaView;
use spacetimedb_expr::statement::parse_and_type_sql;
use spacetimedb_expr::statement::compile_sql_stmt;
use spacetimedb_lib::db::auth::StAccess;
use spacetimedb_lib::db::error::RelationError;
use spacetimedb_lib::identity::AuthCtx;
Expand Down Expand Up @@ -952,7 +952,7 @@ pub(crate) fn compile_to_ast<T: TableSchemaView>(
) -> Result<Vec<SqlAst>, DBError> {
// NOTE: The following ensures compliance with the 1.0 sql api.
// Come 1.0, it will have replaced the current compilation stack.
parse_and_type_sql(sql_text, &SchemaViewer::new(db, tx, auth))?;
compile_sql_stmt(sql_text, &SchemaViewer::new(db, tx, auth))?;

let dialect = PostgreSqlDialect {};
let ast = Parser::parse_sql(&dialect, sql_text).map_err(|error| DBError::SqlParser {
Expand Down
4 changes: 2 additions & 2 deletions crates/core/src/subscription/module_subscription_actor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use crate::vm::check_row_limit;
use crate::worker_metrics::WORKER_METRICS;
use parking_lot::RwLock;
use spacetimedb_client_api_messages::websocket::FormatSwitch;
use spacetimedb_expr::check::parse_and_type_sub;
use spacetimedb_expr::check::compile_sql_sub;
use spacetimedb_expr::ty::TyCtx;
use spacetimedb_lib::identity::AuthCtx;
use spacetimedb_lib::Identity;
Expand Down Expand Up @@ -89,7 +89,7 @@ impl ModuleSubscriptions {
} else {
// NOTE: The following ensures compliance with the 1.0 sql api.
// Come 1.0, it will have replaced the current compilation stack.
parse_and_type_sub(
compile_sql_sub(
&mut TyCtx::default(),
sql,
&SchemaViewer::new(&self.relational_db, &*tx, &auth),
Expand Down
85 changes: 54 additions & 31 deletions crates/expr/src/check.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use std::sync::Arc;

use crate::statement::Statement;
use crate::ty::TyId;
use spacetimedb_schema::schema::{ColumnSchema, TableSchema};
use spacetimedb_sql_parser::{
ast::{
Expand All @@ -10,14 +12,12 @@ use spacetimedb_sql_parser::{
parser::sub::parse_subscription,
};

use crate::ty::TyId;

use super::{
assert_eq_types,
errors::{DuplicateName, TypingError, Unresolved, Unsupported},
expr::{Expr, Let, RelExpr},
ty::{Symbol, TyCtx, TyEnv},
type_expr, type_proj, type_select,
type_expr, type_proj, type_select, StatementCtx, StatementSource,
};

/// The result of type checking and name resolution
Expand Down Expand Up @@ -179,52 +179,42 @@ pub fn parse_and_type_sub(ctx: &mut TyCtx, sql: &str, tx: &impl SchemaView) -> T
expect_table_type(ctx, expr)
}

/// Parse and type check a *subscription* query into a `StatementCtx`
pub fn compile_sql_sub<'a>(ctx: &mut TyCtx, sql: &'a str, tx: &impl SchemaView) -> TypingResult<StatementCtx<'a>> {
let expr = parse_and_type_sub(ctx, sql, tx)?;
Ok(StatementCtx {
statement: Statement::Select(expr),
sql,
source: StatementSource::Subscription,
})
}

/// Returns an error if the input type is not a table type or relvar
fn expect_table_type(ctx: &TyCtx, expr: RelExpr) -> TypingResult<RelExpr> {
let _ = expr.ty(ctx)?.expect_relvar().map_err(|_| Unsupported::ReturnType)?;
Ok(expr)
}

#[cfg(test)]
mod tests {
use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, AlgebraicType, ProductType};
pub mod test_utils {
use spacetimedb_lib::{db::raw_def::v9::RawModuleDefV9Builder, ProductType};
use spacetimedb_primitives::TableId;
use spacetimedb_schema::{
def::ModuleDef,
schema::{Schema, TableSchema},
};
use std::sync::Arc;

use crate::ty::TyCtx;
use super::SchemaView;

use super::{parse_and_type_sub, SchemaView};

fn module_def() -> ModuleDef {
pub fn build_module_def(types: Vec<(&str, ProductType)>) -> ModuleDef {
let mut builder = RawModuleDefV9Builder::new();
builder.build_table_with_new_type(
"t",
ProductType::from([
("u32", AlgebraicType::U32),
("f32", AlgebraicType::F32),
("str", AlgebraicType::String),
("arr", AlgebraicType::array(AlgebraicType::String)),
]),
true,
);
builder.build_table_with_new_type(
"s",
ProductType::from([
("id", AlgebraicType::identity()),
("u32", AlgebraicType::U32),
("arr", AlgebraicType::array(AlgebraicType::String)),
("bytes", AlgebraicType::bytes()),
]),
true,
);
for (name, ty) in types {
builder.build_table_with_new_type(name, ty, true);
}
builder.finish().try_into().expect("failed to generate module def")
}

struct SchemaViewer(ModuleDef);
pub struct SchemaViewer(pub ModuleDef);

impl SchemaView for SchemaViewer {
fn schema(&self, name: &str) -> Option<Arc<TableSchema>> {
Expand All @@ -238,6 +228,39 @@ mod tests {
})
}
}
}

#[cfg(test)]
mod tests {
use crate::check::test_utils::{build_module_def, SchemaViewer};
use crate::ty::TyCtx;
use spacetimedb_lib::{AlgebraicType, ProductType};
use spacetimedb_schema::def::ModuleDef;

use super::parse_and_type_sub;

fn module_def() -> ModuleDef {
build_module_def(vec![
(
"t",
ProductType::from([
("u32", AlgebraicType::U32),
("f32", AlgebraicType::F32),
("str", AlgebraicType::String),
("arr", AlgebraicType::array(AlgebraicType::String)),
]),
),
(
"s",
ProductType::from([
("id", AlgebraicType::identity()),
("u32", AlgebraicType::U32),
("arr", AlgebraicType::array(AlgebraicType::String)),
("bytes", AlgebraicType::bytes()),
]),
),
])
}

#[test]
fn valid() {
Expand Down
23 changes: 18 additions & 5 deletions crates/expr/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use spacetimedb_sql_parser::ast::BinOp;
use super::ty::{InvalidTypeId, Symbol, TyCtx, TyId, Type, TypeWithCtx};

/// A logical relational expression
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum RelExpr {
/// A base table
RelVar(Arc<TableSchema>, TyId),
Expand Down Expand Up @@ -65,7 +65,7 @@ impl RelExpr {
}

/// A relational select operation or filter
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Select {
/// The input relation
pub input: RelExpr,
Expand All @@ -74,7 +74,7 @@ pub struct Select {
}

/// A relational project operation or map
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Project {
/// The input relation
pub input: RelExpr,
Expand All @@ -86,16 +86,29 @@ pub struct Project {
///
/// Relational operators take a single input paramter.
/// Let variables explicitly destructure the input row.
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct Let {
/// The variable definitions for this let expression
pub vars: Vec<(Symbol, Expr)>,
/// The expressions for which the above variables are in scope
pub exprs: Vec<Expr>,
}

/// A let context for variable resolution
pub struct LetCtx<'a> {
pub vars: &'a [(Symbol, Expr)],
}

impl<'a> LetCtx<'a> {
pub fn get_var(&self, sym: Symbol) -> Option<&Expr> {
self.vars
.iter()
.find_map(|(s, e)| if *s == sym { Some(e) } else { None })
}
}

/// A typed scalar expression
#[derive(Debug)]
#[derive(Debug, Clone)]
pub enum Expr {
/// A binary expression
Bin(BinOp, Box<Expr>, Box<Expr>),
Expand Down
19 changes: 19 additions & 0 deletions crates/expr/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use std::collections::HashSet;

use crate::statement::Statement;
use check::TypingResult;
use errors::{DuplicateName, InvalidLiteral, InvalidWildcard, UnexpectedType, Unresolved};
use expr::{Expr, Let, RelExpr};
Expand Down Expand Up @@ -113,6 +114,9 @@ pub(crate) fn type_proj(
// Create a single field expression for the projection.
// Note the variable reference has been inlined.
// Hence no let variables are needed for this expression.
//
// This is because the expression here don't flatten the row, ie:
// `SELECT * FROM a JOIN b` = `Row{a:Row{...}, b:Row{...}}`
Ok(RelExpr::project(
input,
Let {
Expand Down Expand Up @@ -348,3 +352,18 @@ pub(crate) fn parse(value: String, ty: TypeWithCtx) -> Result<AlgebraicValue, In
_ => Err(InvalidLiteral::new(value, &ty)),
}
}

/// The source of a statement
pub enum StatementSource {
Subscription,
Query,
}

/// A statement context.
///
/// This is a wrapper around a statement, its source, and the original SQL text.
pub struct StatementCtx<'a> {
pub statement: Statement,
pub sql: &'a str,
pub source: StatementSource,
}
14 changes: 12 additions & 2 deletions crates/expr/src/statement.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use super::{
expr::{Expr, RelExpr},
parse,
ty::{TyCtx, TyEnv},
type_expr, type_proj, type_select,
type_expr, type_proj, type_select, StatementCtx, StatementSource,
};

pub enum Statement {
Expand Down Expand Up @@ -348,7 +348,7 @@ impl TypeChecker for SqlChecker {
}
}

pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult<Statement> {
fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult<Statement> {
match parse_sql(sql)? {
SqlAst::Insert(insert) => Ok(Statement::Insert(type_insert(&mut TyCtx::default(), insert, tx)?)),
SqlAst::Delete(delete) => Ok(Statement::Delete(type_delete(&mut TyCtx::default(), delete, tx)?)),
Expand All @@ -358,3 +358,13 @@ pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView) -> TypingResult<State
SqlAst::Show(show) => Ok(Statement::Show(type_show(show)?)),
}
}

/// Parse and type check a *general* query into a [StatementCtx].
pub fn compile_sql_stmt<'a>(sql: &'a str, tx: &impl SchemaView) -> TypingResult<StatementCtx<'a>> {
let statement = parse_and_type_sql(sql, tx)?;
Ok(StatementCtx {
statement,
sql,
source: StatementSource::Query,
})
}
Loading

0 comments on commit 5de7d6b

Please sign in to comment.