Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: minimum collect candidates boundary to fix parse performance #378

Merged
merged 3 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions scripts/benchmark.js
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,19 @@ function checkVersion() {
if (semver.lt(currentVersion, MIN_VERSION)) {
console.error(
chalk.bold.red(
`Current Node.js version (v${currentVersion}) is lower than required version (v${semver.major(MIN_VERSION)}.x)`
`Current Node.js version (v${currentVersion}) is lower than required version (v${semver.major(
MIN_VERSION
)}.x)`
)
);
return false;
} else {
if (isRelease && semver.lt(currentVersion, RELEASE_VERSION)) {
console.error(
chalk.bold.red(
`Node.js version v${semver.major(RELEASE_VERSION)}.x+ is required for release benchmark!`
`Node.js version v${semver.major(
RELEASE_VERSION
)}.x+ is required for release benchmark!`
)
);
return false;
Expand Down Expand Up @@ -81,7 +85,9 @@ function prompt() {
'Cold start' +
(isNodeVersionOk
? ''
: ` (Only supported on Node.js v${semver.major(RECOMMENDED_VERSION)}.x+)`),
: ` (Only supported on Node.js v${semver.major(
RECOMMENDED_VERSION
)}.x+)`),
value: 'cold',
disabled: !isNodeVersionOk,
},
Expand Down
94 changes: 69 additions & 25 deletions src/parser/common/basicSQL.ts
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,13 @@ export abstract class BasicSQL<
return this._parseErrors;
}

/**
* Get the input string that has been parsed.
*/
public getParsedInput(): string {
return this._parsedInput;
}

/**
* Get all Tokens of input string,'<EOF>' is not included.
* @param input source string
Expand Down Expand Up @@ -252,35 +259,35 @@ export abstract class BasicSQL<
}

/**
* Get suggestions of syntax and token at caretPosition
* @param input source string
* @param caretPosition caret position, such as cursor position
* @returns suggestion
* Get a minimum boundary parser near tokenIndex.
* @param input source string.
* @param tokenIndex start from which index to minimize the boundary.
* @param originParseTree the parse tree need to be minimized, default value is the result of parsing `input`.
* @returns minimum parser info
*/
public getSuggestionAtCaretPosition(
public getMinimumParserInfo(
input: string,
caretPosition: CaretPosition
): Suggestions | null {
const splitListener = this.splitListener;

this.parseWithCache(input);
if (!this._parseTree) return null;

let sqlParserIns = this._parser;
const allTokens = this.getAllTokens(input);
let caretTokenIndex = findCaretTokenIndex(caretPosition, allTokens);
let c3Context: ParserRuleContext = this._parseTree;
let tokenIndexOffset: number = 0;
tokenIndex: number,
originParseTree?: ParserRuleContext | null
) {
if (arguments.length <= 2) {
this.parseWithCache(input);
originParseTree = this._parseTree;
}

if (!caretTokenIndex && caretTokenIndex !== 0) return null;
if (!originParseTree || !input?.length) return null;

const splitListener = this.splitListener;
/**
* Split sql by statement.
* Try to collect candidates in as small a range as possible.
*/
this.listen(splitListener, this._parseTree);
this.listen(splitListener, originParseTree);
const statementCount = splitListener.statementsContext?.length;
const statementsContext = splitListener.statementsContext;
let tokenIndexOffset = 0;
let sqlParserIns = this._parser;
let parseTree = originParseTree;

// If there are multiple statements.
if (statementCount > 1) {
Expand All @@ -305,14 +312,14 @@ export abstract class BasicSQL<
const isNextCtxValid =
index === statementCount - 1 || !statementsContext[index + 1]?.exception;

if (ctx.stop && ctx.stop.tokenIndex < caretTokenIndex && isPrevCtxValid) {
if (ctx.stop && ctx.stop.tokenIndex < tokenIndex && isPrevCtxValid) {
startStatement = ctx;
}

if (
ctx.start &&
!stopStatement &&
ctx.start.tokenIndex > caretTokenIndex &&
ctx.start.tokenIndex > tokenIndex &&
isNextCtxValid
) {
stopStatement = ctx;
Expand All @@ -329,7 +336,7 @@ export abstract class BasicSQL<
* compared to the tokenIndex in the whole input
*/
tokenIndexOffset = startStatement?.start?.tokenIndex ?? 0;
caretTokenIndex = caretTokenIndex - tokenIndexOffset;
tokenIndex = tokenIndex - tokenIndexOffset;

/**
* Reparse the input fragment,
Expand All @@ -349,17 +356,54 @@ export abstract class BasicSQL<
parser.errorHandler = new ErrorStrategy();

sqlParserIns = parser;
c3Context = parser.program();
parseTree = parser.program();
}

return {
parser: sqlParserIns,
parseTree,
tokenIndexOffset,
newTokenIndex: tokenIndex,
};
}

/**
* Get suggestions of syntax and token at caretPosition
* @param input source string
* @param caretPosition caret position, such as cursor position
* @returns suggestion
*/
public getSuggestionAtCaretPosition(
input: string,
caretPosition: CaretPosition
): Suggestions | null {
this.parseWithCache(input);

if (!this._parseTree) return null;

const allTokens = this.getAllTokens(input);
let caretTokenIndex = findCaretTokenIndex(caretPosition, allTokens);

if (!caretTokenIndex && caretTokenIndex !== 0) return null;

const minimumParser = this.getMinimumParserInfo(input, caretTokenIndex);

if (!minimumParser) return null;

const {
parser: sqlParserIns,
tokenIndexOffset,
newTokenIndex,
parseTree: c3Context,
} = minimumParser;
const core = new CodeCompletionCore(sqlParserIns);
core.preferredRules = this.preferredRules;

const candidates = core.collectCandidates(caretTokenIndex, c3Context);
const candidates = core.collectCandidates(newTokenIndex, c3Context);
const originalSuggestions = this.processCandidates(
candidates,
allTokens,
caretTokenIndex,
newTokenIndex,
tokenIndexOffset
);

Expand Down
17 changes: 13 additions & 4 deletions src/parser/common/parseErrorListener.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import {
InputMismatchException,
NoViableAltException,
} from 'antlr4ng';
import { LOCALE_TYPE } from './types';
import { transform } from './transform';
import { BasicSQL } from './basicSQL';

/**
* Converted from {@link SyntaxError}.
Expand Down Expand Up @@ -48,10 +48,19 @@ export type ErrorListener = (parseError: ParseError, originalError: SyntaxError)

export abstract class ParseErrorListener implements ANTLRErrorListener {
private _errorListener: ErrorListener;
private locale: LOCALE_TYPE;
protected preferredRules: Set<number>;
protected get locale() {
return this.parserContext.locale;
}
protected parserContext: BasicSQL;

constructor(errorListener: ErrorListener, locale: LOCALE_TYPE = 'en_US') {
this.locale = locale;
constructor(
errorListener: ErrorListener,
parserContext: BasicSQL,
preferredRules: Set<number>
) {
this.parserContext = parserContext;
this.preferredRules = preferredRules;
this._errorListener = errorListener;
}

Expand Down
31 changes: 20 additions & 11 deletions src/parser/flink/flinkErrorListener.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import { CodeCompletionCore } from 'antlr4-c3';
import { ErrorListener, ParseErrorListener } from '../common/parseErrorListener';
import { ParseErrorListener } from '../common/parseErrorListener';
import { Parser, Token } from 'antlr4ng';
import { FlinkSqlParser } from '../../lib/flink/FlinkSqlParser';
import { LOCALE_TYPE } from '../common/types';

export class FlinkErrorListener extends ParseErrorListener {
private preferredRules: Set<number>;

private objectNames: Map<number, string> = new Map([
[FlinkSqlParser.RULE_catalogPath, 'catalog'],
[FlinkSqlParser.RULE_catalogPathCreate, 'catalog'],
Expand All @@ -22,22 +19,34 @@ export class FlinkErrorListener extends ParseErrorListener {
[FlinkSqlParser.RULE_columnNameCreate, 'column'],
]);

constructor(errorListener: ErrorListener, preferredRules: Set<number>, locale: LOCALE_TYPE) {
super(errorListener, locale);
this.preferredRules = preferredRules;
}

public getExpectedText(parser: Parser, token: Token) {
let expectedText = '';
const input = this.parserContext.getParsedInput();

/**
* Get the program context.
* When called error listener, `this._parseTree` is still `undefined`,
* so we can't use cached parseTree in `getMinimumParserInfo`
*/
let currentContext = parser.context ?? undefined;
while (currentContext?.parent) {
currentContext = currentContext.parent;
}

const core = new CodeCompletionCore(parser);
const parserInfo = this.parserContext.getMinimumParserInfo(
input,
token.tokenIndex,
currentContext
);

if (!parserInfo) return '';

const { parser: c3Parser, newTokenIndex, parseTree: c3Context } = parserInfo;

const core = new CodeCompletionCore(c3Parser);
core.preferredRules = this.preferredRules;
const candidates = core.collectCandidates(token.tokenIndex, currentContext);

const candidates = core.collectCandidates(newTokenIndex, c3Context);

if (candidates.rules.size) {
const result: string[] = [];
Expand Down
5 changes: 3 additions & 2 deletions src/parser/flink/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ export class FlinkSQL extends BasicSQL<FlinkSqlLexer, ProgramContext, FlinkSqlPa
return new FlinkSqlSplitListener();
}

protected createErrorListener(_errorListener: ErrorListener) {
return new FlinkErrorListener(_errorListener, this.preferredRules, this.locale);
protected createErrorListener(_errorListener: ErrorListener): FlinkErrorListener {
const parserContext = this;
return new FlinkErrorListener(_errorListener, parserContext, this.preferredRules);
}

protected createEntityCollector(input: string, caretTokenIndex?: number) {
Expand Down
31 changes: 20 additions & 11 deletions src/parser/hive/hiveErrorListener.ts
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
import { CodeCompletionCore } from 'antlr4-c3';
import { ErrorListener, ParseErrorListener } from '../common/parseErrorListener';
import { ParseErrorListener } from '../common/parseErrorListener';
import { Parser, Token } from 'antlr4ng';
import { HiveSqlParser } from '../../lib/hive/HiveSqlParser';
import { LOCALE_TYPE } from '../common/types';

export class HiveErrorListener extends ParseErrorListener {
private preferredRules: Set<number>;

private objectNames: Map<number, string> = new Map([
[HiveSqlParser.RULE_dbSchemaName, 'database'],
[HiveSqlParser.RULE_dbSchemaNameCreate, 'database'],
Expand All @@ -21,22 +18,34 @@ export class HiveErrorListener extends ParseErrorListener {
[HiveSqlParser.RULE_columnNameCreate, 'column'],
]);

constructor(errorListener: ErrorListener, preferredRules: Set<number>, locale: LOCALE_TYPE) {
super(errorListener, locale);
this.preferredRules = preferredRules;
}

public getExpectedText(parser: Parser, token: Token) {
let expectedText = '';
const input = this.parserContext.getParsedInput();

/**
* Get the program context.
* When called error listener, `this._parseTree` is still `undefined`,
* so we can't use cached parseTree in `getMinimumParserInfo`
*/
let currentContext = parser.context ?? undefined;
while (currentContext?.parent) {
currentContext = currentContext.parent;
}

const core = new CodeCompletionCore(parser);
const parserInfo = this.parserContext.getMinimumParserInfo(
input,
token.tokenIndex,
currentContext
);

if (!parserInfo) return '';

const { parser: c3Parser, newTokenIndex, parseTree: c3Context } = parserInfo;

const core = new CodeCompletionCore(c3Parser);
core.preferredRules = this.preferredRules;
const candidates = core.collectCandidates(token.tokenIndex, currentContext);

const candidates = core.collectCandidates(newTokenIndex, c3Context);

if (candidates.rules.size) {
const result: string[] = [];
Expand Down
5 changes: 3 additions & 2 deletions src/parser/hive/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@ export class HiveSQL extends BasicSQL<HiveSqlLexer, ProgramContext, HiveSqlParse
return new HiveSqlSplitListener();
}

protected createErrorListener(_errorListener: ErrorListener) {
return new HiveErrorListener(_errorListener, this.preferredRules, this.locale);
protected createErrorListener(_errorListener: ErrorListener): HiveErrorListener {
const parserContext = this;
return new HiveErrorListener(_errorListener, parserContext, this.preferredRules);
}

protected createEntityCollector(input: string, caretTokenIndex?: number) {
Expand Down
Loading
Loading