diff --git a/java/com/metabase/macaw/AnalysisError.java b/java/com/metabase/macaw/AnalysisError.java
new file mode 100644
index 0000000..6e32efb
--- /dev/null
+++ b/java/com/metabase/macaw/AnalysisError.java
@@ -0,0 +1,16 @@
+package com.metabase.macaw;
+
+public class AnalysisError extends RuntimeException {
+ public final AnalysisErrorType errorType;
+ public final Throwable cause;
+
+ public AnalysisError(AnalysisErrorType errorType) {
+ this.errorType = errorType;
+ this.cause = null;
+ }
+
+ public AnalysisError(AnalysisErrorType errorType, Throwable cause) {
+ this.errorType = errorType;
+ this.cause = cause;
+ }
+}
diff --git a/java/com/metabase/macaw/AnalysisErrorType.java b/java/com/metabase/macaw/AnalysisErrorType.java
new file mode 100644
index 0000000..6c47bcb
--- /dev/null
+++ b/java/com/metabase/macaw/AnalysisErrorType.java
@@ -0,0 +1,7 @@
+package com.metabase.macaw;
+
+public enum AnalysisErrorType {
+ UNSUPPORTED_EXPRESSION,
+ INVALID_QUERY,
+ UNABLE_TO_PARSE
+}
diff --git a/java/com/metabase/macaw/BasicTableExtractor.java b/java/com/metabase/macaw/BasicTableExtractor.java
index 8da5606..771f26f 100644
--- a/java/com/metabase/macaw/BasicTableExtractor.java
+++ b/java/com/metabase/macaw/BasicTableExtractor.java
@@ -7,54 +7,28 @@
import java.util.*;
import java.util.function.Consumer;
-/**
- * Return a simplified query representation we can work with further, if possible.
- */
-@SuppressWarnings({
- "PatternVariableCanBeUsed", "IfCanBeSwitch"} // don't force a newer JVM version
-)
+@SuppressWarnings({"PatternVariableCanBeUsed", "IfCanBeSwitch"}) // don't force a newer JVM version
public final class BasicTableExtractor {
- public enum ErrorType {
- UNSUPPORTED_EXPRESSION,
- INVALID_QUERY,
- UNABLE_TO_PARSE
- }
-
- public static class AnalysisError extends RuntimeException {
- public final ErrorType errorType;
- public final Throwable cause;
-
- public AnalysisError(ErrorType errorType) {
- this.errorType = errorType;
- this.cause = null;
- }
-
- public AnalysisError(ErrorType errorType, Throwable cause) {
- this.errorType = errorType;
- this.cause = cause;
- }
- }
-
public static Set
getTables(Statement statement) {
try {
if (statement instanceof Select) {
return getTables((Select) statement);
}
// This is not a query, it's probably a statement.
- throw new AnalysisError(ErrorType.INVALID_QUERY);
+ throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
} catch (IllegalArgumentException e) {
// This query uses features that we do not yet support translating.
- throw new AnalysisError(ErrorType.UNABLE_TO_PARSE, e);
+ throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE, e);
}
}
private static Set getTables(Select select) {
- if (select instanceof PlainSelect) {
+ if (select instanceof PlainSelect) {
return getTables(select.getPlainSelect());
} else {
// We don't support more complex kinds of select statements yet.
- throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION);
+ throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
}
@@ -77,12 +51,12 @@ private static Set getTables(PlainSelect select) {
if (select.getLateralViews() != null ||
select.getOracleHierarchical() != null ||
select.getWindowDefinitions() != null) {
- throw new AnalysisError(ErrorType.UNABLE_TO_PARSE);
+ throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE);
}
// Not currently supported
if (select.getWithItemsList() != null) {
- throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION);
+ throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
// any of these - nope out
@@ -95,15 +69,15 @@ private static Set getTables(PlainSelect select) {
select.getIntoTables() != null ||
select.getIsolation() != null ||
select.getSkip() != null) {
- throw new AnalysisError(ErrorType.INVALID_QUERY);
+ throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
}
Set tables = new HashSet<>();
- for (SelectItem item : select.getSelectItems()) {
+ for (SelectItem> item : select.getSelectItems()) {
if (item.getExpression() instanceof ParenthesedSelect) {
// Do not allow sub-selects.
- throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION);
+ throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
}
@@ -112,15 +86,15 @@ private static Set getTables(PlainSelect select) {
Table table = (Table) item;
if (table.getName().contains("*")) {
// Do not allow table wildcards.
- throw new AnalysisError(ErrorType.INVALID_QUERY);
+ throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
}
tables.add(table);
} else if (item instanceof TableFunction) {
// Do not allow dynamic tables
- throw new AnalysisError(ErrorType.INVALID_QUERY);
+ throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
} else {
// Only allow simple table references.
- throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION);
+ throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
};
diff --git a/java/com/metabase/macaw/CompoundTableExtractor.java b/java/com/metabase/macaw/CompoundTableExtractor.java
new file mode 100644
index 0000000..cbb29aa
--- /dev/null
+++ b/java/com/metabase/macaw/CompoundTableExtractor.java
@@ -0,0 +1,148 @@
+package com.metabase.macaw;
+
+import net.sf.jsqlparser.expression.Expression;
+import net.sf.jsqlparser.schema.Table;
+import net.sf.jsqlparser.statement.Statement;
+import net.sf.jsqlparser.statement.select.*;
+
+import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
+import java.util.Stack;
+import java.util.function.Consumer;
+
+@SuppressWarnings({"PatternVariableCanBeUsed", "IfCanBeSwitch"}) // don't force a newer JVM version
+public final class CompoundTableExtractor {
+
+ /**
+ * BEWARE this may return duplicates if the same table is referenced multiple times in the query.
+ * We need to deduplicate these with value semantics.
+ * A better solution would be to have our own TableValue class to convert to, before accumulating.
+ */
+ public static Set getTables(Statement statement) {
+ try {
+ if (statement instanceof Select) {
+ return accTables((Select) statement);
+ }
+ // This is not a query, it's probably a statement.
+ throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
+ } catch (IllegalArgumentException e) {
+ // This query uses features that we do not yet support translating.
+ throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE, e);
+ }
+ }
+
+ private static Set accTables(Select select) {
+ Set tables = new HashSet<>();
+ Stack> cteAliasScopes = new Stack<>();
+ accTables(select, tables, cteAliasScopes);
+ return tables;
+ }
+
+ private static void accTables(Select select, Set tables, Stack> cteAliasScopes) {
+ if (select instanceof PlainSelect) {
+ accTables(select.getPlainSelect(), tables, cteAliasScopes);
+ } else if (select instanceof ParenthesedSelect) {
+ accTables(((ParenthesedSelect) select).getSelect(), tables, cteAliasScopes);
+ } else if (select instanceof SetOperationList) {
+ for (Select innerSelect : ((SetOperationList) select).getSelects()) {
+ accTables(innerSelect, tables, cteAliasScopes);
+ }
+ } else {
+ // We don't support more complex kinds of select statements yet.
+ throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
+ }
+ }
+
+ private static void accTables(PlainSelect select, Set tables, Stack> cteAliasScopes) {
+ // these are fine, but irrelevant
+ //
+ // - select.getDistinct()
+ // - select.getHaving()
+ // - select.getKsqlWindow()
+ // - select.getMySqlHintStraightJoin()
+ // - select.getMySqlSqlCacheFlag()
+ // - select.getLimitBy()
+ // - select.getOffset()
+ // - select.getOptimizeFor()
+ // - select.getOracleHint()
+ // - select.getTop()
+ // - select.getWait()
+
+ // Not currently parseable
+ if (select.getLateralViews() != null ||
+ select.getOracleHierarchical() != null ||
+ select.getWindowDefinitions() != null) {
+ throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE);
+ }
+
+ final Set cteAliases = new HashSet<>();
+ cteAliasScopes.push(cteAliases);
+
+ if (select.getWithItemsList() != null) {
+ for (WithItem withItem : select.getWithItemsList()) {
+ if (withItem.isRecursive()) {
+ cteAliases.add(withItem.getAlias().getName());
+ }
+ accTables(withItem.getSelect(), tables, cteAliasScopes);
+ // No hard in adding twice to a set
+ cteAliases.add(withItem.getAlias().getName());
+ }
+ }
+
+ // any of these - nope out
+ if (select.getFetch() != null ||
+ select.getFirst() != null ||
+ select.getForClause() != null ||
+ select.getForMode() != null ||
+ select.getForUpdateTable() != null ||
+ select.getForXmlPath() != null ||
+ select.getIntoTables() != null ||
+ select.getIsolation() != null ||
+ select.getSkip() != null) {
+ throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
+ }
+
+ for (SelectItem> item : select.getSelectItems()) {
+ Expression expr = item.getExpression();
+ if (expr instanceof Select) {
+ accTables((Select) expr, tables, cteAliasScopes);
+ }
+ }
+
+ Consumer pushOrThrow = (FromItem item) -> {
+ if (item instanceof Table) {
+ Table table = (Table) item;
+ String tableName = table.getName();
+ // Skip aliases
+ if (cteAliasScopes.stream().noneMatch(scope -> scope.contains(tableName))) {
+ if (tableName.contains("*")) {
+ // Do not allow table wildcards.
+ throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
+ } else {
+ tables.add(table);
+ }
+ }
+ } else if (item instanceof TableFunction) {
+ // Do not allow dynamic tables
+ throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
+ } else if (item instanceof Select) {
+ accTables((Select) item, tables, cteAliasScopes);
+ } else if (item != null) {
+ // Only allow simple table references.
+ throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
+ }
+ };
+
+ if (select.getFromItem() != null) {
+ pushOrThrow.accept(select.getFromItem());
+ List joins = select.getJoins();
+ if (joins != null) {
+ joins.stream().map(Join::getFromItem).forEach(pushOrThrow);
+ }
+ }
+
+ cteAliasScopes.pop();
+ }
+
+}
diff --git a/src/macaw/core.clj b/src/macaw/core.clj
index 2e5d05f..6f4b92a 100644
--- a/src/macaw/core.clj
+++ b/src/macaw/core.clj
@@ -5,7 +5,7 @@
[macaw.collect :as collect]
[macaw.rewrite :as rewrite])
(:import
- (com.metabase.macaw AstWalker$Scope BasicTableExtractor BasicTableExtractor$AnalysisError)
+ (com.metabase.macaw AstWalker$Scope BasicTableExtractor AnalysisError CompoundTableExtractor)
(java.util.function Consumer)
(net.sf.jsqlparser JSQLParserException)
(net.sf.jsqlparser.parser CCJSqlParser CCJSqlParserUtil)
@@ -93,14 +93,14 @@
:table (.getName t)}
{:table (.getName t)}))
-(defn- ->macaw-error [^BasicTableExtractor$AnalysisError analysis-error]
+(defn- ->macaw-error [^AnalysisError analysis-error]
(keyword "macaw.error" (-> (.-errorType analysis-error)
str/lower-case
(str/replace #"_" "-"))))
(defmacro ^:private kw-or-tables [expr]
- `(try (map table->identifier ~expr)
- (catch BasicTableExtractor$AnalysisError e#
+ `(try (set (map table->identifier ~expr))
+ (catch AnalysisError e#
(->macaw-error e#))
(catch JSQLParserException _e#
:macaw.error/unable-to-parse)))
@@ -108,14 +108,13 @@
(defn query->tables
"Given a parsed query (i.e., a [subclass of] `Statement`) return a set of all the table identifiers found within it."
[sql & {:keys [mode] :as opts}]
- (case mode
- :ast-walker-1 (-> (parsed-query sql opts)
- (query->components opts)
- :tables
- raw-components)
- :basic-select (->> (parsed-query sql opts)
- (BasicTableExtractor/getTables)
- kw-or-tables)))
+ ;; We delay parsing so that kw-or-tables is able to catch exceptions.
+ ;; This will no longer be necessary when we update :ast-walker-1 to catch exceptions too.
+ (let [query (delay (parsed-query sql opts))]
+ (case mode
+ :ast-walker-1 (-> (query->components @query opts) :tables raw-components)
+ :basic-select (-> (BasicTableExtractor/getTables @query) kw-or-tables)
+ :compound-select (-> (CompoundTableExtractor/getTables @query) kw-or-tables))))
(defn replace-names
"Given an SQL query, apply the given table, column, and schema renames.
diff --git a/test/macaw/acceptance_test.clj b/test/macaw/acceptance_test.clj
index df12cc9..0b6ab33 100644
--- a/test/macaw/acceptance_test.clj
+++ b/test/macaw/acceptance_test.clj
@@ -37,15 +37,26 @@
(def ^:private test-modes
#{:ast-walker-1
- :basic-select})
+ :basic-select
+ :compound-select})
+
+(def override-hierarchy
+ (-> (make-hierarchy)
+ (derive :basic-select :select-only)
+ (derive :compound-select :select-only)))
+
+(defn- lineage [h k]
+ (when k
+ (assert (<= (count (parents h k)) 1) "Multiple inheritance not supported for override hierarchy.")
+ (cons k (lineage h (first (parents h k))))))
(def global-overrides
{})
(def ns-overrides
- {:basic-select {"compound" :macaw.error/unsupported-expression
- "mutation" :macaw.error/invalid-query
- "dynamic" :macaw.error/invalid-query}})
+ {:select-only {"mutation" :macaw.error/invalid-query
+ "dynamic" :macaw.error/invalid-query}
+ :basic-select {"compound" :macaw.error/unsupported-expression}})
(def ^:private merged-fixtures-file "test/resources/acceptance/queries.sql")
@@ -76,15 +87,21 @@
(when (keyword? x)
x))
-(defn- get-override [expected-cs mode fixture ck]
+(defn- get-override* [expected-cs mode fixture ck]
(or (get global-overrides mode)
(get-in ns-overrides [mode (namespace fixture)])
(get-in expected-cs [:overrides mode :error])
- (get-in expected-cs [:overrides :error])
(get-in expected-cs [:overrides mode ck])
- (get-in expected-cs [:overrides ck])
- (when-keyword (get-in expected-cs [:overrides mode]))
- (when-keyword (get expected-cs :overrides))))
+ (when-keyword (get-in expected-cs [:overrides mode]))))
+
+(defn- get-override [expected-cs mode fixture ck]
+ (or
+ (some #(get-override* expected-cs % fixture ck)
+ (lineage override-hierarchy mode))
+
+ (get-in expected-cs [:overrides :error])
+ (get-in expected-cs [:overrides ck])
+ (when-keyword (get expected-cs :overrides))))
(defn- test-fixture
"Test that we can parse a given fixture, and compare against expected analysis and rewrites, where they are defined."
@@ -95,7 +112,7 @@
expected-cs (fixture-analysis fixture)
renames (fixture-renames fixture)
expected-rw (fixture-rewritten fixture)
- base-opts {:non-reserved-words [:final]}
+ base-opts {:non-reserved-words [:final], :allow-unused? true}
opts-mode (fn [mode] (assoc base-opts :mode mode))]
(assert sql "Fixture exists")
(doseq [m test-modes
@@ -107,7 +124,7 @@
(is (thrown-with-msg? Exception expected-msg (ct/components sql opts))))
(let [cs (testing (str prefix " analysis does not throw")
(is (ct/components sql opts)))]
- (doseq [[ck cv] (dissoc expected-cs :overrides :error)]
+ (doseq [[ck cv] (dissoc expected-cs :overrides :error :skip)]
(testing (str prefix " analysis is correct: " (name ck))
(let [actual-cv (get-component cs ck)
override (get-override expected-cs m fixture ck)]
@@ -118,7 +135,9 @@
;; For now, we only support (and test) :tables
tables (testing (str prefix " table analysis does not throw for mode " m)
(is (ct/tables sql opts)))]
- (when-not (and (nil? correct) (nil? override))
+ (if (and (nil? correct) (nil? override))
+ (testing "Must define expected tables, or explicitly skip analysis"
+ (is (:skip expected-cs)))
(testing (str prefix " table analysis is correct for mode " m)
(validate-analysis correct override tables))))))
@@ -161,7 +180,7 @@
(cons 'do
(for [f fixtures
:let [test-name (symbol (str/replace (ct/fixture->filename f "-test") #"(?