Skip to content

Commit

Permalink
More advanced table extractor to handle compound queries (#107)
Browse files Browse the repository at this point in the history
This makes things much more useful.

Wasn't as bad as I expected!
  • Loading branch information
crisptrutski authored Oct 25, 2024
1 parent 41f6132 commit 791ab33
Show file tree
Hide file tree
Showing 18 changed files with 322 additions and 74 deletions.
16 changes: 16 additions & 0 deletions java/com/metabase/macaw/AnalysisError.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.metabase.macaw;

public class AnalysisError extends RuntimeException {
public final AnalysisErrorType errorType;
public final Throwable cause;

public AnalysisError(AnalysisErrorType errorType) {
this.errorType = errorType;
this.cause = null;
}

public AnalysisError(AnalysisErrorType errorType, Throwable cause) {
this.errorType = errorType;
this.cause = cause;
}
}
7 changes: 7 additions & 0 deletions java/com/metabase/macaw/AnalysisErrorType.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package com.metabase.macaw;

public enum AnalysisErrorType {
UNSUPPORTED_EXPRESSION,
INVALID_QUERY,
UNABLE_TO_PARSE
}
52 changes: 13 additions & 39 deletions java/com/metabase/macaw/BasicTableExtractor.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,54 +7,28 @@
import java.util.*;
import java.util.function.Consumer;

/**
* Return a simplified query representation we can work with further, if possible.
*/
@SuppressWarnings({
"PatternVariableCanBeUsed", "IfCanBeSwitch"} // don't force a newer JVM version
)
@SuppressWarnings({"PatternVariableCanBeUsed", "IfCanBeSwitch"}) // don't force a newer JVM version
public final class BasicTableExtractor {

public enum ErrorType {
UNSUPPORTED_EXPRESSION,
INVALID_QUERY,
UNABLE_TO_PARSE
}

public static class AnalysisError extends RuntimeException {
public final ErrorType errorType;
public final Throwable cause;

public AnalysisError(ErrorType errorType) {
this.errorType = errorType;
this.cause = null;
}

public AnalysisError(ErrorType errorType, Throwable cause) {
this.errorType = errorType;
this.cause = cause;
}
}

public static Set<Table> getTables(Statement statement) {
try {
if (statement instanceof Select) {
return getTables((Select) statement);
}
// This is not a query, it's probably a statement.
throw new AnalysisError(ErrorType.INVALID_QUERY);
throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
} catch (IllegalArgumentException e) {
// This query uses features that we do not yet support translating.
throw new AnalysisError(ErrorType.UNABLE_TO_PARSE, e);
throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE, e);
}
}

private static Set<Table> getTables(Select select) {
if (select instanceof PlainSelect) {
if (select instanceof PlainSelect) {
return getTables(select.getPlainSelect());
} else {
// We don't support more complex kinds of select statements yet.
throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION);
throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
}

Expand All @@ -77,12 +51,12 @@ private static Set<Table> getTables(PlainSelect select) {
if (select.getLateralViews() != null ||
select.getOracleHierarchical() != null ||
select.getWindowDefinitions() != null) {
throw new AnalysisError(ErrorType.UNABLE_TO_PARSE);
throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE);
}

// Not currently supported
if (select.getWithItemsList() != null) {
throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION);
throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}

// any of these - nope out
Expand All @@ -95,15 +69,15 @@ private static Set<Table> getTables(PlainSelect select) {
select.getIntoTables() != null ||
select.getIsolation() != null ||
select.getSkip() != null) {
throw new AnalysisError(ErrorType.INVALID_QUERY);
throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
}

Set<Table> tables = new HashSet<>();

for (SelectItem item : select.getSelectItems()) {
for (SelectItem<?> item : select.getSelectItems()) {
if (item.getExpression() instanceof ParenthesedSelect) {
// Do not allow sub-selects.
throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION);
throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
}

Expand All @@ -112,15 +86,15 @@ private static Set<Table> getTables(PlainSelect select) {
Table table = (Table) item;
if (table.getName().contains("*")) {
// Do not allow table wildcards.
throw new AnalysisError(ErrorType.INVALID_QUERY);
throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
}
tables.add(table);
} else if (item instanceof TableFunction) {
// Do not allow dynamic tables
throw new AnalysisError(ErrorType.INVALID_QUERY);
throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
} else {
// Only allow simple table references.
throw new AnalysisError(ErrorType.UNSUPPORTED_EXPRESSION);
throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
};

Expand Down
148 changes: 148 additions & 0 deletions java/com/metabase/macaw/CompoundTableExtractor.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package com.metabase.macaw;

import net.sf.jsqlparser.expression.Expression;
import net.sf.jsqlparser.schema.Table;
import net.sf.jsqlparser.statement.Statement;
import net.sf.jsqlparser.statement.select.*;

import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Stack;
import java.util.function.Consumer;

@SuppressWarnings({"PatternVariableCanBeUsed", "IfCanBeSwitch"}) // don't force a newer JVM version
public final class CompoundTableExtractor {

/**
* BEWARE this may return duplicates if the same table is referenced multiple times in the query.
* We need to deduplicate these with value semantics.
* A better solution would be to have our own TableValue class to convert to, before accumulating.
*/
public static Set<Table> getTables(Statement statement) {
try {
if (statement instanceof Select) {
return accTables((Select) statement);
}
// This is not a query, it's probably a statement.
throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
} catch (IllegalArgumentException e) {
// This query uses features that we do not yet support translating.
throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE, e);
}
}

private static Set<Table> accTables(Select select) {
Set<Table> tables = new HashSet<>();
Stack<Set<String>> cteAliasScopes = new Stack<>();
accTables(select, tables, cteAliasScopes);
return tables;
}

private static void accTables(Select select, Set<Table> tables, Stack<Set<String>> cteAliasScopes) {
if (select instanceof PlainSelect) {
accTables(select.getPlainSelect(), tables, cteAliasScopes);
} else if (select instanceof ParenthesedSelect) {
accTables(((ParenthesedSelect) select).getSelect(), tables, cteAliasScopes);
} else if (select instanceof SetOperationList) {
for (Select innerSelect : ((SetOperationList) select).getSelects()) {
accTables(innerSelect, tables, cteAliasScopes);
}
} else {
// We don't support more complex kinds of select statements yet.
throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
}

private static void accTables(PlainSelect select, Set<Table> tables, Stack<Set<String>> cteAliasScopes) {
// these are fine, but irrelevant
//
// - select.getDistinct()
// - select.getHaving()
// - select.getKsqlWindow()
// - select.getMySqlHintStraightJoin()
// - select.getMySqlSqlCacheFlag()
// - select.getLimitBy()
// - select.getOffset()
// - select.getOptimizeFor()
// - select.getOracleHint()
// - select.getTop()
// - select.getWait()

// Not currently parseable
if (select.getLateralViews() != null ||
select.getOracleHierarchical() != null ||
select.getWindowDefinitions() != null) {
throw new AnalysisError(AnalysisErrorType.UNABLE_TO_PARSE);
}

final Set<String> cteAliases = new HashSet<>();
cteAliasScopes.push(cteAliases);

if (select.getWithItemsList() != null) {
for (WithItem withItem : select.getWithItemsList()) {
if (withItem.isRecursive()) {
cteAliases.add(withItem.getAlias().getName());
}
accTables(withItem.getSelect(), tables, cteAliasScopes);
// No hard in adding twice to a set
cteAliases.add(withItem.getAlias().getName());
}
}

// any of these - nope out
if (select.getFetch() != null ||
select.getFirst() != null ||
select.getForClause() != null ||
select.getForMode() != null ||
select.getForUpdateTable() != null ||
select.getForXmlPath() != null ||
select.getIntoTables() != null ||
select.getIsolation() != null ||
select.getSkip() != null) {
throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
}

for (SelectItem<?> item : select.getSelectItems()) {
Expression expr = item.getExpression();
if (expr instanceof Select) {
accTables((Select) expr, tables, cteAliasScopes);
}
}

Consumer<FromItem> pushOrThrow = (FromItem item) -> {
if (item instanceof Table) {
Table table = (Table) item;
String tableName = table.getName();
// Skip aliases
if (cteAliasScopes.stream().noneMatch(scope -> scope.contains(tableName))) {
if (tableName.contains("*")) {
// Do not allow table wildcards.
throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
} else {
tables.add(table);
}
}
} else if (item instanceof TableFunction) {
// Do not allow dynamic tables
throw new AnalysisError(AnalysisErrorType.INVALID_QUERY);
} else if (item instanceof Select) {
accTables((Select) item, tables, cteAliasScopes);
} else if (item != null) {
// Only allow simple table references.
throw new AnalysisError(AnalysisErrorType.UNSUPPORTED_EXPRESSION);
}
};

if (select.getFromItem() != null) {
pushOrThrow.accept(select.getFromItem());
List<Join> joins = select.getJoins();
if (joins != null) {
joins.stream().map(Join::getFromItem).forEach(pushOrThrow);
}
}

cteAliasScopes.pop();
}

}
23 changes: 11 additions & 12 deletions src/macaw/core.clj
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
[macaw.collect :as collect]
[macaw.rewrite :as rewrite])
(:import
(com.metabase.macaw AstWalker$Scope BasicTableExtractor BasicTableExtractor$AnalysisError)
(com.metabase.macaw AstWalker$Scope BasicTableExtractor AnalysisError CompoundTableExtractor)
(java.util.function Consumer)
(net.sf.jsqlparser JSQLParserException)
(net.sf.jsqlparser.parser CCJSqlParser CCJSqlParserUtil)
Expand Down Expand Up @@ -93,29 +93,28 @@
:table (.getName t)}
{:table (.getName t)}))

