diff --git a/java/com/metabase/macaw/AnalysisError.java b/java/com/metabase/macaw/AnalysisError.java new file mode 100644 index 0000000..6e32efb --- /dev/null +++ b/java/com/metabase/macaw/AnalysisError.java @@ -0,0 +1,16 @@ +package com.metabase.macaw; + +public class AnalysisError extends RuntimeException { + public final AnalysisErrorType errorType; + public final Throwable cause; + + public AnalysisError(AnalysisErrorType errorType) { + this.errorType = errorType; + this.cause = null; + } + + public AnalysisError(AnalysisErrorType errorType, Throwable cause) { + this.errorType = errorType; + this.cause = cause; + } +} diff --git a/java/com/metabase/macaw/AnalysisErrorType.java b/java/com/metabase/macaw/AnalysisErrorType.java new file mode 100644 index 0000000..6c47bcb --- /dev/null +++ b/java/com/metabase/macaw/AnalysisErrorType.java @@ -0,0 +1,7 @@ +package com.metabase.macaw; + +public enum AnalysisErrorType { + UNSUPPORTED_EXPRESSION, + INVALID_QUERY, + UNABLE_TO_PARSE +} diff --git a/java/com/metabase/macaw/BasicTableExtractor.java b/java/com/metabase/macaw/BasicTableExtractor.java index 8da5606..771f26f 100644 --- a/java/com/metabase/macaw/BasicTableExtractor.java +++ b/java/com/metabase/macaw/BasicTableExtractor.java @@ -7,54 +7,28 @@ import java.util.*; import java.util.function.Consumer; -/** - * Return a simplified query representation we can work with further, if possible. - */ -@SuppressWarnings({ - "PatternVariableCanBeUsed", "IfCanBeSwitch"} // don't force a newer JVM version -) +@SuppressWarnings({"PatternVariableCanBeUsed", "IfCanBeSwitch"}) // don't force a newer JVM version public final class BasicTableExtractor { - public enum ErrorType { - UNSUPPORTED_EXPRESSION, - INVALID_QUERY, - UNABLE_TO_PARSE - } - - public static class AnalysisError extends RuntimeException { - public final ErrorType errorType; - public final Throwable cause; - - public AnalysisError(ErrorType errorType) { - this.errorType = errorType; - this.cause = null; - } - - public AnalysisError(ErrorType errorType, Throwable cause) { - this.errorType = errorType; - this.cause = cause; - } - } - public static Set getTables(Statement statement) { try { if (statement instanceof Select) { return getTables((Select) statement); } // This is not a query, it's probably a statement. - throw new AnalysisError(ErrorType.INVALID_QUERY); + throw new AnalysisError(AnalysisErrorType.INVALID_QUERY); } catch (IllegalArgumentException e) { // This query uses features that we do not yet support translating. - throw new AnalysisError(ErrorType.UNABLE_TO_PARSE, e); + throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE, e); } } private static Set
getTables(Select select) { - if (select instanceof PlainSelect) { + if (select instanceof PlainSelect) { return getTables(select.getPlainSelect()); } else { // We don't support more complex kinds of select statements yet. - throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION); + throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION); } } @@ -77,12 +51,12 @@ private static Set
getTables(PlainSelect select) { if (select.getLateralViews() != null || select.getOracleHierarchical() != null || select.getWindowDefinitions() != null) { - throw new AnalysisError(ErrorType.UNABLE_TO_PARSE); + throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE); } // Not currently supported if (select.getWithItemsList() != null) { - throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION); + throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION); } // any of these - nope out @@ -95,15 +69,15 @@ private static Set
getTables(PlainSelect select) { select.getIntoTables() != null || select.getIsolation() != null || select.getSkip() != null) { - throw new AnalysisError(ErrorType.INVALID_QUERY); + throw new AnalysisError(AnalysisErrorType.INVALID_QUERY); } Set
tables = new HashSet<>(); - for (SelectItem item : select.getSelectItems()) { + for (SelectItem item : select.getSelectItems()) { if (item.getExpression() instanceof ParenthesedSelect) { // Do not allow sub-selects. - throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION); + throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION); } } @@ -112,15 +86,15 @@ private static Set
getTables(PlainSelect select) { Table table = (Table) item; if (table.getName().contains("*")) { // Do not allow table wildcards. - throw new AnalysisError(ErrorType.INVALID_QUERY); + throw new AnalysisError(AnalysisErrorType.INVALID_QUERY); } tables.add(table); } else if (item instanceof TableFunction) { // Do not allow dynamic tables - throw new AnalysisError(ErrorType.INVALID_QUERY); + throw new AnalysisError(AnalysisErrorType.INVALID_QUERY); } else { // Only allow simple table references. - throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION); + throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION); } }; diff --git a/java/com/metabase/macaw/CompoundTableExtractor.java b/java/com/metabase/macaw/CompoundTableExtractor.java new file mode 100644 index 0000000..cbb29aa --- /dev/null +++ b/java/com/metabase/macaw/CompoundTableExtractor.java @@ -0,0 +1,148 @@ +package com.metabase.macaw; + +import net.sf.jsqlparser.expression.Expression; +import net.sf.jsqlparser.schema.Table; +import net.sf.jsqlparser.statement.Statement; +import net.sf.jsqlparser.statement.select.*; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.Stack; +import java.util.function.Consumer; + +@SuppressWarnings({"PatternVariableCanBeUsed", "IfCanBeSwitch"}) // don't force a newer JVM version +public final class CompoundTableExtractor { + + /** + * BEWARE this may return duplicates if the same table is referenced multiple times in the query. + * We need to deduplicate these with value semantics. + * A better solution would be to have our own TableValue class to convert to, before accumulating. + */ + public static Set
getTables(Statement statement) { + try { + if (statement instanceof Select) { + return accTables((Select) statement); + } + // This is not a query, it's probably a statement. + throw new AnalysisError(AnalysisErrorType.INVALID_QUERY); + } catch (IllegalArgumentException e) { + // This query uses features that we do not yet support translating. + throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE, e); + } + } + + private static Set
accTables(Select select) { + Set
tables = new HashSet<>(); + Stack> cteAliasScopes = new Stack<>(); + accTables(select, tables, cteAliasScopes); + return tables; + } + + private static void accTables(Select select, Set
tables, Stack> cteAliasScopes) { + if (select instanceof PlainSelect) { + accTables(select.getPlainSelect(), tables, cteAliasScopes); + } else if (select instanceof ParenthesedSelect) { + accTables(((ParenthesedSelect) select).getSelect(), tables, cteAliasScopes); + } else if (select instanceof SetOperationList) { + for (Select innerSelect : ((SetOperationList) select).getSelects()) { + accTables(innerSelect, tables, cteAliasScopes); + } + } else { + // We don't support more complex kinds of select statements yet. + throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION); + } + } + + private static void accTables(PlainSelect select, Set
tables, Stack> cteAliasScopes) { + // these are fine, but irrelevant + // + // - select.getDistinct() + // - select.getHaving() + // - select.getKsqlWindow() + // - select.getMySqlHintStraightJoin() + // - select.getMySqlSqlCacheFlag() + // - select.getLimitBy() + // - select.getOffset() + // - select.getOptimizeFor() + // - select.getOracleHint() + // - select.getTop() + // - select.getWait() + + // Not currently parseable + if (select.getLateralViews() != null || + select.getOracleHierarchical() != null || + select.getWindowDefinitions() != null) { + throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE); + } + + final Set cteAliases = new HashSet<>(); + cteAliasScopes.push(cteAliases); + + if (select.getWithItemsList() != null) { + for (WithItem withItem : select.getWithItemsList()) { + if (withItem.isRecursive()) { + cteAliases.add(withItem.getAlias().getName()); + } + accTables(withItem.getSelect(), tables, cteAliasScopes); + // No hard in adding twice to a set + cteAliases.add(withItem.getAlias().getName()); + } + } + + // any of these - nope out + if (select.getFetch() != null || + select.getFirst() != null || + select.getForClause() != null || + select.getForMode() != null || + select.getForUpdateTable() != null || + select.getForXmlPath() != null || + select.getIntoTables() != null || + select.getIsolation() != null || + select.getSkip() != null) { + throw new AnalysisError(AnalysisErrorType.INVALID_QUERY); + } + + for (SelectItem item : select.getSelectItems()) { + Expression expr = item.getExpression(); + if (expr instanceof Select) { + accTables((Select) expr, tables, cteAliasScopes); + } + } + + Consumer pushOrThrow = (FromItem item) -> { + if (item instanceof Table) { + Table table = (Table) item; + String tableName = table.getName(); + // Skip aliases + if (cteAliasScopes.stream().noneMatch(scope -> scope.contains(tableName))) { + if (tableName.contains("*")) { + // Do not allow table wildcards. + throw new AnalysisError(AnalysisErrorType.INVALID_QUERY); + } else { + tables.add(table); + } + } + } else if (item instanceof TableFunction) { + // Do not allow dynamic tables + throw new AnalysisError(AnalysisErrorType.INVALID_QUERY); + } else if (item instanceof Select) { + accTables((Select) item, tables, cteAliasScopes); + } else if (item != null) { + // Only allow simple table references. + throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION); + } + }; + + if (select.getFromItem() != null) { + pushOrThrow.accept(select.getFromItem()); + List joins = select.getJoins(); + if (joins != null) { + joins.stream().map(Join::getFromItem).forEach(pushOrThrow); + } + } + + cteAliasScopes.pop(); + } + +} diff --git a/src/macaw/core.clj b/src/macaw/core.clj index 2e5d05f..6f4b92a 100644 --- a/src/macaw/core.clj +++ b/src/macaw/core.clj @@ -5,7 +5,7 @@ [macaw.collect :as collect] [macaw.rewrite :as rewrite]) (:import - (com.metabase.macaw AstWalker$Scope BasicTableExtractor BasicTableExtractor$AnalysisError) + (com.metabase.macaw AstWalker$Scope BasicTableExtractor AnalysisError CompoundTableExtractor) (java.util.function Consumer) (net.sf.jsqlparser JSQLParserException) (net.sf.jsqlparser.parser CCJSqlParser CCJSqlParserUtil) @@ -93,14 +93,14 @@ :table (.getName t)} {:table (.getName t)})) -(defn- ->macaw-error [^BasicTableExtractor$AnalysisError analysis-error] +(defn- ->macaw-error [^AnalysisError analysis-error] (keyword "macaw.error" (-> (.-errorType analysis-error) str/lower-case (str/replace #"_" "-")))) (defmacro ^:private kw-or-tables [expr] - `(try (map table->identifier ~expr) - (catch BasicTableExtractor$AnalysisError e# + `(try (set (map table->identifier ~expr)) + (catch AnalysisError e# (->macaw-error e#)) (catch JSQLParserException _e# :macaw.error/unable-to-parse))) @@ -108,14 +108,13 @@ (defn query->tables "Given a parsed query (i.e., a [subclass of] `Statement`) return a set of all the table identifiers found within it." [sql & {:keys [mode] :as opts}] - (case mode - :ast-walker-1 (-> (parsed-query sql opts) - (query->components opts) - :tables - raw-components) - :basic-select (->> (parsed-query sql opts) - (BasicTableExtractor/getTables) - kw-or-tables))) + ;; We delay parsing so that kw-or-tables is able to catch exceptions. + ;; This will no longer be necessary when we update :ast-walker-1 to catch exceptions too. + (let [query (delay (parsed-query sql opts))] + (case mode + :ast-walker-1 (-> (query->components @query opts) :tables raw-components) + :basic-select (-> (BasicTableExtractor/getTables @query) kw-or-tables) + :compound-select (-> (CompoundTableExtractor/getTables @query) kw-or-tables)))) (defn replace-names "Given an SQL query, apply the given table, column, and schema renames. diff --git a/test/macaw/acceptance_test.clj b/test/macaw/acceptance_test.clj index df12cc9..0b6ab33 100644 --- a/test/macaw/acceptance_test.clj +++ b/test/macaw/acceptance_test.clj @@ -37,15 +37,26 @@ (def ^:private test-modes #{:ast-walker-1 - :basic-select}) + :basic-select + :compound-select}) + +(def override-hierarchy + (-> (make-hierarchy) + (derive :basic-select :select-only) + (derive :compound-select :select-only))) + +(defn- lineage [h k] + (when k + (assert (<= (count (parents h k)) 1) "Multiple inheritance not supported for override hierarchy.") + (cons k (lineage h (first (parents h k)))))) (def global-overrides {}) (def ns-overrides - {:basic-select {"compound" :macaw.error/unsupported-expression - "mutation" :macaw.error/invalid-query - "dynamic" :macaw.error/invalid-query}}) + {:select-only {"mutation" :macaw.error/invalid-query + "dynamic" :macaw.error/invalid-query} + :basic-select {"compound" :macaw.error/unsupported-expression}}) (def ^:private merged-fixtures-file "test/resources/acceptance/queries.sql") @@ -76,15 +87,21 @@ (when (keyword? x) x)) -(defn- get-override [expected-cs mode fixture ck] +(defn- get-override* [expected-cs mode fixture ck] (or (get global-overrides mode) (get-in ns-overrides [mode (namespace fixture)]) (get-in expected-cs [:overrides mode :error]) - (get-in expected-cs [:overrides :error]) (get-in expected-cs [:overrides mode ck]) - (get-in expected-cs [:overrides ck]) - (when-keyword (get-in expected-cs [:overrides mode])) - (when-keyword (get expected-cs :overrides)))) + (when-keyword (get-in expected-cs [:overrides mode])))) + +(defn- get-override [expected-cs mode fixture ck] + (or + (some #(get-override* expected-cs % fixture ck) + (lineage override-hierarchy mode)) + + (get-in expected-cs [:overrides :error]) + (get-in expected-cs [:overrides ck]) + (when-keyword (get expected-cs :overrides)))) (defn- test-fixture "Test that we can parse a given fixture, and compare against expected analysis and rewrites, where they are defined." @@ -95,7 +112,7 @@ expected-cs (fixture-analysis fixture) renames (fixture-renames fixture) expected-rw (fixture-rewritten fixture) - base-opts {:non-reserved-words [:final]} + base-opts {:non-reserved-words [:final], :allow-unused? true} opts-mode (fn [mode] (assoc base-opts :mode mode))] (assert sql "Fixture exists") (doseq [m test-modes @@ -107,7 +124,7 @@ (is (thrown-with-msg? Exception expected-msg (ct/components sql opts)))) (let [cs (testing (str prefix " analysis does not throw") (is (ct/components sql opts)))] - (doseq [[ck cv] (dissoc expected-cs :overrides :error)] + (doseq [[ck cv] (dissoc expected-cs :overrides :error :skip)] (testing (str prefix " analysis is correct: " (name ck)) (let [actual-cv (get-component cs ck) override (get-override expected-cs m fixture ck)] @@ -118,7 +135,9 @@ ;; For now, we only support (and test) :tables tables (testing (str prefix " table analysis does not throw for mode " m) (is (ct/tables sql opts)))] - (when-not (and (nil? correct) (nil? override)) + (if (and (nil? correct) (nil? override)) + (testing "Must define expected tables, or explicitly skip analysis" + (is (:skip expected-cs))) (testing (str prefix " table analysis is correct for mode " m) (validate-analysis correct override tables)))))) @@ -161,7 +180,7 @@ (cons 'do (for [f fixtures :let [test-name (symbol (str/replace (ct/fixture->filename f "-test") #"(?