Skip to content

Commit

Permalink
bigquery sql parser support create model
Browse files Browse the repository at this point in the history
  • Loading branch information
wenshao committed Dec 12, 2024
1 parent 8c9aee5 commit 4087fa9
Show file tree
Hide file tree
Showing 6 changed files with 241 additions and 9 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package com.alibaba.druid.sql.dialect.bigquery.ast;

import com.alibaba.druid.sql.ast.SQLName;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.ast.SQLStatementImpl;
import com.alibaba.druid.sql.ast.statement.SQLAssignItem;
import com.alibaba.druid.sql.dialect.bigquery.visitor.BigQueryVisitor;
import com.alibaba.druid.sql.visitor.SQLASTVisitor;

import java.util.ArrayList;
import java.util.List;

public class BigQueryCreateModelStatement extends SQLStatementImpl implements BigQueryObject {
private boolean ifNotExists;
private boolean replace;

private SQLName name;
private final List<SQLAssignItem> options = new ArrayList<>();
private SQLStatement trainingData;
private SQLStatement customHoliday;

public SQLName getName() {
return name;
}

public void setName(SQLName x) {
if (x != null) {
x.setParent(this);
}
this.name = x;
}

public List<SQLAssignItem> getOptions() {
return options;
}

public boolean isIfNotExists() {
return ifNotExists;
}

public void setIfNotExists(boolean ifNotExists) {
this.ifNotExists = ifNotExists;
}

public boolean isReplace() {
return replace;
}

public void setReplace(boolean replace) {
this.replace = replace;
}

public SQLStatement getTrainingData() {
return trainingData;
}

public void setTrainingData(SQLStatement x) {
if (x != null) {
x.setParent(this);
}
this.trainingData = x;
}

public SQLStatement getCustomHoliday() {
return customHoliday;
}

public void setCustomHoliday(SQLStatement x) {
if (x != null) {
x.setParent(this);
}
this.customHoliday = x;
}

@Override
public void accept0(SQLASTVisitor v) {
if (v instanceof BigQueryVisitor) {
accept0((BigQueryVisitor) v);
} else {
super.accept0(v);
}
}

@Override
public void accept0(BigQueryVisitor visitor) {
if (visitor.visit(this)) {
acceptChild(visitor, name);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import com.alibaba.druid.sql.ast.expr.SQLCharExpr;
import com.alibaba.druid.sql.ast.statement.*;
import com.alibaba.druid.sql.dialect.bigquery.ast.BigQueryAssertStatement;
import com.alibaba.druid.sql.dialect.bigquery.ast.BigQueryCreateModelStatement;
import com.alibaba.druid.sql.dialect.bigquery.ast.BigQueryExecuteImmediateStatement;
import com.alibaba.druid.sql.parser.*;
import com.alibaba.druid.util.FnvHash;
Expand Down Expand Up @@ -261,4 +262,50 @@ protected void createViewAs(SQLCreateViewStatement createView) {
}
super.createViewAs(createView);
}

@Override
protected SQLStatement parseCreateModel() {
accept(Token.CREATE);
acceptIdentifier("MODEL");

BigQueryCreateModelStatement stmt = new BigQueryCreateModelStatement();
if (lexer.nextIf(Token.IF)) {
accept(Token.NOT);
accept(Token.EXISTS);
stmt.setIfNotExists(true);
} else if (lexer.nextIf(Token.OR)) {
accept(Token.REPLACE);
stmt.setReplace(true);
}
stmt.setName(
exprParser.name()
);

if (lexer.nextIfIdentifier("OPTIONS")) {
exprParser.parseAssignItem(stmt.getOptions(), stmt);
}

if (lexer.nextIf(Token.AS)) {
accept(Token.LPAREN);
acceptIdentifier("TRAINING_DATA");
accept(Token.AS);
accept(Token.LPAREN);
stmt.setTrainingData(
parseStatement0()
);
accept(Token.RPAREN);

accept(Token.COMMA);
acceptIdentifier("CUSTOM_HOLIDAY");
accept(Token.AS);
accept(Token.LPAREN);
stmt.setCustomHoliday(
parseStatement0()
);
accept(Token.RPAREN);
accept(Token.RPAREN);
}

return stmt;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -430,4 +430,45 @@ public boolean visit(BigQueryExecuteImmediateStatement x) {
}
return false;
}

public boolean visit(BigQueryCreateModelStatement x) {
print0(ucase ? " CREATE " : " create ");
if (x.isIfNotExists()) {
print0(ucase ? "IF NOT EXISTS " : "if not exists ");
}
if (x.isReplace()) {
print0(ucase ? "OR REPLACE " : "or replace ");
}
print0(ucase ? "MODEL " : "model ");
x.getName().accept(this);
println();

incrementIndent();
println(ucase ? "OPTIONS (" : "options (");
printlnAndAccept(x.getOptions(), ",");
decrementIndent();
println();
println(')');

print0(ucase ? "AS (" : "as (");
incrementIndent();
println();

incrementIndent();
println(ucase ? "TRAINING_DATA AS (" : "training_data AS (");
x.getTrainingData().accept(this);
decrementIndent();
println();

println("),");
incrementIndent();
println(ucase ? "CUSTOM_HOLIDAY AS (" : "custom_holiday AS (");
x.getCustomHoliday().accept(this);
decrementIndent();
println();
decrementIndent();
println(")");
print0(')');
return false;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,11 @@ default boolean visit(BigQueryExecuteImmediateStatement x) {

default void endVisit(BigQueryExecuteImmediateStatement x) {
}

default boolean visit(BigQueryCreateModelStatement x) {
return true;
}

default void endVisit(BigQueryCreateModelStatement x) {
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4124,6 +4124,9 @@ public SQLStatement parseCreate() {
lexer.reset(mark);
return parseCreateTable();
}
} else if (lexer.identifierEquals(Constants.MODEL)) {
lexer.reset(mark);
return parseCreateModel();
}

SQLStatement stmt = createTableRest(mark);
Expand Down Expand Up @@ -4160,6 +4163,10 @@ public SQLStatement parseCreateScan() {
throw new ParserException("TODO " + lexer.token);
}

protected SQLStatement parseCreateModel() {
throw new ParserException("TODO " + lexer.token);
}

public SQLStatement parseCreateRole() {
accept(Token.CREATE);
acceptIdentifier("ROLE");
Expand Down Expand Up @@ -5182,6 +5189,20 @@ protected SQLAlterTableReplaceColumn parseAlterTableReplaceColumn() {
}

public SQLStatement parseStatement() {
final SQLStatement ret = parseStatement0();

if (END_TOKEN_CHECKING_ENABLED) {
checkEndToken();
}

if (lexer.nextIf(SEMI)) {
ret.setAfterSemi(true);
}

return ret;
}

protected SQLStatement parseStatement0() {
final SQLStatement ret;
if (lexer.token == Token.SELECT) {
ret = this.parseSelect();
Expand All @@ -5196,15 +5217,6 @@ public SQLStatement parseStatement() {
this.parseStatementList(list, 1, null);
ret = list.get(0);
}

if (END_TOKEN_CHECKING_ENABLED) {
checkEndToken();
}

if (lexer.nextIf(SEMI)) {
ret.setAfterSemi(true);
}

return ret;
}

Expand Down
35 changes: 35 additions & 0 deletions core/src/test/resources/bvt/parser/bigquery/0.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,38 @@
CREATE MODEL m0
OPTIONS(
MODEL_TYPE='ARIMA_PLUS'
,TIME_SERIES_TIMESTAMP_COL='partition_date'
,TIME_SERIES_DATA_COL='increment'
,SEASONALITIES=['WEEKLY']
) AS (
training_data AS (
SELECT partition_date
FROM t1
),
custom_holiday AS (
SELECT DISTINCT region
FROM t2
)
);
--------------------
CREATE MODEL m0
OPTIONS (
MODEL_TYPE = 'ARIMA_PLUS',
TIME_SERIES_TIMESTAMP_COL = 'partition_date',
TIME_SERIES_DATA_COL = 'increment',
SEASONALITIES = ['WEEKLY']
)
AS (
TRAINING_DATA AS (
SELECT partition_date
FROM t1
),
CUSTOM_HOLIDAY AS (
SELECT DISTINCT region
FROM t2
)
);
------------------------------------------------------------------------------------------------------------------------
LOOP
SET x = x + 1;
IF x >= 10 THEN
Expand Down

0 comments on commit 4087fa9

Please sign in to comment.