From be4fb90c87030f9901ecb0febac992fa57266f9d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E6=9E=B8?= Date: Fri, 27 Dec 2024 15:29:06 +0800 Subject: [PATCH] Record all table references in schemaStatVisitor. --- .../druid/sql/visitor/SchemaStatVisitor.java | 12 +++++ .../benckmark/sql/SchemaStatVisitorTest.java | 51 +++++++++++++++++++ 2 files changed, 63 insertions(+) create mode 100644 core/src/test/java/com/alibaba/druid/benckmark/sql/SchemaStatVisitorTest.java diff --git a/core/src/main/java/com/alibaba/druid/sql/visitor/SchemaStatVisitor.java b/core/src/main/java/com/alibaba/druid/sql/visitor/SchemaStatVisitor.java index 455a708c35..240361ffdd 100644 --- a/core/src/main/java/com/alibaba/druid/sql/visitor/SchemaStatVisitor.java +++ b/core/src/main/java/com/alibaba/druid/sql/visitor/SchemaStatVisitor.java @@ -36,6 +36,7 @@ import com.alibaba.druid.stat.TableStat.Mode; import com.alibaba.druid.stat.TableStat.Relationship; import com.alibaba.druid.util.FnvHash; +import org.apache.commons.lang3.tuple.Pair; import java.util.*; @@ -43,6 +44,7 @@ public class SchemaStatVisitor extends SQLASTVisitorAdapter { protected SchemaRepository repository; protected final List originalTables = new ArrayList(); + protected final List> tableReferences = new ArrayList<>(); protected final HashMap tableStats = new LinkedHashMap(); protected final Map columns = new LinkedHashMap(); @@ -123,6 +125,10 @@ public TableStat getTableStat(String tableName) { return stat; } + public List> getTableReferences() { + return tableReferences; + } + public TableStat getTableStat(SQLName tableName) { String strName; if (tableName instanceof SQLIdentifierExpr) { @@ -1972,6 +1978,11 @@ public TableStat getTableStat(SQLExprTableSource tableSource) { tableSource.getExpr()); } + protected void recordTableReference(SQLExprTableSource x) { + if (x.getExpr() instanceof SQLName) { + tableReferences.add(Pair.of(((SQLName) x.getExpr()), x.getAlias())); + } + } protected TableStat getTableStatWithUnwrap(SQLExpr expr) { SQLExpr identExpr = null; @@ -2023,6 +2034,7 @@ public boolean visit(SQLExprTableSource x) { } if (isSimpleExprTableSource(x)) { + recordTableReference(x); TableStat stat = getTableStatWithUnwrap(expr); if (stat == null) { return false; diff --git a/core/src/test/java/com/alibaba/druid/benckmark/sql/SchemaStatVisitorTest.java b/core/src/test/java/com/alibaba/druid/benckmark/sql/SchemaStatVisitorTest.java new file mode 100644 index 0000000000..89d3bf5322 --- /dev/null +++ b/core/src/test/java/com/alibaba/druid/benckmark/sql/SchemaStatVisitorTest.java @@ -0,0 +1,51 @@ +package com.alibaba.druid.benckmark.sql; + +import java.util.List; + +import com.alibaba.druid.DbType; +import com.alibaba.druid.sql.SQLUtils; +import com.alibaba.druid.sql.ast.SQLName; +import com.alibaba.druid.sql.ast.SQLStatement; +import com.alibaba.druid.sql.parser.SQLParserUtils; +import com.alibaba.druid.sql.parser.SQLStatementParser; +import com.alibaba.druid.sql.visitor.SchemaStatVisitor; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Assert; +import org.junit.Test; + +public class SchemaStatVisitorTest { + @Test + public void testGetTableReferences() { + String sql = + "CREATE TEMP TABLE temp_participant_log AS (\n" + + "WITH\n" + + "transform AS (\n" + + " SELECT\n" + + " `godata-platform.udfs.standardRule`(status, ['cleanup']) as status_name\n" + + " FROM patch_source\n" + + ")\n" + + "SELECT\n" + + " event.order_no,\n" + + " event.status_name\n" + + " FROM\n" + + " (\n" + + " SELECT\n" + + " ARRAY_AGG(\n" + + " table ORDER BY event_timestamp DESC LIMIT 1\n" + + " )[OFFSET(0)] event\n" + + " FROM\n" + + " transform table\n" + + " GROUP BY\n" + + " order_no, status_name, bid_id, iteration_number, participant_id, participant_uuid\n" + + " )\n" + + ");"; + SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, DbType.bigquery); + SchemaStatVisitor schemaStatVisitor = new SchemaStatVisitor(DbType.bigquery); + SQLStatement stmt = parser.parseStatement(); + stmt.accept(schemaStatVisitor); + List> tableReferences = schemaStatVisitor.getTableReferences(); + Assert.assertEquals("patch_source", tableReferences.get(0).getKey().getSimpleName()); + Assert.assertEquals("transform", tableReferences.get(1).getKey().getSimpleName()); + Assert.assertEquals("table", tableReferences.get(1).getValue()); + } +}