Skip to content

Commit

Permalink
Record all table references in schemaStatVisitor.
Browse files Browse the repository at this point in the history
  • Loading branch information
lingo-xp authored and wenshao committed Dec 30, 2024
1 parent 7040893 commit c83288e
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,15 @@
import com.alibaba.druid.stat.TableStat.Mode;
import com.alibaba.druid.stat.TableStat.Relationship;
import com.alibaba.druid.util.FnvHash;
import org.apache.commons.lang3.tuple.Pair;

import java.util.*;

public class SchemaStatVisitor extends SQLASTVisitorAdapter {
protected SchemaRepository repository;

protected final List<SQLName> originalTables = new ArrayList<SQLName>();
protected final List<Pair<SQLName, String>> tableReferences = new ArrayList<>();

protected final HashMap<TableStat.Name, TableStat> tableStats = new LinkedHashMap<TableStat.Name, TableStat>();
protected final Map<Long, Column> columns = new LinkedHashMap<Long, Column>();
Expand Down Expand Up @@ -123,6 +125,10 @@ public TableStat getTableStat(String tableName) {
return stat;
}

public List<Pair<SQLName, String>> getTableReferences() {
return tableReferences;
}

public TableStat getTableStat(SQLName tableName) {
String strName;
if (tableName instanceof SQLIdentifierExpr) {
Expand Down Expand Up @@ -1972,6 +1978,11 @@ public TableStat getTableStat(SQLExprTableSource tableSource) {
tableSource.getExpr());
}

protected void recordTableReference(SQLExprTableSource x) {
if (x.getExpr() instanceof SQLName) {
tableReferences.add(Pair.of(((SQLName) x.getExpr()), x.getAlias()));
}
}
protected TableStat getTableStatWithUnwrap(SQLExpr expr) {
SQLExpr identExpr = null;

Expand Down Expand Up @@ -2023,6 +2034,7 @@ public boolean visit(SQLExprTableSource x) {
}

if (isSimpleExprTableSource(x)) {
recordTableReference(x);
TableStat stat = getTableStatWithUnwrap(expr);
if (stat == null) {
return false;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.alibaba.druid.benckmark.sql;

import java.util.List;

import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.parser.SQLParserUtils;
import com.alibaba.druid.sql.parser.SQLStatementParser;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import org.apache.commons.lang3.tuple.Pair;
import org.junit.Assert;
import org.junit.Test;

public class SchemaStatVisitorTest {
@Test
public void testGetTableReferences() {
String sql =
"CREATE TEMP TABLE temp_participant_log AS (\n"
+ "WITH\n"
+ "transform AS (\n"
+ " SELECT\n"
+ " `godata-platform.udfs.standardRule`(status, ['cleanup']) as status_name\n"
+ " FROM patch_source\n"
+ ")\n"
+ "SELECT\n"
+ " event.order_no,\n"
+ " event.status_name\n"
+ " FROM\n"
+ " (\n"
+ " SELECT\n"
+ " ARRAY_AGG(\n"
+ " table ORDER BY event_timestamp DESC LIMIT 1\n"
+ " )[OFFSET(0)] event\n"
+ " FROM\n"
+ " transform table\n"
+ " GROUP BY\n"
+ " order_no, status_name, bid_id, iteration_number, participant_id, participant_uuid\n"
+ " )\n"
+ ");";
SQLStatementParser parser = SQLParserUtils.createSQLStatementParser(sql, DbType.bigquery);
SchemaStatVisitor schemaStatVisitor = new SchemaStatVisitor(DbType.bigquery);
SQLStatement stmt = parser.parseStatement();
stmt.accept(schemaStatVisitor);
List<Pair<SQLName, String>> tableReferences = schemaStatVisitor.getTableReferences();
Assert.assertEquals("patch_source", tableReferences.get(0).getKey().getSimpleName());
Assert.assertEquals("transform", tableReferences.get(1).getKey().getSimpleName());
Assert.assertEquals("table", tableReferences.get(1).getValue());
}
}

0 comments on commit c83288e

Please sign in to comment.