(defn- ->macaw-error [^BasicTableExtractor$AnalysisError analysis-error]
(defn- ->macaw-error [^AnalysisError analysis-error]
(keyword "macaw.error" (-> (.-errorType analysis-error)
str/lower-case
(str/replace #"_" "-"))))

(defmacro ^:private kw-or-tables [expr]
`(try (map table->identifier ~expr)
(catch BasicTableExtractor$AnalysisError e#
`(try (set (map table->identifier ~expr))
(catch AnalysisError e#
(->macaw-error e#))
(catch JSQLParserException _e#
:macaw.error/unable-to-parse)))

(defn query->tables
"Given a parsed query (i.e., a [subclass of] `Statement`) return a set of all the table identifiers found within it."
[sql & {:keys [mode] :as opts}]
(case mode
:ast-walker-1 (-> (parsed-query sql opts)
(query->components opts)
:tables
raw-components)
:basic-select (->> (parsed-query sql opts)
(BasicTableExtractor/getTables)
kw-or-tables)))
;; We delay parsing so that kw-or-tables is able to catch exceptions.
;; This will no longer be necessary when we update :ast-walker-1 to catch exceptions too.
(let [query (delay (parsed-query sql opts))]
(case mode
:ast-walker-1 (-> (query->components @query opts) :tables raw-components)
:basic-select (-> (BasicTableExtractor/getTables @query) kw-or-tables)
:compound-select (-> (CompoundTableExtractor/getTables @query) kw-or-tables))))

(defn replace-names
"Given an SQL query, apply the given table, column, and schema renames.
Expand Down
Loading

0 comments on commit 791ab33

Please sign in to comment.