diff --git a/formatter/chain.go b/formatter/chain.go new file mode 100644 index 0000000..59acf51 --- /dev/null +++ b/formatter/chain.go @@ -0,0 +1,140 @@ +package formatter + +import ( + "fmt" + + "github.com/antlr4-go/antlr/v4" + "github.com/octoberswimmer/apexfmt/parser" + log "github.com/sirupsen/logrus" +) + +type ChainVisitor struct { + parser.BaseApexParserVisitor +} + +func NewChainVisitor() *ChainVisitor { + return &ChainVisitor{} +} + +func (v *ChainVisitor) visitRule(node antlr.RuleNode) interface{} { + result := node.Accept(v) + if r, ok := result.(int); ok { + return r + } + if result == nil { + log.Debug(fmt.Sprintf("missing ChainVisitor function for %T", node)) + } + return 0 +} + +func (v *ChainVisitor) VisitStatement(ctx *parser.StatementContext) interface{} { + child := ctx.GetChild(0).(antlr.RuleNode) + return v.visitRule(child) +} + +func (v *ChainVisitor) VisitExpressionStatement(ctx *parser.ExpressionStatementContext) interface{} { + return v.visitRule(ctx.Expression()) +} + +func (v *ChainVisitor) VisitEqualityExpression(ctx *parser.EqualityExpressionContext) interface{} { + return 1 + v.visitRule(ctx.Expression(0)).(int) + v.visitRule(ctx.Expression(1)).(int) +} + +func (v *ChainVisitor) VisitPrimaryExpression(ctx *parser.PrimaryExpressionContext) interface{} { + switch e := ctx.Primary().(type) { + case *parser.ThisPrimaryContext: + return 0 + case *parser.SuperPrimaryContext: + return 0 + case *parser.LiteralPrimaryContext: + return 0 + case *parser.TypeRefPrimaryContext: + return 0 + case *parser.IdPrimaryContext: + return 0 + case *parser.SoqlPrimaryContext: + return 1 + case *parser.SoslPrimaryContext: + return 1 + default: + return fmt.Sprintf("UNHANDLED PRIMARY EXPRESSION: %T %s", e, e.GetText()) + } +} + +func (v *ChainVisitor) VisitDotExpression(ctx *parser.DotExpressionContext) interface{} { + if ctx.DotMethodCall() != nil { + return 1 + v.visitRule(ctx.Expression()).(int) + } + return v.visitRule(ctx.Expression()) +} + +func (v *ChainVisitor) VisitLogAndExpression(ctx *parser.LogAndExpressionContext) interface{} { + return 1 + v.visitRule(ctx.Expression(0)).(int) + v.visitRule(ctx.Expression(1)).(int) +} + +func (v *ChainVisitor) VisitLogOrExpression(ctx *parser.LogOrExpressionContext) interface{} { + return 1 + v.visitRule(ctx.Expression(0)).(int) + v.visitRule(ctx.Expression(1)).(int) +} + +func (v *ChainVisitor) VisitSubExpression(ctx *parser.SubExpressionContext) interface{} { + return 1 + v.visitRule(ctx.Expression()).(int) +} + +func (v *ChainVisitor) VisitQuery(ctx *parser.QueryContext) interface{} { + score := v.visitRule(ctx.SelectList()).(int) + v.visitRule(ctx.FromNameList()).(int) + if scope := ctx.UsingScope(); scope != nil { + score++ + } + if where := ctx.WhereClause(); where != nil { + score += v.visitRule(where).(int) + } + if groupBy := ctx.GroupByClause(); groupBy != nil { + score += v.visitRule(groupBy).(int) + } + if orderBy := ctx.OrderByClause(); orderBy != nil { + score += v.visitRule(orderBy).(int) + } + if limit := ctx.LimitClause(); limit != nil { + score++ + } + if offset := ctx.OffsetClause(); offset != nil { + score++ + } + score += v.visitRule(ctx.ForClauses()).(int) + if update := ctx.UpdateList(); update != nil { + score++ + } + return score +} + +func (v *ChainVisitor) VisitSelectList(ctx *parser.SelectListContext) interface{} { + score := 0 + for _, p := range ctx.AllSelectEntry() { + score += v.visitRule(p).(int) + } + return score +} + +func (v *ChainVisitor) VisitSelectEntry(ctx *parser.SelectEntryContext) interface{} { + if ctx.SubQuery() != nil { + // estimate complexity; probably close enough + return 5 + } + return 1 +} + +func (v *ChainVisitor) VisitFromNameList(ctx *parser.FromNameListContext) interface{} { + return len(ctx.AllFieldNameAlias()) +} + +func (v *ChainVisitor) VisitWhereClause(ctx *parser.WhereClauseContext) interface{} { + return v.visitRule(ctx.LogicalExpression()) +} + +func (v *ChainVisitor) VisitLogicalExpression(ctx *parser.LogicalExpressionContext) interface{} { + return len(ctx.AllSOQLOR()) + len(ctx.AllSOQLAND()) + 1 +} + +func (v *ChainVisitor) VisitForClauses(ctx *parser.ForClausesContext) interface{} { + return len(ctx.AllForClause()) +} diff --git a/formatter/chain_test.go b/formatter/chain_test.go new file mode 100644 index 0000000..5482a92 --- /dev/null +++ b/formatter/chain_test.go @@ -0,0 +1,43 @@ +package formatter + +import ( + "testing" + + "github.com/antlr4-go/antlr/v4" + "github.com/octoberswimmer/apexfmt/parser" + + log "github.com/sirupsen/logrus" +) + +func TestChain(t *testing.T) { + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + + } + tests := + []struct { + input string + output int + }{ + {`Schema.SObjectType.Account.getRecordTypeInfosByDeveloperName().get('Business').getRecordTypeId()`, 3}, + {`Fixtures.Contact(account).put(Contact.RecordTypeId, Schema.SObjectType.Contact.getRecordTypeInfosByDeveloperName().get('Person').getRecordTypeId()).put(Contact.My_Lookup__c, newRecord[0].Id).save()`, 4}, + } + + for _, tt := range tests { + input := antlr.NewInputStream(tt.input) + lexer := parser.NewApexLexer(input) + stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) + + p := parser.NewApexParser(stream) + p.RemoveErrorListeners() + + v := NewChainVisitor() + out, ok := v.visitRule(p.Statement()).(int) + if !ok { + t.Errorf("Unexpected result parsing apex") + } + if out != tt.output { + t.Errorf("unexpected result. expected: %d; got: %d", tt.output, out) + } + } +} diff --git a/formatter/format_test.go b/formatter/format_test.go new file mode 100644 index 0000000..1265b66 --- /dev/null +++ b/formatter/format_test.go @@ -0,0 +1,204 @@ +package formatter + +import ( + "testing" + + "github.com/antlr4-go/antlr/v4" + "github.com/octoberswimmer/apexfmt/parser" + + log "github.com/sirupsen/logrus" +) + +func TestStatement(t *testing.T) { + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + + } + tests := + []struct { + input string + output string + }{ + { + `Account a = new Account(Name='Acme', RecordTypeId=myRecordTypeId, BillingCity='Los Angeles', BillingState = 'CA');`, + `Account a = new Account( + Name = 'Acme', + RecordTypeId = myRecordTypeId, + BillingCity = 'Los Angeles', + BillingState = 'CA' +);`}, + + { + `System.runAs(u) { + facility = Fixtures.account().put(Schema.Account.RecordTypeId, facilityRecordType).put(Schema.Account.HealthCloudGA__SourceSystemId__c, '0001') + .put(Schema.Account.Patient_Owner__c, u.Id) + .save(); +}`, `System.runAs(u) { + facility = Fixtures.account() + .put(Schema.Account.RecordTypeId, facilityRecordType) + .put(Schema.Account.HealthCloudGA__SourceSystemId__c, '0001') + .put(Schema.Account.Patient_Owner__c, u.Id) + .save(); +}`}, + + { + + `System.assertEquals(UserInfo.getUserId(), [SELECT OwnerId FROM Account WHERE Id = :person.Id].OwnerId, 'Account should be owned by correct user');`, + `System.assertEquals(UserInfo.getUserId(), [SELECT OwnerId FROM Account WHERE Id = :person.Id].OwnerId, 'Account should be owned by correct user');`}, + { + `System.assert(lsr[0].getErrors()[0].getMessage().contains(constants.ERR_MSG_NO_CLIENT_DEMOGRAPHICS), 'error message');`, + `System.assert(lsr[0].getErrors()[0].getMessage().contains(constants.ERR_MSG_NO_CLIENT_DEMOGRAPHICS), 'error message');`, + }, + + { + `RecordType referralType = [ SELECT Id FROM RecordType WHERE SobjectType = 'Contact' AND DeveloperName = 'Referral_Contact' ];`, + `RecordType referralType = [ + SELECT + Id + FROM + RecordType + WHERE + SobjectType = 'Contact' AND + DeveloperName = 'Referral_Contact' +];`}, + + { + `update [SELECT Id FROM Territory_Coverage__c WHERE Named_Account__c IN :accountIds];`, + `update [SELECT Id FROM Territory_Coverage__c WHERE Named_Account__c IN :accountIds];`, + }, + + { + `for (Referral__c ref : [SELECT Summary_Name__c, Name FROM Referral__c WHERE Id IN :referralIdSet]) { + System.assertEquals(ref.Name, ref.Summary_Name__c); +}`, + `for (Referral__c ref : [ + SELECT + Summary_Name__c, + Name + FROM + Referral__c + WHERE + Id IN :referralIdSet +]) { + System.assertEquals(ref.Name, ref.Summary_Name__c); +}`}, + + { + `for (Referral__c ref : [SELECT Name FROM Referral__c WHERE Id IN :referralIdSet]) { + System.assertEquals('test', ref.Summary_Name__c); +}`, + `for (Referral__c ref : [SELECT Name FROM Referral__c WHERE Id IN :referralIdSet]) { + System.assertEquals('test', ref.Summary_Name__c); +}`}, + + { + `if (!r.isSuccess()) { + throw new BenefitCheckNotificationException( + 'Failed to send Benefit Check notification. First error: ' + + r.getErrors()[0].getMessage() + ); +}`, + `if (!r.isSuccess()) { + throw new BenefitCheckNotificationException( + 'Failed to send Benefit Check notification. First error: ' + + r.getErrors()[0].getMessage() + ); +}`}, + + { + `return new List{ + new CountryZip( new Territory_Zip_Lookup__c( Id = zip.Id, Name = zip.Name, City__c = zip.City__c, State_2_Letter_Code__c = zip.State_2_Letter_Code__c, Country__c = zip.Country__c)) +};`, + `return new List{ + new CountryZip( + new Territory_Zip_Lookup__c( + Id = zip.Id, + Name = zip.Name, + City__c = zip.City__c, + State_2_Letter_Code__c = zip.State_2_Letter_Code__c, + Country__c = zip.Country__c + ) + ) +};`}, + + { + `Psychological__c psyc = Fixtures.psychological(inq).put(Psychological__c.RecordTypeId, Schema.SObjectType.Psychological__c.getRecordTypeInfosByDeveloperName().get('ICD_10').getRecordTypeId()).put(Psychological__c.Diagnosis_Lookup__c, newDiagnosis[0].Id).save();`, + `Psychological__c psyc = Fixtures.psychological(inq) + .put(Psychological__c.RecordTypeId, Schema.SObjectType.Psychological__c + .getRecordTypeInfosByDeveloperName() + .get('ICD_10') + .getRecordTypeId()) + .put(Psychological__c.Diagnosis_Lookup__c, newDiagnosis[0].Id) + .save();`}, + + { + `this.assertPassed(Assert.isNumericallyInner(2, '~0.5', 2.4, null));`, + `this.assertPassed(Assert.isNumericallyInner(2, '~0.5', 2.4, null));`}, + + { + `assertFailed(Assert.consistOfInner(new List{ a, b }, new List{ b, c }, 'doodle'), 'doodle: expected (1, 2) to consist of (2, 3)\nextra elements:\n\t(1)\nmissing elements:\n\t(3)');`, + `assertFailed(Assert.consistOfInner(new List{ a, b }, new List{ b, c }, 'doodle'), + 'doodle: expected (1, 2) to consist of (2, 3)\nextra elements:\n\t(1)\nmissing elements:\n\t(3)');`}, + + { + `if (cl_record.Last_Placement__c == true && + (Trigger.isInsert || (Trigger.isUpdate && cl_record.Last_Placement__c != Trigger.OldMap.get(cl_record.Id).Last_Placement__c))) {x=1;}`, + `if (cl_record.Last_Placement__c == true && + (Trigger.isInsert || + (Trigger.isUpdate && + cl_record.Last_Placement__c != Trigger.OldMap.get(cl_record.Id).Last_Placement__c))) { + x = 1; +}`}, + } + for _, tt := range tests { + input := antlr.NewInputStream(tt.input) + lexer := parser.NewApexLexer(input) + stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) + + p := parser.NewApexParser(stream) + p.RemoveErrorListeners() + + v := NewFormatVisitor(stream) + out, ok := v.visitRule(p.Statement()).(string) + if !ok { + t.Errorf("Unexpected result parsing apex") + } + if out != tt.output { + t.Errorf("unexpected format. expected:\n%s;\ngot:\n%s\n", tt.output, out) + } + } + +} + +func TestCompilationUnit(t *testing.T) { + if testing.Verbose() { + log.SetLevel(log.DebugLevel) + + } + tests := + []struct { + input string + output string + }{ + { + `private class T1Exception extends Exception {}`, + `private class T1Exception extends Exception {}`}, + } + for _, tt := range tests { + input := antlr.NewInputStream(tt.input) + lexer := parser.NewApexLexer(input) + stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel) + + p := parser.NewApexParser(stream) + p.RemoveErrorListeners() + + v := NewFormatVisitor(stream) + out, ok := v.visitRule(p.CompilationUnit()).(string) + if !ok { + t.Errorf("Unexpected result parsing apex") + } + if out != tt.output { + t.Errorf("unexpected format. expected:\n%s;\ngot:\n%s\n", tt.output, out) + } + } +} diff --git a/formatter/formatter.go b/formatter/formatter.go index 8d3920d..abc803a 100644 --- a/formatter/formatter.go +++ b/formatter/formatter.go @@ -70,7 +70,7 @@ func (f *Formatter) Format() error { p.AddErrorListener(&errorListener{filename: f.filename}) // p.AddErrorListener(antlr.NewDiagnosticErrorListener(false)) - v := NewVisitor(stream) + v := NewFormatVisitor(stream) out, ok := v.visitRule(p.CompilationUnit()).(string) if !ok { return fmt.Errorf("Unexpected result parsing apex") diff --git a/formatter/indent.go b/formatter/indent.go new file mode 100644 index 0000000..4aaac74 --- /dev/null +++ b/formatter/indent.go @@ -0,0 +1,52 @@ +package formatter + +import ( + "github.com/antlr4-go/antlr/v4" + "github.com/octoberswimmer/apexfmt/parser" +) + +type IndentVisitor struct { + parser.BaseApexParserVisitor +} + +func NewIndentVisitor() *IndentVisitor { + return &IndentVisitor{} +} + +func (v *IndentVisitor) visitRule(node antlr.RuleNode) interface{} { + result := node.Accept(v) + if r, ok := result.(int); ok { + return r + } + return 0 +} + +func (v *IndentVisitor) VisitExpressionList(ctx *parser.ExpressionListContext) interface{} { + indent := 0 + for _, p := range ctx.AllExpression() { + n := v.visitRule(p).(int) + if n > indent { + indent = n + } + } + return indent +} + +func (v *IndentVisitor) VisitDotExpression(ctx *parser.DotExpressionContext) interface{} { + switch { + case ctx.DotMethodCall() != nil: + switch ctx.Expression().(type) { + case *parser.PrimaryExpressionContext: + return 0 + case *parser.NewInstanceExpressionContext: + return 2 + default: + return v.visitRule(ctx.Expression()) + } + } + return 0 +} + +func (v *IndentVisitor) VisitNewInstanceExpression(ctx *parser.NewInstanceExpressionContext) interface{} { + return 2 +} diff --git a/formatter/visitor.go b/formatter/visitor.go index 69d4796..62f9625 100644 --- a/formatter/visitor.go +++ b/formatter/visitor.go @@ -9,22 +9,23 @@ import ( "github.com/octoberswimmer/apexfmt/parser" ) -type Visitor struct { +type FormatVisitor struct { tokens *antlr.CommonTokenStream commentsOutput map[int]struct{} newlinesOutput map[int]struct{} parser.BaseApexParserVisitor + wrap bool } -func NewVisitor(tokens *antlr.CommonTokenStream) *Visitor { - return &Visitor{ +func NewFormatVisitor(tokens *antlr.CommonTokenStream) *FormatVisitor { + return &FormatVisitor{ tokens: tokens, commentsOutput: make(map[int]struct{}), newlinesOutput: make(map[int]struct{}), } } -func (v *Visitor) visitRule(node antlr.RuleNode) interface{} { +func (v *FormatVisitor) visitRule(node antlr.RuleNode) interface{} { start := node.(antlr.ParserRuleContext).GetStart() beforeWhitespace := v.tokens.GetHiddenTokensToLeft(start.GetTokenIndex(), 2) beforeComments := v.tokens.GetHiddenTokensToLeft(start.GetTokenIndex(), 3) @@ -36,7 +37,7 @@ func (v *Visitor) visitRule(node antlr.RuleNode) interface{} { comments := []string{} for _, c := range beforeComments { if _, seen := v.commentsOutput[c.GetTokenIndex()]; !seen { - comments = append(comments, removeLeadingTabs(c.GetText())) + comments = append(comments, cleanWhitespace(c.GetText())) v.commentsOutput[c.GetTokenIndex()] = struct{}{} } } @@ -61,7 +62,7 @@ func (v *Visitor) visitRule(node antlr.RuleNode) interface{} { return result } -func (v *Visitor) Modifiers(ctxs []parser.IModifierContext) string { +func (v *FormatVisitor) Modifiers(ctxs []parser.IModifierContext) string { mods := []string{} annotations := []string{} for _, m := range ctxs { @@ -83,7 +84,11 @@ func (v *Visitor) Modifiers(ctxs []parser.IModifierContext) string { return m.String() } -func indent(text string) string { +func (v *FormatVisitor) indent(text string) string { + return v.indentTo(text, 1) +} + +func (v *FormatVisitor) indentTo(text string, indents int) string { var indentedText strings.Builder scanner := bufio.NewScanner(strings.NewReader(text)) isFirstLine := true @@ -95,7 +100,7 @@ func indent(text string) string { indentedText.WriteString("\n") } if scanner.Text() != "" { - indentedText.WriteString("\t" + scanner.Text()) + indentedText.WriteString(strings.Repeat("\t", indents) + scanner.Text()) } else { indentedText.WriteString(scanner.Text()) } @@ -104,20 +109,30 @@ func indent(text string) string { return indentedText.String() } -func removeLeadingTabs(input string) string { +// Remove leading tabs +func cleanWhitespace(input string) string { lines := strings.Split(input, "\n") for i, line := range lines { - tabs := 0 - for j := 0; j < len(line); j++ { - if line[j] == '\t' { - tabs++ - } else { - break - } - } - lines[i] = line[tabs:] + lines[i] = strings.TrimRight(strings.TrimLeft(line, "\t"), " \t") } return strings.Join(lines, "\n") } + +func restoreWrap(v *FormatVisitor, reset bool) *FormatVisitor { + v.wrap = reset + return v +} + +func wrap(v *FormatVisitor) (*FormatVisitor, bool) { + old := v.wrap + v.wrap = true + return v, old +} + +func unwrap(v *FormatVisitor) (*FormatVisitor, bool) { + old := v.wrap + v.wrap = false + return v, old +} diff --git a/formatter/visitors.go b/formatter/visitors.go index f806b49..6312837 100644 --- a/formatter/visitors.go +++ b/formatter/visitors.go @@ -6,9 +6,10 @@ import ( "github.com/antlr4-go/antlr/v4" "github.com/octoberswimmer/apexfmt/parser" + log "github.com/sirupsen/logrus" ) -func (v *Visitor) VisitCompilationUnit(ctx *parser.CompilationUnitContext) interface{} { +func (v *FormatVisitor) VisitCompilationUnit(ctx *parser.CompilationUnitContext) interface{} { if trigger := ctx.TriggerUnit(); trigger != nil { return v.visitRule(trigger) } @@ -31,7 +32,7 @@ func (v *Visitor) VisitCompilationUnit(ctx *parser.CompilationUnitContext) inter return "" } -func (v *Visitor) VisitClassDeclaration(ctx *parser.ClassDeclarationContext) interface{} { +func (v *FormatVisitor) VisitClassDeclaration(ctx *parser.ClassDeclarationContext) interface{} { var class strings.Builder class.WriteString(fmt.Sprintf("class %s", v.visitRule(ctx.Id()))) if ctx.EXTENDS() != nil { @@ -40,11 +41,15 @@ func (v *Visitor) VisitClassDeclaration(ctx *parser.ClassDeclarationContext) int if ctx.IMPLEMENTS() != nil { class.WriteString(fmt.Sprintf(" implements %s", v.visitRule(ctx.TypeList()))) } - class.WriteString(fmt.Sprintf(" {\n%s\n}", indent(v.visitRule(ctx.ClassBody()).(string)))) + if ctx.ClassBody().GetText() == "{}" { + class.WriteString(" {}") + } else { + class.WriteString(fmt.Sprintf(" {\n%s\n}", v.indent(v.visitRule(ctx.ClassBody()).(string)))) + } return class.String() } -func (v *Visitor) VisitTriggerUnit(ctx *parser.TriggerUnitContext) interface{} { +func (v *FormatVisitor) VisitTriggerUnit(ctx *parser.TriggerUnitContext) interface{} { triggerCases := []string{} for _, t := range ctx.AllTriggerCase() { triggerCases = append(triggerCases, v.visitRule(t).(string)) @@ -54,23 +59,23 @@ func (v *Visitor) VisitTriggerUnit(ctx *parser.TriggerUnitContext) interface{} { v.visitRule(ctx.TriggerBlock())) } -func (v *Visitor) VisitTriggerBlock(ctx *parser.TriggerBlockContext) interface{} { +func (v *FormatVisitor) VisitTriggerBlock(ctx *parser.TriggerBlockContext) interface{} { statements := []string{} for _, stmt := range ctx.AllTriggerStatement() { statements = append(statements, v.visitRule(stmt).(string)) } - return fmt.Sprintf("{\n%s\n}", indent(strings.Join(statements, "\n"))) + return fmt.Sprintf("{\n%s\n}", v.indent(strings.Join(statements, "\n"))) } -func (v *Visitor) VisitTriggerStatement(ctx *parser.TriggerStatementContext) interface{} { +func (v *FormatVisitor) VisitTriggerStatement(ctx *parser.TriggerStatementContext) interface{} { return v.visitRule(ctx.GetChild(0).(antlr.RuleNode)) } -func (v *Visitor) VisitTriggerCase(ctx *parser.TriggerCaseContext) interface{} { +func (v *FormatVisitor) VisitTriggerCase(ctx *parser.TriggerCaseContext) interface{} { return fmt.Sprintf("%s %s", ctx.GetChild(0).(antlr.TerminalNode).GetText(), ctx.GetChild(1).(antlr.TerminalNode).GetText()) } -func (v *Visitor) VisitEnumDeclaration(ctx *parser.EnumDeclarationContext) interface{} { +func (v *FormatVisitor) VisitEnumDeclaration(ctx *parser.EnumDeclarationContext) interface{} { enumConstants := "" if ctx.EnumConstants() != nil { enumConstants = v.visitRule(ctx.EnumConstants()).(string) @@ -78,7 +83,7 @@ func (v *Visitor) VisitEnumDeclaration(ctx *parser.EnumDeclarationContext) inter return fmt.Sprintf("enum %s { %s }", v.visitRule(ctx.Id()), enumConstants) } -func (v *Visitor) VisitEnumConstants(ctx *parser.EnumConstantsContext) interface{} { +func (v *FormatVisitor) VisitEnumConstants(ctx *parser.EnumConstantsContext) interface{} { ids := []string{} for _, t := range ctx.AllId() { ids = append(ids, t.GetText()) @@ -86,15 +91,15 @@ func (v *Visitor) VisitEnumConstants(ctx *parser.EnumConstantsContext) interface return strings.Join(ids, ", ") } -func (v *Visitor) VisitInterfaceDeclaration(ctx *parser.InterfaceDeclarationContext) interface{} { +func (v *FormatVisitor) VisitInterfaceDeclaration(ctx *parser.InterfaceDeclarationContext) interface{} { extends := "" if ctx.EXTENDS() != nil { extends = fmt.Sprintf(" extends %s ", v.visitRule(ctx.TypeList())) } - return fmt.Sprintf("interface %s%s {\n%s\n}", ctx.Id().GetText(), extends, indent(v.visitRule(ctx.InterfaceBody()).(string))) + return fmt.Sprintf("interface %s%s {\n%s\n}", ctx.Id().GetText(), extends, v.indent(v.visitRule(ctx.InterfaceBody()).(string))) } -func (v *Visitor) VisitInterfaceBody(ctx *parser.InterfaceBodyContext) interface{} { +func (v *FormatVisitor) VisitInterfaceBody(ctx *parser.InterfaceBodyContext) interface{} { declarations := []string{} for _, d := range ctx.AllInterfaceMethodDeclaration() { declarations = append(declarations, v.visitRule(d).(string)) @@ -102,7 +107,7 @@ func (v *Visitor) VisitInterfaceBody(ctx *parser.InterfaceBodyContext) interface return strings.Join(declarations, "\n") } -func (v *Visitor) VisitClassBody(ctx *parser.ClassBodyContext) interface{} { +func (v *FormatVisitor) VisitClassBody(ctx *parser.ClassBodyContext) interface{} { var cb []string for _, b := range ctx.AllClassBodyDeclaration() { cb = append(cb, v.visitRule(b).(string)) @@ -110,7 +115,7 @@ func (v *Visitor) VisitClassBody(ctx *parser.ClassBodyContext) interface{} { return strings.Join(cb, "\n") } -func (v *Visitor) VisitClassBodyDeclaration(ctx *parser.ClassBodyDeclarationContext) interface{} { +func (v *FormatVisitor) VisitClassBodyDeclaration(ctx *parser.ClassBodyDeclarationContext) interface{} { switch { case ctx.SEMI() != nil: return ";" @@ -126,11 +131,11 @@ func (v *Visitor) VisitClassBodyDeclaration(ctx *parser.ClassBodyDeclarationCont return "" } -func (v *Visitor) VisitMemberDeclaration(ctx *parser.MemberDeclarationContext) interface{} { +func (v *FormatVisitor) VisitMemberDeclaration(ctx *parser.MemberDeclarationContext) interface{} { return v.visitRule(ctx.GetChild(0).(antlr.RuleNode)) } -func (v *Visitor) VisitInterfaceMethodDeclaration(ctx *parser.InterfaceMethodDeclarationContext) interface{} { +func (v *FormatVisitor) VisitInterfaceMethodDeclaration(ctx *parser.InterfaceMethodDeclarationContext) interface{} { returnType := "void" if ctx.TypeRef() != nil { returnType = v.visitRule(ctx.TypeRef()).(string) @@ -138,11 +143,11 @@ func (v *Visitor) VisitInterfaceMethodDeclaration(ctx *parser.InterfaceMethodDec return fmt.Sprintf("%s%s %s%s;", v.Modifiers(ctx.AllModifier()), returnType, ctx.Id().GetText(), v.visitRule(ctx.FormalParameters())) } -func (v *Visitor) VisitFieldDeclaration(ctx *parser.FieldDeclarationContext) interface{} { +func (v *FormatVisitor) VisitFieldDeclaration(ctx *parser.FieldDeclarationContext) interface{} { return fmt.Sprintf("%s %s;", v.visitRule(ctx.TypeRef()), v.visitRule(ctx.VariableDeclarators())) } -func (v *Visitor) VisitPropertyDeclaration(ctx *parser.PropertyDeclarationContext) interface{} { +func (v *FormatVisitor) VisitPropertyDeclaration(ctx *parser.PropertyDeclarationContext) interface{} { propertyBlocks := []string{} if ctx.AllPropertyBlock() != nil { for _, p := range ctx.AllPropertyBlock() { @@ -154,10 +159,10 @@ func (v *Visitor) VisitPropertyDeclaration(ctx *parser.PropertyDeclarationContex return fmt.Sprintf("%s %s {%s}", v.visitRule(ctx.TypeRef()), ctx.Id().GetText(), strings.Join(propertyBlocks, " ")) } sep := "\n" - return fmt.Sprintf("%s %s {%s%s%s}", v.visitRule(ctx.TypeRef()), ctx.Id().GetText(), sep, indent(strings.Join(propertyBlocks, sep)), sep) + return fmt.Sprintf("%s %s {%s%s%s}", v.visitRule(ctx.TypeRef()), ctx.Id().GetText(), sep, v.indent(strings.Join(propertyBlocks, sep)), sep) } -func (v *Visitor) VisitPropertyBlock(ctx *parser.PropertyBlockContext) interface{} { +func (v *FormatVisitor) VisitPropertyBlock(ctx *parser.PropertyBlockContext) interface{} { if ctx.Getter() != nil { return fmt.Sprintf("%s%s", v.Modifiers(ctx.AllModifier()), v.visitRule(ctx.Getter())) } else { @@ -165,7 +170,7 @@ func (v *Visitor) VisitPropertyBlock(ctx *parser.PropertyBlockContext) interface } } -func (v *Visitor) VisitGetter(ctx *parser.GetterContext) interface{} { +func (v *FormatVisitor) VisitGetter(ctx *parser.GetterContext) interface{} { if ctx.SEMI() != nil { return "get;" } else { @@ -173,7 +178,7 @@ func (v *Visitor) VisitGetter(ctx *parser.GetterContext) interface{} { } } -func (v *Visitor) VisitSetter(ctx *parser.SetterContext) interface{} { +func (v *FormatVisitor) VisitSetter(ctx *parser.SetterContext) interface{} { if ctx.SEMI() != nil { return "set;" } else { @@ -181,43 +186,35 @@ func (v *Visitor) VisitSetter(ctx *parser.SetterContext) interface{} { } } -func (v *Visitor) VisitConstructorDeclaration(ctx *parser.ConstructorDeclarationContext) interface{} { +func (v *FormatVisitor) VisitConstructorDeclaration(ctx *parser.ConstructorDeclarationContext) interface{} { return fmt.Sprintf("%s%s %s", v.visitRule(ctx.QualifiedName()), v.visitRule(ctx.FormalParameters()), v.visitRule(ctx.Block()).(string)) } -func (v *Visitor) VisitBlock(ctx *parser.BlockContext) interface{} { +func (v *FormatVisitor) VisitBlock(ctx *parser.BlockContext) interface{} { statements := []string{} for _, stmt := range ctx.AllStatement() { statements = append(statements, v.visitRule(stmt).(string)) } - return fmt.Sprintf("{\n%s\n}", indent(strings.Join(statements, "\n"))) + return fmt.Sprintf("{\n%s\n}", v.indent(strings.Join(statements, "\n"))) } -func (v *Visitor) VisitStatement(ctx *parser.StatementContext) interface{} { - /* - if ctx.GetChild(0) == nil { - return "NIL STATEMENT?" + ctx.GetText() - } - if errNode, ok := ctx.GetChild(0).(antlr.ErrorNode); ok { - return fmt.Sprintf("ERROR: %+v", errNode) - } - */ +func (v *FormatVisitor) VisitStatement(ctx *parser.StatementContext) interface{} { child := ctx.GetChild(0).(antlr.RuleNode) return v.visitRule(child) } -func (v *Visitor) VisitBlockMemberDeclaration(ctx *parser.BlockMemberDeclarationContext) interface{} { +func (v *FormatVisitor) VisitBlockMemberDeclaration(ctx *parser.BlockMemberDeclarationContext) interface{} { return fmt.Sprintf("%s%s", v.Modifiers(ctx.AllModifier()), v.visitRule(ctx.MemberDeclaration())) } -func (v *Visitor) VisitIfStatement(ctx *parser.IfStatementContext) interface{} { +func (v *FormatVisitor) VisitIfStatement(ctx *parser.IfStatementContext) interface{} { var out strings.Builder if block := ctx.Statement(0).Block(); block != nil { out.WriteString(fmt.Sprintf("if %s %s", v.visitRule(ctx.ParExpression()), v.visitRule(ctx.Statement(0)))) } else { out.WriteString(fmt.Sprintf("if %s {\n%s\n}", v.visitRule(ctx.ParExpression()), - indent(v.visitRule(ctx.Statement(0)).(string)))) + v.indent(v.visitRule(ctx.Statement(0)).(string)))) } if ctx.ELSE() != nil { if block := ctx.Statement(1).Block(); block != nil { @@ -225,13 +222,13 @@ func (v *Visitor) VisitIfStatement(ctx *parser.IfStatementContext) interface{} { } else if ifStatement := ctx.Statement(1).IfStatement(); ifStatement != nil { out.WriteString(fmt.Sprintf(" else %s", v.visitRule(ifStatement))) } else { - out.WriteString(fmt.Sprintf(" else {\n%s}", indent(v.visitRule(ctx.Statement(1)).(string)))) + out.WriteString(fmt.Sprintf(" else {\n%s}", v.indent(v.visitRule(ctx.Statement(1)).(string)))) } } return out.String() } -func (v *Visitor) VisitWhileStatement(ctx *parser.WhileStatementContext) interface{} { +func (v *FormatVisitor) VisitWhileStatement(ctx *parser.WhileStatementContext) interface{} { if s := ctx.Statement; s == nil { return fmt.Sprintf("while %s;", v.visitRule(ctx.ParExpression())) } @@ -242,31 +239,31 @@ func (v *Visitor) VisitWhileStatement(ctx *parser.WhileStatementContext) interfa } } -func (v *Visitor) VisitForStatement(ctx *parser.ForStatementContext) interface{} { +func (v *FormatVisitor) VisitForStatement(ctx *parser.ForStatementContext) interface{} { if statement := ctx.Statement(); statement != nil { if statement.Block() != nil { return fmt.Sprintf("for (%s) %s", v.visitRule(ctx.ForControl()), v.visitRule(ctx.Statement())) } else { - return fmt.Sprintf("for (%s) {\n%s}\n", v.visitRule(ctx.ForControl()), indent(v.visitRule(ctx.Statement()).(string))) + return fmt.Sprintf("for (%s) {\n%s}\n", v.visitRule(ctx.ForControl()), v.indent(v.visitRule(ctx.Statement()).(string))) } } else { return fmt.Sprintf("for (%s);", v.visitRule(ctx.ForControl())) } } -func (v *Visitor) VisitSwitchStatement(ctx *parser.SwitchStatementContext) interface{} { +func (v *FormatVisitor) VisitSwitchStatement(ctx *parser.SwitchStatementContext) interface{} { when := []string{} for _, w := range ctx.AllWhenControl() { when = append(when, v.visitRule(w).(string)) } - return fmt.Sprintf("switch on %s {\n%s}", v.visitRule(ctx.Expression()), indent(strings.Join(when, "\n"))) + return fmt.Sprintf("switch on %s {\n%s}", v.visitRule(ctx.Expression()), v.indent(strings.Join(when, "\n"))) } -func (v *Visitor) VisitWhenControl(ctx *parser.WhenControlContext) interface{} { +func (v *FormatVisitor) VisitWhenControl(ctx *parser.WhenControlContext) interface{} { return fmt.Sprintf("when %s %s", v.visitRule(ctx.WhenValue()), v.visitRule(ctx.Block())) } -func (v *Visitor) VisitWhenValue(ctx *parser.WhenValueContext) interface{} { +func (v *FormatVisitor) VisitWhenValue(ctx *parser.WhenValueContext) interface{} { switch { case ctx.ELSE() != nil: return "else" @@ -281,7 +278,7 @@ func (v *Visitor) VisitWhenValue(ctx *parser.WhenValueContext) interface{} { } } -func (v *Visitor) VisitWhenLiteral(ctx *parser.WhenLiteralContext) interface{} { +func (v *FormatVisitor) VisitWhenLiteral(ctx *parser.WhenLiteralContext) interface{} { if w := ctx.WhenLiteral(); w != nil { return fmt.Sprintf("(%s)", v.visitRule(w)) } @@ -291,7 +288,7 @@ func (v *Visitor) VisitWhenLiteral(ctx *parser.WhenLiteralContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitTryStatement(ctx *parser.TryStatementContext) interface{} { +func (v *FormatVisitor) VisitTryStatement(ctx *parser.TryStatementContext) interface{} { if len(ctx.AllCatchClause()) > 0 { catchClauses := []string{} for _, c := range ctx.AllCatchClause() { @@ -307,7 +304,7 @@ func (v *Visitor) VisitTryStatement(ctx *parser.TryStatementContext) interface{} } } -func (v *Visitor) VisitCatchClause(ctx *parser.CatchClauseContext) interface{} { +func (v *FormatVisitor) VisitCatchClause(ctx *parser.CatchClauseContext) interface{} { return fmt.Sprintf("catch (%s%s %s) %s", v.Modifiers(ctx.AllModifier()), v.visitRule(ctx.QualifiedName()), @@ -315,15 +312,15 @@ func (v *Visitor) VisitCatchClause(ctx *parser.CatchClauseContext) interface{} { v.visitRule(ctx.Block())) } -func (v *Visitor) VisitFinallyBlock(ctx *parser.FinallyBlockContext) interface{} { +func (v *FormatVisitor) VisitFinallyBlock(ctx *parser.FinallyBlockContext) interface{} { return fmt.Sprintf("finally %s", v.visitRule(ctx.Block())) } -func (v *Visitor) VisitThrowStatement(ctx *parser.ThrowStatementContext) interface{} { +func (v *FormatVisitor) VisitThrowStatement(ctx *parser.ThrowStatementContext) interface{} { return fmt.Sprintf("throw %s;", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitRunAsStatement(ctx *parser.RunAsStatementContext) interface{} { +func (v *FormatVisitor) VisitRunAsStatement(ctx *parser.RunAsStatementContext) interface{} { expressionList := "" if e := ctx.ExpressionList(); e != nil { expressionList = v.visitRule(e).(string) @@ -331,7 +328,7 @@ func (v *Visitor) VisitRunAsStatement(ctx *parser.RunAsStatementContext) interfa return fmt.Sprintf("System.runAs(%s) %s", expressionList, v.visitRule(ctx.Block())) } -func (v *Visitor) VisitForControl(ctx *parser.ForControlContext) interface{} { +func (v *FormatVisitor) VisitForControl(ctx *parser.ForControlContext) interface{} { if enhancedForControl := ctx.EnhancedForControl(); enhancedForControl != nil { return v.visitRule(enhancedForControl) } @@ -350,113 +347,120 @@ func (v *Visitor) VisitForControl(ctx *parser.ForControlContext) interface{} { return init.String() } -func (v *Visitor) VisitEnhancedForControl(ctx *parser.EnhancedForControlContext) interface{} { - return fmt.Sprintf("%s %s : %s", v.visitRule(ctx.TypeRef()), v.visitRule(ctx.Id()), v.visitRule(ctx.Expression())) +func (v *FormatVisitor) VisitEnhancedForControl(ctx *parser.EnhancedForControlContext) interface{} { + var out strings.Builder + out.WriteString(fmt.Sprintf("%s %s : ", v.visitRule(ctx.TypeRef()), v.visitRule(ctx.Id()))) + out.WriteString(v.visitRule(ctx.Expression()).(string)) + return out.String() } -func (v *Visitor) VisitForInit(ctx *parser.ForInitContext) interface{} { +func (v *FormatVisitor) VisitForInit(ctx *parser.ForInitContext) interface{} { return v.visitRule(ctx.GetChild(0).(antlr.RuleNode)) } -func (v *Visitor) VisitContinueStatement(ctx *parser.ContinueStatementContext) interface{} { +func (v *FormatVisitor) VisitContinueStatement(ctx *parser.ContinueStatementContext) interface{} { return "continue;" } -func (v *Visitor) VisitBreakStatement(ctx *parser.BreakStatementContext) interface{} { +func (v *FormatVisitor) VisitBreakStatement(ctx *parser.BreakStatementContext) interface{} { return "break;" } -func (v *Visitor) VisitForUpdate(ctx *parser.ForUpdateContext) interface{} { +func (v *FormatVisitor) VisitForUpdate(ctx *parser.ForUpdateContext) interface{} { return v.visitRule(ctx.ExpressionList()) } -func (v *Visitor) VisitLocalVariableDeclarationStatement(ctx *parser.LocalVariableDeclarationStatementContext) interface{} { +func (v *FormatVisitor) VisitLocalVariableDeclarationStatement(ctx *parser.LocalVariableDeclarationStatementContext) interface{} { return fmt.Sprintf("%s;", v.visitRule(ctx.LocalVariableDeclaration())) } -func (v *Visitor) VisitInsertStatement(ctx *parser.InsertStatementContext) interface{} { +func (v *FormatVisitor) VisitInsertStatement(ctx *parser.InsertStatementContext) interface{} { return fmt.Sprintf("insert %s;", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitUpdateStatement(ctx *parser.UpdateStatementContext) interface{} { +func (v *FormatVisitor) VisitUpdateStatement(ctx *parser.UpdateStatementContext) interface{} { return fmt.Sprintf("update %s;", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitUpsertStatement(ctx *parser.UpsertStatementContext) interface{} { +func (v *FormatVisitor) VisitUpsertStatement(ctx *parser.UpsertStatementContext) interface{} { return fmt.Sprintf("upsert %s;", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitMergeStatement(ctx *parser.MergeStatementContext) interface{} { +func (v *FormatVisitor) VisitMergeStatement(ctx *parser.MergeStatementContext) interface{} { return fmt.Sprintf("merge %s %s;", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitDeleteStatement(ctx *parser.DeleteStatementContext) interface{} { +func (v *FormatVisitor) VisitDeleteStatement(ctx *parser.DeleteStatementContext) interface{} { return fmt.Sprintf("delete %s;", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitUndeleteStatement(ctx *parser.UndeleteStatementContext) interface{} { +func (v *FormatVisitor) VisitUndeleteStatement(ctx *parser.UndeleteStatementContext) interface{} { return fmt.Sprintf("undelete %s;", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitLocalVariableDeclaration(ctx *parser.LocalVariableDeclarationContext) interface{} { +func (v *FormatVisitor) VisitLocalVariableDeclaration(ctx *parser.LocalVariableDeclarationContext) interface{} { return fmt.Sprintf("%s%s %s", v.Modifiers(ctx.AllModifier()), v.visitRule(ctx.TypeRef()), v.visitRule(ctx.VariableDeclarators())) } -func (v *Visitor) VisitReturnStatement(ctx *parser.ReturnStatementContext) interface{} { +func (v *FormatVisitor) VisitReturnStatement(ctx *parser.ReturnStatementContext) interface{} { if e := ctx.Expression(); e != nil { return fmt.Sprintf("return %s;", v.visitRule(e)) } return "return;" } -func (v *Visitor) VisitParExpression(ctx *parser.ParExpressionContext) interface{} { +func (v *FormatVisitor) VisitParExpression(ctx *parser.ParExpressionContext) interface{} { return fmt.Sprintf("(%s)", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitExpressionStatement(ctx *parser.ExpressionStatementContext) interface{} { +func (v *FormatVisitor) VisitExpressionStatement(ctx *parser.ExpressionStatementContext) interface{} { return fmt.Sprintf("%s;", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitAssignExpression(ctx *parser.AssignExpressionContext) interface{} { +func (v *FormatVisitor) VisitAssignExpression(ctx *parser.AssignExpressionContext) interface{} { assignmentToken := ctx.GetChild(1).(antlr.TerminalNode) return fmt.Sprintf("%s %s %s", v.visitRule(ctx.Expression(0)), assignmentToken.GetText(), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitCondExpression(ctx *parser.CondExpressionContext) interface{} { +func (v *FormatVisitor) VisitCondExpression(ctx *parser.CondExpressionContext) interface{} { return fmt.Sprintf("%s ? %s : %s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1)), v.visitRule(ctx.Expression(2))) } -func (v *Visitor) VisitLogAndExpression(ctx *parser.LogAndExpressionContext) interface{} { - // TODO: Wrap long expressions - sep := " " - if len(ctx.Expression(0).GetText()) > 60 { - sep = "\n\t" +func (v *FormatVisitor) VisitLogAndExpression(ctx *parser.LogAndExpressionContext) interface{} { + i := NewChainVisitor() + if i.visitRule(ctx.Expression(0)).(int)+i.visitRule(ctx.Expression(1)).(int) > 2 { + defer restoreWrap(wrap(v)) + } + if v.wrap { + return fmt.Sprintf("%s &&\n\t%s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } - return fmt.Sprintf("%s &&%s%s", v.visitRule(ctx.Expression(0)), sep, v.visitRule(ctx.Expression(1))) + return fmt.Sprintf("%s && %s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitLogOrExpression(ctx *parser.LogOrExpressionContext) interface{} { - // TODO: Wrap long expressions - sep := " " - if len(ctx.Expression(0).GetText()) > 60 { - sep = "\n\t" +func (v *FormatVisitor) VisitLogOrExpression(ctx *parser.LogOrExpressionContext) interface{} { + i := NewChainVisitor() + if i.visitRule(ctx.Expression(0)).(int)+i.visitRule(ctx.Expression(1)).(int) > 2 { + defer restoreWrap(wrap(v)) + } + if v.wrap { + return fmt.Sprintf("%s ||\n\t%s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } - return fmt.Sprintf("%s ||%s%s", v.visitRule(ctx.Expression(0)), sep, v.visitRule(ctx.Expression(1))) + return fmt.Sprintf("%s || %s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitBitAndExpression(ctx *parser.BitAndExpressionContext) interface{} { +func (v *FormatVisitor) VisitBitAndExpression(ctx *parser.BitAndExpressionContext) interface{} { return fmt.Sprintf("%s & %s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitBitOrExpression(ctx *parser.BitOrExpressionContext) interface{} { +func (v *FormatVisitor) VisitBitOrExpression(ctx *parser.BitOrExpressionContext) interface{} { return fmt.Sprintf("%s | %s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitBitNotExpression(ctx *parser.BitNotExpressionContext) interface{} { +func (v *FormatVisitor) VisitBitNotExpression(ctx *parser.BitNotExpressionContext) interface{} { return fmt.Sprintf("%s ^ %s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitBitExpression(ctx *parser.BitExpressionContext) interface{} { +func (v *FormatVisitor) VisitBitExpression(ctx *parser.BitExpressionContext) interface{} { ops := []string{} for _, o := range ctx.AllGT() { ops = append(ops, o.GetText()) @@ -467,52 +471,75 @@ func (v *Visitor) VisitBitExpression(ctx *parser.BitExpressionContext) interface return strings.Join(ops, "") } -func (v *Visitor) VisitArth1Expression(ctx *parser.Arth1ExpressionContext) interface{} { +func (v *FormatVisitor) VisitArth1Expression(ctx *parser.Arth1ExpressionContext) interface{} { return fmt.Sprintf("%s %s %s", v.visitRule(ctx.Expression(0)), ctx.GetChild(1).(antlr.TerminalNode).GetText(), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitArth2Expression(ctx *parser.Arth2ExpressionContext) interface{} { +func (v *FormatVisitor) VisitArth2Expression(ctx *parser.Arth2ExpressionContext) interface{} { sep := " " - if len(ctx.Expression(0).GetText()) > 40 { + if v.wrap { + log.Debug(fmt.Sprintf("visitor says to wrap %T in VisitArth2Expression", ctx.Expression(1))) sep = "\n\t" + log.Debug("not wrapping individual expressions") + defer restoreWrap(unwrap(v)) } return fmt.Sprintf("%s %s%s%s", v.visitRule(ctx.Expression(0)), ctx.GetChild(1).(antlr.TerminalNode).GetText(), sep, v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitNegExpression(ctx *parser.NegExpressionContext) interface{} { +func (v *FormatVisitor) VisitNegExpression(ctx *parser.NegExpressionContext) interface{} { return fmt.Sprintf("%s%s", ctx.GetChild(0).(antlr.TerminalNode).GetText(), v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitPreOpExpression(ctx *parser.PreOpExpressionContext) interface{} { +func (v *FormatVisitor) VisitPreOpExpression(ctx *parser.PreOpExpressionContext) interface{} { return fmt.Sprintf("%s%s", ctx.GetChild(0).(antlr.TerminalNode).GetText(), v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitPostOpExpression(ctx *parser.PostOpExpressionContext) interface{} { +func (v *FormatVisitor) VisitPostOpExpression(ctx *parser.PostOpExpressionContext) interface{} { return fmt.Sprintf("%s%s", v.visitRule(ctx.Expression()), ctx.GetChild(1).(antlr.TerminalNode).GetText()) } -func (v *Visitor) VisitSubExpression(ctx *parser.SubExpressionContext) interface{} { +func (v *FormatVisitor) VisitSubExpression(ctx *parser.SubExpressionContext) interface{} { return fmt.Sprintf("(%s)", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitCastExpression(ctx *parser.CastExpressionContext) interface{} { +func (v *FormatVisitor) VisitCastExpression(ctx *parser.CastExpressionContext) interface{} { return fmt.Sprintf("(%s)%s", v.visitRule(ctx.TypeRef()), v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitNewInstanceExpression(ctx *parser.NewInstanceExpressionContext) interface{} { +func (v *FormatVisitor) VisitNewInstanceExpression(ctx *parser.NewInstanceExpressionContext) interface{} { return fmt.Sprintf("new %s", v.visitRule(ctx.Creator())) } -func (v *Visitor) VisitArrayExpression(ctx *parser.ArrayExpressionContext) interface{} { +func (v *FormatVisitor) VisitArrayExpression(ctx *parser.ArrayExpressionContext) interface{} { return fmt.Sprintf("%s[%s]", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitDotExpression(ctx *parser.DotExpressionContext) interface{} { +func (v *FormatVisitor) VisitDotExpression(ctx *parser.DotExpressionContext) interface{} { + i := NewChainVisitor() + depth := i.visitRule(ctx.Expression()).(int) + log.Debug(fmt.Sprintf("depth is %d: %s", depth, ctx.GetText())) + if depth > 1 { + defer restoreWrap(wrap(v)) + } expr := v.visitRule(ctx.Expression()) dot := ctx.GetChild(1).(antlr.TerminalNode).GetText() - switch { case ctx.DotMethodCall() != nil: + i := NewIndentVisitor() + depth := i.visitRule(ctx.Expression()).(int) + if v.wrap { + if depth == 0 { + depth = 1 + } + switch ctx.Expression().(type) { + case *parser.PrimaryExpressionContext: + log.Debug(fmt.Sprintf("NOT wrapping after between %q (%T)", expr, ctx.Expression())) + default: + log.Debug(fmt.Sprintf("Wrapping in between %q (%T) and %q", expr, ctx.Expression(), ctx.DotMethodCall().GetText())) + return expr.(string) + "\n" + v.indentTo(fmt.Sprintf("%s%s", dot, v.visitRule(ctx.DotMethodCall())), depth) + } + } + return fmt.Sprintf("%s%s%s", expr, dot, v.visitRule(ctx.DotMethodCall())) case ctx.AnyId() != nil: return fmt.Sprintf("%s%s%s", expr, dot, v.visitRule(ctx.AnyId())) @@ -520,7 +547,11 @@ func (v *Visitor) VisitDotExpression(ctx *parser.DotExpressionContext) interface return "" } -func (v *Visitor) VisitDotMethodCall(ctx *parser.DotMethodCallContext) interface{} { +func (v *FormatVisitor) VisitDotMethodCall(ctx *parser.DotMethodCallContext) interface{} { + if v.wrap { + log.Debug(fmt.Sprintf("Visitor says to wrap in VisitDotMethodCall; not wrapping individual expressions: %s", ctx.GetText())) + defer restoreWrap(unwrap(v)) + } expressionList := "" if l := ctx.ExpressionList(); l != nil { expressionList = v.visitRule(l).(string) @@ -528,33 +559,46 @@ func (v *Visitor) VisitDotMethodCall(ctx *parser.DotMethodCallContext) interface return fmt.Sprintf("%s(%s)", v.visitRule(ctx.AnyId()), expressionList) } -func (v *Visitor) VisitExpressionList(ctx *parser.ExpressionListContext) interface{} { +func (v *FormatVisitor) VisitExpressionList(ctx *parser.ExpressionListContext) interface{} { + wrap := v.wrap || (len(ctx.GetText()) > 40 && len(ctx.AllExpression()) > 3) || len(ctx.GetText()) > 150 + expressions := []string{} - for _, p := range ctx.AllExpression() { - expressions = append(expressions, v.visitRule(p).(string)) + for i, p := range ctx.AllExpression() { + // We want to indent method argument expressions, but not new instance arguments + switch p.(type) { + case *parser.AssignExpressionContext: + expressions = append(expressions, v.visitRule(p).(string)) + default: + if wrap && i > 0 { + expressions = append(expressions, v.indent(v.visitRule(p).(string))) + } else { + expressions = append(expressions, v.visitRule(p).(string)) + } + } } - if len(ctx.GetText()) > 90 { + + if wrap { return strings.Join(expressions, ",\n") } return strings.Join(expressions, ", ") } -func (v *Visitor) VisitAnyId(ctx *parser.AnyIdContext) interface{} { +func (v *FormatVisitor) VisitAnyId(ctx *parser.AnyIdContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitPrimaryExpression(ctx *parser.PrimaryExpressionContext) interface{} { +func (v *FormatVisitor) VisitPrimaryExpression(ctx *parser.PrimaryExpressionContext) interface{} { switch e := ctx.Primary().(type) { case *parser.ThisPrimaryContext: return "this" case *parser.SuperPrimaryContext: return "super" case *parser.LiteralPrimaryContext: - return e.GetText() + return v.visitRule(e) case *parser.TypeRefPrimaryContext: return v.visitRule(e) case *parser.IdPrimaryContext: - return e.GetText() + return v.visitRule(e) case *parser.SoqlPrimaryContext: return v.visitRule(e) case *parser.SoslPrimaryContext: @@ -564,15 +608,27 @@ func (v *Visitor) VisitPrimaryExpression(ctx *parser.PrimaryExpressionContext) i } } -func (v *Visitor) VisitMethodCallExpression(ctx *parser.MethodCallExpressionContext) interface{} { +func (v *FormatVisitor) VisitIdPrimary(ctx *parser.IdPrimaryContext) interface{} { + return v.visitRule(ctx.Id()) +} + +func (v *FormatVisitor) VisitLiteralPrimary(ctx *parser.LiteralPrimaryContext) interface{} { + return v.visitRule(ctx.Literal()) +} + +func (v *FormatVisitor) VisitLiteral(ctx *parser.LiteralContext) interface{} { + return ctx.GetText() +} + +func (v *FormatVisitor) VisitMethodCallExpression(ctx *parser.MethodCallExpressionContext) interface{} { return v.visitRule(ctx.MethodCall()) } -func (v *Visitor) VisitMethodCall(ctx *parser.MethodCallContext) interface{} { +func (v *FormatVisitor) VisitMethodCall(ctx *parser.MethodCallContext) interface{} { var f string switch e := ctx.GetChild(0).(type) { case *parser.IdContext: - f = e.GetText() + f = v.visitRule(e).(string) case antlr.TerminalNode: f = strings.ToLower(e.GetText()) } @@ -583,59 +639,85 @@ func (v *Visitor) VisitMethodCall(ctx *parser.MethodCallContext) interface{} { return fmt.Sprintf("%s(%s)", f, expressionList) } -func (v *Visitor) VisitSoslPrimary(ctx *parser.SoslPrimaryContext) interface{} { +func (v *FormatVisitor) VisitSoslPrimary(ctx *parser.SoslPrimaryContext) interface{} { return v.visitRule(ctx.SoslLiteral()) } -func (v *Visitor) VisitSoqlPrimary(ctx *parser.SoqlPrimaryContext) interface{} { +func (v *FormatVisitor) VisitSoqlPrimary(ctx *parser.SoqlPrimaryContext) interface{} { return v.visitRule(ctx.SoqlLiteral()) } -func (v *Visitor) VisitSoqlLiteral(ctx *parser.SoqlLiteralContext) interface{} { - return fmt.Sprintf("[\n%s\n]", indent(v.visitRule(ctx.Query()).(string))) +func (v *FormatVisitor) VisitSoqlLiteral(ctx *parser.SoqlLiteralContext) interface{} { + // Check whether we should wrap this SOQL Query based on query complexity + i := NewChainVisitor() + n := i.visitRule(ctx.Query()).(int) + if n > 3 { + defer restoreWrap(wrap(v)) + return fmt.Sprintf("[\n%s\n]", v.indent(v.visitRule(ctx.Query()).(string))) + } + return fmt.Sprintf("[%s]", v.visitRule(ctx.Query())) } -func (v *Visitor) VisitQuery(ctx *parser.QueryContext) interface{} { +func (v *FormatVisitor) VisitQuery(ctx *parser.QueryContext) interface{} { + sep := " " + indent := 0 + if v.wrap { + sep = "\n" + indent = 1 + } var query strings.Builder - query.WriteString(fmt.Sprintf("SELECT\n%s\nFROM\n%s", - indent(v.visitRule(ctx.SelectList()).(string)), - indent(v.visitRule(ctx.FromNameList()).(string)))) + query.WriteString("SELECT") + query.WriteString(sep) + query.WriteString(v.indentTo(v.visitRule(ctx.SelectList()).(string), indent)) + query.WriteString(sep) + query.WriteString("FROM") + query.WriteString(sep) + query.WriteString(v.indentTo(v.visitRule(ctx.FromNameList()).(string), indent)) if scope := ctx.UsingScope(); scope != nil { - query.WriteString(fmt.Sprintf("\n%s", v.visitRule(scope).(string))) + query.WriteString(sep) + query.WriteString(fmt.Sprintf("%s", v.visitRule(scope).(string))) } if where := ctx.WhereClause(); where != nil { - query.WriteString(fmt.Sprintf("\n%s", v.visitRule(where).(string))) + query.WriteString(sep) + query.WriteString(v.visitRule(where).(string)) } if groupBy := ctx.GroupByClause(); groupBy != nil { - query.WriteString(fmt.Sprintf("\n%s", v.visitRule(groupBy).(string))) + query.WriteString(sep) + query.WriteString(v.visitRule(groupBy).(string)) } if orderBy := ctx.OrderByClause(); orderBy != nil { - query.WriteString(fmt.Sprintf("\n%s", v.visitRule(orderBy).(string))) + query.WriteString(sep) + query.WriteString(v.visitRule(orderBy).(string)) } if limit := ctx.LimitClause(); limit != nil { - query.WriteString(fmt.Sprintf("\n%s", v.visitRule(limit).(string))) + query.WriteString(sep) + query.WriteString(v.visitRule(limit).(string)) } if offset := ctx.OffsetClause(); offset != nil { - query.WriteString(fmt.Sprintf("\n%s", v.visitRule(offset).(string))) + query.WriteString(sep) + query.WriteString(v.visitRule(offset).(string)) } if ctx.OffsetClause() != nil { - query.WriteString("\nALL ROWS") + query.WriteString(sep) + query.WriteString("ALL ROWS") } forClauses := v.visitRule(ctx.ForClauses()) if forClauses != "" { - query.WriteString(fmt.Sprintf("\n%s", forClauses)) + query.WriteString(sep) + query.WriteString(forClauses.(string)) } if update := ctx.UpdateList(); update != nil { - query.WriteString(fmt.Sprintf("\nUPDATE %s", v.visitRule(update).(string))) + query.WriteString(sep) + query.WriteString(fmt.Sprintf("UPDATE %s", v.visitRule(update).(string))) } return query.String() } -func (v *Visitor) VisitSubQuery(ctx *parser.SubQueryContext) interface{} { +func (v *FormatVisitor) VisitSubQuery(ctx *parser.SubQueryContext) interface{} { var query strings.Builder query.WriteString(fmt.Sprintf("SELECT\n%s\nFROM\n%s", - indent(v.visitRule(ctx.SubFieldList()).(string)), - indent(v.visitRule(ctx.FromNameList()).(string)), + v.indent(v.visitRule(ctx.SubFieldList()).(string)), + v.indent(v.visitRule(ctx.FromNameList()).(string)), )) if where := ctx.WhereClause(); where != nil { query.WriteString(fmt.Sprintf("\n%s", v.visitRule(where).(string))) @@ -656,7 +738,7 @@ func (v *Visitor) VisitSubQuery(ctx *parser.SubQueryContext) interface{} { return query.String() } -func (v *Visitor) VisitFromNameList(ctx *parser.FromNameListContext) interface{} { +func (v *FormatVisitor) VisitFromNameList(ctx *parser.FromNameListContext) interface{} { fieldNames := []string{} for _, p := range ctx.AllFieldNameAlias() { fieldNames = append(fieldNames, v.visitRule(p).(string)) @@ -664,7 +746,7 @@ func (v *Visitor) VisitFromNameList(ctx *parser.FromNameListContext) interface{} return strings.Join(fieldNames, ",\n") } -func (v *Visitor) VisitUpdateList(ctx *parser.UpdateListContext) interface{} { +func (v *FormatVisitor) VisitUpdateList(ctx *parser.UpdateListContext) interface{} { updateList := "" if u := ctx.UpdateList(); u != nil { updateList = fmt.Sprintf(", %s", v.visitRule(u).(string)) @@ -672,7 +754,7 @@ func (v *Visitor) VisitUpdateList(ctx *parser.UpdateListContext) interface{} { return fmt.Sprintf("%s%s", ctx.UpdateType().GetText(), updateList) } -func (v *Visitor) VisitFieldNameAlias(ctx *parser.FieldNameAliasContext) interface{} { +func (v *FormatVisitor) VisitFieldNameAlias(ctx *parser.FieldNameAliasContext) interface{} { soqlId := "" if s := ctx.SoqlId(); s != nil { soqlId = " " + s.GetText() @@ -680,7 +762,7 @@ func (v *Visitor) VisitFieldNameAlias(ctx *parser.FieldNameAliasContext) interfa return fmt.Sprintf("%s%s", v.visitRule(ctx.FieldName()), soqlId) } -func (v *Visitor) VisitSelectList(ctx *parser.SelectListContext) interface{} { +func (v *FormatVisitor) VisitSelectList(ctx *parser.SelectListContext) interface{} { selectEntries := []string{} for _, p := range ctx.AllSelectEntry() { selectEntries = append(selectEntries, v.visitRule(p).(string)) @@ -688,7 +770,7 @@ func (v *Visitor) VisitSelectList(ctx *parser.SelectListContext) interface{} { return strings.Join(selectEntries, ",\n") } -func (v *Visitor) VisitSubFieldList(ctx *parser.SubFieldListContext) interface{} { +func (v *FormatVisitor) VisitSubFieldList(ctx *parser.SubFieldListContext) interface{} { selectEntries := []string{} for _, p := range ctx.AllSubFieldEntry() { selectEntries = append(selectEntries, v.visitRule(p).(string)) @@ -696,7 +778,7 @@ func (v *Visitor) VisitSubFieldList(ctx *parser.SubFieldListContext) interface{} return strings.Join(selectEntries, ",\n") } -func (v *Visitor) VisitSelectEntry(ctx *parser.SelectEntryContext) interface{} { +func (v *FormatVisitor) VisitSelectEntry(ctx *parser.SelectEntryContext) interface{} { soqlId := "" if s := ctx.SoqlId(); s != nil { soqlId = " " + s.GetText() @@ -714,7 +796,7 @@ func (v *Visitor) VisitSelectEntry(ctx *parser.SelectEntryContext) interface{} { panic("Unexpected selectEntry") } -func (v *Visitor) VisitSubFieldEntry(ctx *parser.SubFieldEntryContext) interface{} { +func (v *FormatVisitor) VisitSubFieldEntry(ctx *parser.SubFieldEntryContext) interface{} { soqlId := "" if s := ctx.SoqlId(); s != nil { soqlId = " " + s.GetText() @@ -730,7 +812,7 @@ func (v *Visitor) VisitSubFieldEntry(ctx *parser.SubFieldEntryContext) interface panic("Unexpected selectEntry") } -func (v *Visitor) VisitFieldName(ctx *parser.FieldNameContext) interface{} { +func (v *FormatVisitor) VisitFieldName(ctx *parser.FieldNameContext) interface{} { ids := []string{} for _, t := range ctx.AllSoqlId() { ids = append(ids, t.GetText()) @@ -738,7 +820,7 @@ func (v *Visitor) VisitFieldName(ctx *parser.FieldNameContext) interface{} { return strings.Join(ids, ".") } -func (v *Visitor) VisitFieldNameList(ctx *parser.FieldNameListContext) interface{} { +func (v *FormatVisitor) VisitFieldNameList(ctx *parser.FieldNameListContext) interface{} { fieldNames := []string{} for _, p := range ctx.AllFieldName() { fieldNames = append(fieldNames, v.visitRule(p).(string)) @@ -746,7 +828,7 @@ func (v *Visitor) VisitFieldNameList(ctx *parser.FieldNameListContext) interface return strings.Join(fieldNames, ",\n") } -func (v *Visitor) VisitTypeOf(ctx *parser.TypeOfContext) interface{} { +func (v *FormatVisitor) VisitTypeOf(ctx *parser.TypeOfContext) interface{} { whenClauses := []string{} for _, w := range ctx.AllWhenClause() { whenClauses = append(whenClauses, v.visitRule(w).(string)) @@ -763,7 +845,7 @@ func (v *Visitor) VisitTypeOf(ctx *parser.TypeOfContext) interface{} { ) } -func (v *Visitor) VisitForClauses(ctx *parser.ForClausesContext) interface{} { +func (v *FormatVisitor) VisitForClauses(ctx *parser.ForClausesContext) interface{} { forClauses := []string{} for _, f := range ctx.AllForClause() { forClauses = append(forClauses, v.visitRule(f).(string)) @@ -771,33 +853,57 @@ func (v *Visitor) VisitForClauses(ctx *parser.ForClausesContext) interface{} { return strings.Join(forClauses, " ") } -func (v *Visitor) VisitForClause(ctx *parser.ForClauseContext) interface{} { +func (v *FormatVisitor) VisitForClause(ctx *parser.ForClauseContext) interface{} { return fmt.Sprintf("FOR %s", ctx.GetChild(1).(antlr.TerminalNode).GetText()) } -func (v *Visitor) VisitWhenClause(ctx *parser.WhenClauseContext) interface{} { - return fmt.Sprintf("WHEN\n%s\nTHEN\n%s", indent(v.visitRule(ctx.FieldName()).(string)), indent(v.visitRule(ctx.FieldNameList()).(string))) -} - -func (v *Visitor) VisitWhereClause(ctx *parser.WhereClauseContext) interface{} { - return fmt.Sprintf("WHERE\n%s", indent(v.visitRule(ctx.LogicalExpression()).(string))) +func (v *FormatVisitor) VisitWhenClause(ctx *parser.WhenClauseContext) interface{} { + sep := " " + indent := 0 + if v.wrap { + sep = "\n" + indent = 1 + } + var clause strings.Builder + clause.WriteString("WHEN") + clause.WriteString(sep) + clause.WriteString(v.indentTo(v.visitRule(ctx.FieldName()).(string), indent)) + clause.WriteString(sep) + clause.WriteString("THEN") + clause.WriteString(sep) + clause.WriteString(v.indentTo(v.visitRule(ctx.FieldNameList()).(string), indent)) + return clause.String() +} + +func (v *FormatVisitor) VisitWhereClause(ctx *parser.WhereClauseContext) interface{} { + sep := " " + indent := 0 + if v.wrap { + sep = "\n" + indent = 1 + } + var clause strings.Builder + clause.WriteString("WHERE") + clause.WriteString(sep) + clause.WriteString(v.indentTo(v.visitRule(ctx.LogicalExpression()).(string), indent)) + return clause.String() } -func (v *Visitor) VisitLimitClause(ctx *parser.LimitClauseContext) interface{} { +func (v *FormatVisitor) VisitLimitClause(ctx *parser.LimitClauseContext) interface{} { if e := ctx.BoundExpression(); e != nil { return fmt.Sprintf("LIMIT %s", v.visitRule(ctx.BoundExpression())) } return fmt.Sprintf("LIMIT %s", ctx.IntegerLiteral().GetText()) } -func (v *Visitor) VisitOffsetClause(ctx *parser.OffsetClauseContext) interface{} { +func (v *FormatVisitor) VisitOffsetClause(ctx *parser.OffsetClauseContext) interface{} { if e := ctx.BoundExpression(); e != nil { return fmt.Sprintf("OFFSET %s", v.visitRule(ctx.BoundExpression())) } return fmt.Sprintf("OFFSET %s", ctx.IntegerLiteral().GetText()) } -func (v *Visitor) VisitLogicalExpression(ctx *parser.LogicalExpressionContext) interface{} { +func (v *FormatVisitor) VisitLogicalExpression(ctx *parser.LogicalExpressionContext) interface{} { switch { case ctx.NOT() != nil: return fmt.Sprintf("NOT %s", v.visitRule(ctx.ConditionalExpression(0))) @@ -819,7 +925,7 @@ func (v *Visitor) VisitLogicalExpression(ctx *parser.LogicalExpressionContext) i } } -func (v *Visitor) VisitConditionalExpression(ctx *parser.ConditionalExpressionContext) interface{} { +func (v *FormatVisitor) VisitConditionalExpression(ctx *parser.ConditionalExpressionContext) interface{} { switch { case ctx.LogicalExpression() != nil: return fmt.Sprintf("(%s)", v.visitRule(ctx.LogicalExpression())) @@ -829,7 +935,7 @@ func (v *Visitor) VisitConditionalExpression(ctx *parser.ConditionalExpressionCo panic("Unexpected conditionalExpression") } -func (v *Visitor) VisitFieldExpression(ctx *parser.FieldExpressionContext) interface{} { +func (v *FormatVisitor) VisitFieldExpression(ctx *parser.FieldExpressionContext) interface{} { switch { case ctx.FieldName() != nil: // TODO: Format IN/NOT IN @@ -840,14 +946,14 @@ func (v *Visitor) VisitFieldExpression(ctx *parser.FieldExpressionContext) inter panic("Unexpected fieldExpression") } -func (v *Visitor) VisitComparisonOperator(ctx *parser.ComparisonOperatorContext) interface{} { +func (v *FormatVisitor) VisitComparisonOperator(ctx *parser.ComparisonOperatorContext) interface{} { if ctx.NOT() != nil { return "NOT IN" } return ctx.GetText() } -func (v *Visitor) VisitSoqlFunction(ctx *parser.SoqlFunctionContext) interface{} { +func (v *FormatVisitor) VisitSoqlFunction(ctx *parser.SoqlFunctionContext) interface{} { param := "" switch { case ctx.COUNT() != nil: @@ -864,77 +970,77 @@ func (v *Visitor) VisitSoqlFunction(ctx *parser.SoqlFunctionContext) interface{} return fmt.Sprintf("%s(%s)", ctx.GetChild(0).(antlr.TerminalNode).GetText(), param) } -func (v *Visitor) VisitSoqlFieldsParameter(ctx *parser.SoqlFieldsParameterContext) interface{} { +func (v *FormatVisitor) VisitSoqlFieldsParameter(ctx *parser.SoqlFieldsParameterContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitDateFieldName(ctx *parser.DateFieldNameContext) interface{} { +func (v *FormatVisitor) VisitDateFieldName(ctx *parser.DateFieldNameContext) interface{} { if ctx.CONVERT_TIMEZONE() != nil { return fmt.Sprintf("CONVERT_TIMEZONE(%s)", v.visitRule(ctx.FieldName())) } return v.visitRule(ctx.FieldName()) } -func (v *Visitor) VisitNullValue(ctx *parser.NullValueContext) interface{} { +func (v *FormatVisitor) VisitNullValue(ctx *parser.NullValueContext) interface{} { return "null" } -func (v *Visitor) VisitBooleanLiteralValue(ctx *parser.BooleanLiteralValueContext) interface{} { +func (v *FormatVisitor) VisitBooleanLiteralValue(ctx *parser.BooleanLiteralValueContext) interface{} { return strings.ToLower(ctx.GetText()) } -func (v *Visitor) VisitSignedNumberValue(ctx *parser.SignedNumberValueContext) interface{} { +func (v *FormatVisitor) VisitSignedNumberValue(ctx *parser.SignedNumberValueContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitStringLiteralValue(ctx *parser.StringLiteralValueContext) interface{} { +func (v *FormatVisitor) VisitStringLiteralValue(ctx *parser.StringLiteralValueContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitDateLiteralValue(ctx *parser.DateLiteralValueContext) interface{} { +func (v *FormatVisitor) VisitDateLiteralValue(ctx *parser.DateLiteralValueContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitDateTimeLiteralValue(ctx *parser.DateTimeLiteralValueContext) interface{} { +func (v *FormatVisitor) VisitDateTimeLiteralValue(ctx *parser.DateTimeLiteralValueContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitDateFormulaValue(ctx *parser.DateFormulaValueContext) interface{} { +func (v *FormatVisitor) VisitDateFormulaValue(ctx *parser.DateFormulaValueContext) interface{} { return v.visitRule(ctx.DateFormula()) } -func (v *Visitor) VisitCurrencyValueValue(ctx *parser.CurrencyValueValueContext) interface{} { +func (v *FormatVisitor) VisitCurrencyValueValue(ctx *parser.CurrencyValueValueContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitSubQueryValue(ctx *parser.SubQueryValueContext) interface{} { +func (v *FormatVisitor) VisitSubQueryValue(ctx *parser.SubQueryValueContext) interface{} { return fmt.Sprintf("(%s)", v.visitRule(ctx.SubQuery())) } -func (v *Visitor) VisitValueListValue(ctx *parser.ValueListValueContext) interface{} { +func (v *FormatVisitor) VisitValueListValue(ctx *parser.ValueListValueContext) interface{} { return v.visitRule(ctx.ValueList()) } -func (v *Visitor) VisitBoundExpressionValue(ctx *parser.BoundExpressionValueContext) interface{} { +func (v *FormatVisitor) VisitBoundExpressionValue(ctx *parser.BoundExpressionValueContext) interface{} { return v.visitRule(ctx.BoundExpression()) } -func (v *Visitor) VisitDateFormula(ctx *parser.DateFormulaContext) interface{} { +func (v *FormatVisitor) VisitDateFormula(ctx *parser.DateFormulaContext) interface{} { if ctx.SignedInteger() != nil { return fmt.Sprintf("%s:%s", ctx.GetChild(0).(antlr.TerminalNode).GetText(), v.visitRule(ctx.SignedInteger())) } return ctx.GetChild(0).(antlr.TerminalNode).GetText() } -func (v *Visitor) VisitSignedInteger(ctx *parser.SignedIntegerContext) interface{} { +func (v *FormatVisitor) VisitSignedInteger(ctx *parser.SignedIntegerContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitSignedNumber(ctx *parser.SignedNumberContext) interface{} { +func (v *FormatVisitor) VisitSignedNumber(ctx *parser.SignedNumberContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitValueList(ctx *parser.ValueListContext) interface{} { +func (v *FormatVisitor) VisitValueList(ctx *parser.ValueListContext) interface{} { values := []string{} for _, i := range ctx.AllValue() { values = append(values, v.visitRule(i).(string)) @@ -942,7 +1048,7 @@ func (v *Visitor) VisitValueList(ctx *parser.ValueListContext) interface{} { return fmt.Sprintf("(%s)", strings.Join(values, ", ")) } -func (v *Visitor) VisitGroupByClause(ctx *parser.GroupByClauseContext) interface{} { +func (v *FormatVisitor) VisitGroupByClause(ctx *parser.GroupByClauseContext) interface{} { fieldNames := []string{} for _, i := range ctx.AllFieldName() { fieldNames = append(fieldNames, v.visitRule(i).(string)) @@ -961,15 +1067,15 @@ func (v *Visitor) VisitGroupByClause(ctx *parser.GroupByClauseContext) interface } } -func (v *Visitor) VisitUsingScope(ctx *parser.UsingScopeContext) interface{} { +func (v *FormatVisitor) VisitUsingScope(ctx *parser.UsingScopeContext) interface{} { return fmt.Sprintf("USING SCOPE %s", ctx.SoqlId().GetText()) } -func (v *Visitor) VisitOrderByClause(ctx *parser.OrderByClauseContext) interface{} { - return fmt.Sprintf("ORDER BY\n%s", indent(v.visitRule(ctx.FieldOrderList()).(string))) +func (v *FormatVisitor) VisitOrderByClause(ctx *parser.OrderByClauseContext) interface{} { + return fmt.Sprintf("ORDER BY\n%s", v.indent(v.visitRule(ctx.FieldOrderList()).(string))) } -func (v *Visitor) VisitFieldOrderList(ctx *parser.FieldOrderListContext) interface{} { +func (v *FormatVisitor) VisitFieldOrderList(ctx *parser.FieldOrderListContext) interface{} { fields := []string{} for _, i := range ctx.AllFieldOrder() { fields = append(fields, v.visitRule(i).(string)) @@ -977,7 +1083,7 @@ func (v *Visitor) VisitFieldOrderList(ctx *parser.FieldOrderListContext) interfa return strings.Join(fields, ", ") } -func (v *Visitor) VisitFieldOrder(ctx *parser.FieldOrderContext) interface{} { +func (v *FormatVisitor) VisitFieldOrder(ctx *parser.FieldOrderContext) interface{} { var field strings.Builder if f := ctx.FieldName(); f != nil { field.WriteString(v.visitRule(f).(string)) @@ -1000,15 +1106,15 @@ func (v *Visitor) VisitFieldOrder(ctx *parser.FieldOrderContext) interface{} { return field.String() } -func (v *Visitor) VisitBoundExpression(ctx *parser.BoundExpressionContext) interface{} { +func (v *FormatVisitor) VisitBoundExpression(ctx *parser.BoundExpressionContext) interface{} { return fmt.Sprintf(":%s", v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitCreator(ctx *parser.CreatorContext) interface{} { +func (v *FormatVisitor) VisitCreator(ctx *parser.CreatorContext) interface{} { return fmt.Sprintf("%s%s", v.visitRule(ctx.CreatedName()), v.visitRule(ctx.GetChild(1).(antlr.RuleNode))) } -func (v *Visitor) VisitCreatedName(ctx *parser.CreatedNameContext) interface{} { +func (v *FormatVisitor) VisitCreatedName(ctx *parser.CreatedNameContext) interface{} { namePairs := []string{} for _, i := range ctx.AllIdCreatedNamePair() { namePairs = append(namePairs, v.visitRule(i).(string)) @@ -1016,26 +1122,26 @@ func (v *Visitor) VisitCreatedName(ctx *parser.CreatedNameContext) interface{} { return strings.Join(namePairs, ".") } -func (v *Visitor) VisitIdCreatedNamePair(ctx *parser.IdCreatedNamePairContext) interface{} { +func (v *FormatVisitor) VisitIdCreatedNamePair(ctx *parser.IdCreatedNamePairContext) interface{} { if typeList := ctx.TypeList(); typeList != nil { return fmt.Sprintf("%s<%s>", v.visitRule(ctx.AnyId()), v.visitRule(typeList)) } return v.visitRule(ctx.AnyId()) } -func (v *Visitor) VisitNoRest(ctx *parser.NoRestContext) interface{} { +func (v *FormatVisitor) VisitNoRest(ctx *parser.NoRestContext) interface{} { return "{}" } -func (v *Visitor) VisitId(ctx *parser.IdContext) interface{} { +func (v *FormatVisitor) VisitId(ctx *parser.IdContext) interface{} { return ctx.GetText() } -func (v *Visitor) VisitClassCreatorRest(ctx *parser.ClassCreatorRestContext) interface{} { +func (v *FormatVisitor) VisitClassCreatorRest(ctx *parser.ClassCreatorRestContext) interface{} { return v.visitRule(ctx.Arguments()) } -func (v *Visitor) VisitArrayCreatorRest(ctx *parser.ArrayCreatorRestContext) interface{} { +func (v *FormatVisitor) VisitArrayCreatorRest(ctx *parser.ArrayCreatorRestContext) interface{} { if expression := ctx.Expression(); expression != nil { return fmt.Sprintf("[ %s ]", v.visitRule(expression)) } else if arrayInitializer := ctx.ArrayInitializer(); arrayInitializer != nil { @@ -1044,33 +1150,33 @@ func (v *Visitor) VisitArrayCreatorRest(ctx *parser.ArrayCreatorRestContext) int return "[]" } -func (v *Visitor) VisitMapCreatorRest(ctx *parser.MapCreatorRestContext) interface{} { +func (v *FormatVisitor) VisitMapCreatorRest(ctx *parser.MapCreatorRestContext) interface{} { pairs := []string{} for _, i := range ctx.AllMapCreatorRestPair() { pairs = append(pairs, v.visitRule(i).(string)) } - if len(ctx.GetText()) > 80 { - return fmt.Sprintf("{\n%s\n}", indent(strings.Join(pairs, ",\n"))) + if len(pairs) > 1 { + return fmt.Sprintf("{\n%s\n}", v.indent(strings.Join(pairs, ",\n"))) } - return fmt.Sprintf("{%s}", strings.Join(pairs, ", ")) + return fmt.Sprintf("{ %s }", strings.Join(pairs, ", ")) } -func (v *Visitor) VisitMapCreatorRestPair(ctx *parser.MapCreatorRestPairContext) interface{} { +func (v *FormatVisitor) VisitMapCreatorRestPair(ctx *parser.MapCreatorRestPairContext) interface{} { return fmt.Sprintf("%s => %s", v.visitRule(ctx.Expression(0)), v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitSetCreatorRest(ctx *parser.SetCreatorRestContext) interface{} { +func (v *FormatVisitor) VisitSetCreatorRest(ctx *parser.SetCreatorRestContext) interface{} { expressions := []string{} for _, i := range ctx.AllExpression() { expressions = append(expressions, v.visitRule(i).(string)) } if len(ctx.GetText()) > 80 { - return fmt.Sprintf("{\n%s\n}", indent(strings.Join(expressions, ",\n"))) + return fmt.Sprintf("{\n%s\n}", v.indent(strings.Join(expressions, ",\n"))) } return fmt.Sprintf("{ %s }", strings.Join(expressions, ", ")) } -func (v *Visitor) VisitArrayInitializer(ctx *parser.ArrayInitializerContext) interface{} { +func (v *FormatVisitor) VisitArrayInitializer(ctx *parser.ArrayInitializerContext) interface{} { expressions := []string{} for _, i := range ctx.AllExpression() { expressions = append(expressions, v.visitRule(i).(string)) @@ -1078,18 +1184,23 @@ func (v *Visitor) VisitArrayInitializer(ctx *parser.ArrayInitializerContext) int return fmt.Sprintf("{ %s }", strings.Join(expressions, ", ")) } -func (v *Visitor) VisitArguments(ctx *parser.ArgumentsContext) interface{} { +// Class instance arguments, e.g. (Name = 'Acme', BillingCity = 'Los Angeles') in Account(Name = 'Acme', BillingCity = 'Los Angeles') +func (v *FormatVisitor) VisitArguments(ctx *parser.ArgumentsContext) interface{} { expressionList := ctx.ExpressionList() if expressionList == nil { return "()" } + if v.wrap { + log.Debug("Visitor says to wrap in VisitArguments") + } if len(expressionList.GetText()) > 40 { - return fmt.Sprintf("(\n%s\n)", indent(v.visitRule(expressionList).(string))) + defer restoreWrap(wrap(v)) + return fmt.Sprintf("(\n%s\n)", v.indent(v.visitRule(expressionList).(string))) } return fmt.Sprintf("(%s)", v.visitRule(expressionList)) } -func (v *Visitor) VisitCmpExpression(ctx *parser.CmpExpressionContext) interface{} { +func (v *FormatVisitor) VisitCmpExpression(ctx *parser.CmpExpressionContext) interface{} { cmpToken := ctx.GetChild(1).(antlr.TerminalNode).GetText() if ctx.ASSIGN() != nil { cmpToken += "=" @@ -1097,16 +1208,17 @@ func (v *Visitor) VisitCmpExpression(ctx *parser.CmpExpressionContext) interface return fmt.Sprintf("%s %s %s", v.visitRule(ctx.Expression(0)), cmpToken, v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitEqualityExpression(ctx *parser.EqualityExpressionContext) interface{} { +func (v *FormatVisitor) VisitEqualityExpression(ctx *parser.EqualityExpressionContext) interface{} { + defer restoreWrap(unwrap(v)) cmpToken := ctx.GetChild(1).(antlr.TerminalNode).GetText() return fmt.Sprintf("%s %s %s", v.visitRule(ctx.Expression(0)), cmpToken, v.visitRule(ctx.Expression(1))) } -func (v *Visitor) VisitInstanceOfExpression(ctx *parser.InstanceOfExpressionContext) interface{} { +func (v *FormatVisitor) VisitInstanceOfExpression(ctx *parser.InstanceOfExpressionContext) interface{} { return fmt.Sprintf("%s instanceof %s", v.visitRule(ctx.Expression()), v.visitRule(ctx.TypeRef())) } -func (v *Visitor) VisitTypeList(ctx *parser.TypeListContext) interface{} { +func (v *FormatVisitor) VisitTypeList(ctx *parser.TypeListContext) interface{} { types := []string{} for _, p := range ctx.AllTypeRef() { types = append(types, v.visitRule(p).(string)) @@ -1118,7 +1230,7 @@ func (v *Visitor) VisitTypeList(ctx *parser.TypeListContext) interface{} { return strings.Join(types, sep) } -func (v *Visitor) VisitFormalParameters(ctx *parser.FormalParametersContext) interface{} { +func (v *FormatVisitor) VisitFormalParameters(ctx *parser.FormalParametersContext) interface{} { params := []string{} list := ctx.FormalParameterList() if list == nil { @@ -1131,7 +1243,7 @@ func (v *Visitor) VisitFormalParameters(ctx *parser.FormalParametersContext) int return val } -func (v *Visitor) VisitAnnotation(ctx *parser.AnnotationContext) interface{} { +func (v *FormatVisitor) VisitAnnotation(ctx *parser.AnnotationContext) interface{} { args := "" if ctx.LPAREN() != nil { vals := "" @@ -1145,7 +1257,7 @@ func (v *Visitor) VisitAnnotation(ctx *parser.AnnotationContext) interface{} { return fmt.Sprintf("@%s%s", v.visitRule(ctx.QualifiedName()), args) } -func (v *Visitor) VisitElementValuePairs(ctx *parser.ElementValuePairsContext) interface{} { +func (v *FormatVisitor) VisitElementValuePairs(ctx *parser.ElementValuePairsContext) interface{} { pairs := []string{v.visitRule(ctx.ElementValuePair()).(string)} for _, p := range ctx.AllDelimitedElementValuePair() { pairs = append(pairs, v.visitRule(p).(string)) @@ -1153,7 +1265,7 @@ func (v *Visitor) VisitElementValuePairs(ctx *parser.ElementValuePairsContext) i return strings.Join(pairs, "") } -func (v *Visitor) VisitDelimitedElementValuePair(ctx *parser.DelimitedElementValuePairContext) interface{} { +func (v *FormatVisitor) VisitDelimitedElementValuePair(ctx *parser.DelimitedElementValuePairContext) interface{} { delimiter := " " if ctx.COMMA() != nil { delimiter = ", " @@ -1161,15 +1273,15 @@ func (v *Visitor) VisitDelimitedElementValuePair(ctx *parser.DelimitedElementVal return fmt.Sprintf("%s%s", delimiter, v.visitRule(ctx.ElementValuePair())) } -func (v *Visitor) VisitElementValuePair(ctx *parser.ElementValuePairContext) interface{} { +func (v *FormatVisitor) VisitElementValuePair(ctx *parser.ElementValuePairContext) interface{} { return fmt.Sprintf("%s = %s", v.visitRule(ctx.Id()), v.visitRule(ctx.ElementValue())) } -func (v *Visitor) VisitElementValue(ctx *parser.ElementValueContext) interface{} { +func (v *FormatVisitor) VisitElementValue(ctx *parser.ElementValueContext) interface{} { return v.visitRule(ctx.GetChild(0).(antlr.RuleNode)) } -func (v *Visitor) VisitElementValueArrayInitializer(ctx *parser.ElementValueArrayInitializerContext) interface{} { +func (v *FormatVisitor) VisitElementValueArrayInitializer(ctx *parser.ElementValueArrayInitializerContext) interface{} { values := []string{} for _, val := range ctx.AllElementValue() { values = append(values, v.visitRule(val).(string)) @@ -1181,11 +1293,11 @@ func (v *Visitor) VisitElementValueArrayInitializer(ctx *parser.ElementValueArra return fmt.Sprintf("(%s%s)", strings.Join(values, ", "), trailingComma) } -func (v *Visitor) VisitFormalParameter(ctx *parser.FormalParameterContext) interface{} { +func (v *FormatVisitor) VisitFormalParameter(ctx *parser.FormalParameterContext) interface{} { return fmt.Sprintf("%s%s %s", v.Modifiers(ctx.AllModifier()), v.visitRule(ctx.TypeRef()), ctx.Id().GetText()) } -func (v *Visitor) VisitQualifiedName(ctx *parser.QualifiedNameContext) interface{} { +func (v *FormatVisitor) VisitQualifiedName(ctx *parser.QualifiedNameContext) interface{} { ids := []string{} for _, i := range ctx.AllId() { ids = append(ids, i.GetText()) @@ -1193,7 +1305,7 @@ func (v *Visitor) VisitQualifiedName(ctx *parser.QualifiedNameContext) interface return strings.Join(ids, ".") } -func (v *Visitor) VisitVariableDeclarators(ctx *parser.VariableDeclaratorsContext) interface{} { +func (v *FormatVisitor) VisitVariableDeclarators(ctx *parser.VariableDeclaratorsContext) interface{} { vars := []string{} for _, vd := range ctx.AllVariableDeclarator() { vars = append(vars, v.visitRule(vd).(string)) @@ -1201,15 +1313,18 @@ func (v *Visitor) VisitVariableDeclarators(ctx *parser.VariableDeclaratorsContex return strings.Join(vars, ", ") } -func (v *Visitor) VisitVariableDeclarator(ctx *parser.VariableDeclaratorContext) interface{} { +func (v *FormatVisitor) VisitVariableDeclarator(ctx *parser.VariableDeclaratorContext) interface{} { decl := ctx.Id().GetText() - if ctx.Expression() != nil { - decl = fmt.Sprintf("%s = %s", decl, v.visitRule(ctx.Expression())) + if ctx.Expression() == nil { + return decl + } + if v.wrap { + return fmt.Sprintf("%s =%s", decl, v.visitRule(ctx.Expression())) } - return decl + return fmt.Sprintf("%s = %s", decl, v.visitRule(ctx.Expression())) } -func (v *Visitor) VisitMethodDeclaration(ctx *parser.MethodDeclarationContext) interface{} { +func (v *FormatVisitor) VisitMethodDeclaration(ctx *parser.MethodDeclarationContext) interface{} { returnType := "void" if ctx.TypeRef() != nil { returnType = v.visitRule(ctx.TypeRef()).(string) @@ -1223,11 +1338,11 @@ func (v *Visitor) VisitMethodDeclaration(ctx *parser.MethodDeclarationContext) i body) } -func (v *Visitor) VisitTypeRefPrimary(ctx *parser.TypeRefPrimaryContext) interface{} { +func (v *FormatVisitor) VisitTypeRefPrimary(ctx *parser.TypeRefPrimaryContext) interface{} { return fmt.Sprintf("%s.class", v.visitRule(ctx.TypeRef())) } -func (v *Visitor) VisitTypeRef(ctx *parser.TypeRefContext) interface{} { +func (v *FormatVisitor) VisitTypeRef(ctx *parser.TypeRefContext) interface{} { typeNames := []string{} for _, t := range ctx.AllTypeName() { typeNames = append(typeNames, v.visitRule(t).(string)) @@ -1237,7 +1352,7 @@ func (v *Visitor) VisitTypeRef(ctx *parser.TypeRefContext) interface{} { return val } -func (v *Visitor) VisitTypeName(ctx *parser.TypeNameContext) interface{} { +func (v *FormatVisitor) VisitTypeName(ctx *parser.TypeNameContext) interface{} { typeName := "" if id := ctx.Id(); id != nil { typeName = v.visitRule(id).(string) @@ -1251,20 +1366,20 @@ func (v *Visitor) VisitTypeName(ctx *parser.TypeNameContext) interface{} { return fmt.Sprintf("%s%s", typeName, typeArguments) } -func (v *Visitor) VisitTypeArguments(ctx *parser.TypeArgumentsContext) interface{} { +func (v *FormatVisitor) VisitTypeArguments(ctx *parser.TypeArgumentsContext) interface{} { return fmt.Sprintf("<%s>", v.visitRule(ctx.TypeList())) } -func (v *Visitor) VisitSoslLiteral(ctx *parser.SoslLiteralContext) interface{} { +func (v *FormatVisitor) VisitSoslLiteral(ctx *parser.SoslLiteralContext) interface{} { if ctx.BoundExpression() != nil { return fmt.Sprintf("[\n%s]", - indent(fmt.Sprintf("FIND\n%s%s", indent(v.visitRule(ctx.BoundExpression()).(string)), v.visitRule(ctx.SoslClauses()))), + v.indent(fmt.Sprintf("FIND\n%s%s", v.indent(v.visitRule(ctx.BoundExpression()).(string)), v.visitRule(ctx.SoslClauses()))), ) } return fmt.Sprintf("%s %s ]", ctx.GetChild(0).(antlr.TerminalNode).GetText(), v.visitRule(ctx.SoslClauses())) } -func (v *Visitor) VisitSoslClauses(ctx *parser.SoslClausesContext) interface{} { +func (v *FormatVisitor) VisitSoslClauses(ctx *parser.SoslClausesContext) interface{} { var clauses strings.Builder if i := ctx.InSearchGroup(); i != nil { clauses.WriteString(fmt.Sprintf("\n%s", v.visitRule(i))) @@ -1302,19 +1417,19 @@ func (v *Visitor) VisitSoslClauses(ctx *parser.SoslClausesContext) interface{} { return clauses.String() } -func (v *Visitor) VisitInSearchGroup(ctx *parser.InSearchGroupContext) interface{} { +func (v *FormatVisitor) VisitInSearchGroup(ctx *parser.InSearchGroupContext) interface{} { return fmt.Sprintf("IN %s", v.visitRule(ctx.SearchGroup())) } -func (v *Visitor) VisitSearchGroup(ctx *parser.SearchGroupContext) interface{} { +func (v *FormatVisitor) VisitSearchGroup(ctx *parser.SearchGroupContext) interface{} { return fmt.Sprintf("%s FIELDS", strings.ToUpper(ctx.GetChild(0).(antlr.TerminalNode).GetText())) } -func (v *Visitor) VisitReturningFieldSpecList(ctx *parser.ReturningFieldSpecListContext) interface{} { +func (v *FormatVisitor) VisitReturningFieldSpecList(ctx *parser.ReturningFieldSpecListContext) interface{} { return fmt.Sprintf("RETURNING %s", v.visitRule(ctx.FieldSpecList())) } -func (v *Visitor) VisitFieldSpecList(ctx *parser.FieldSpecListContext) interface{} { +func (v *FormatVisitor) VisitFieldSpecList(ctx *parser.FieldSpecListContext) interface{} { list := []string{v.visitRule(ctx.FieldSpec()).(string)} for _, f := range ctx.AllFieldSpecList() { list = append(list, v.visitRule(f).(string)) @@ -1322,18 +1437,18 @@ func (v *Visitor) VisitFieldSpecList(ctx *parser.FieldSpecListContext) interface return strings.Join(list, ",\n") } -func (v *Visitor) VisitFieldSpec(ctx *parser.FieldSpecContext) interface{} { +func (v *FormatVisitor) VisitFieldSpec(ctx *parser.FieldSpecContext) interface{} { if ctx.FieldSpecClauses() == nil { return v.visitRule(ctx.SoslId()) } return fmt.Sprintf("%s%s", v.visitRule(ctx.SoslId()), v.visitRule(ctx.FieldSpecClauses())) } -func (v *Visitor) VisitFieldSpecClauses(ctx *parser.FieldSpecClausesContext) interface{} { +func (v *FormatVisitor) VisitFieldSpecClauses(ctx *parser.FieldSpecClausesContext) interface{} { var clauses strings.Builder - clauses.WriteString(fmt.Sprintf("(\n%s", indent(v.visitRule(ctx.FieldList()).(string)))) + clauses.WriteString(fmt.Sprintf("(\n%s", v.indent(v.visitRule(ctx.FieldList()).(string)))) if i := ctx.LogicalExpression(); i != nil { - clauses.WriteString(fmt.Sprintf("\nWHERE\n%s", indent(v.visitRule(i).(string)))) + clauses.WriteString(fmt.Sprintf("\nWHERE\n%s", v.indent(v.visitRule(i).(string)))) } if i := ctx.SoslId(); i != nil { clauses.WriteString(fmt.Sprintf("\nUSING LISTVIEW = %s", v.visitRule(i))) @@ -1351,7 +1466,7 @@ func (v *Visitor) VisitFieldSpecClauses(ctx *parser.FieldSpecClausesContext) int return clauses.String() } -func (v *Visitor) VisitFieldList(ctx *parser.FieldListContext) interface{} { +func (v *FormatVisitor) VisitFieldList(ctx *parser.FieldListContext) interface{} { list := []string{v.visitRule(ctx.SoslId()).(string)} for _, f := range ctx.AllFieldList() { list = append(list, v.visitRule(f).(string)) @@ -1359,7 +1474,7 @@ func (v *Visitor) VisitFieldList(ctx *parser.FieldListContext) interface{} { return strings.Join(list, ",\n") } -func (v *Visitor) VisitSoslId(ctx *parser.SoslIdContext) interface{} { +func (v *FormatVisitor) VisitSoslId(ctx *parser.SoslIdContext) interface{} { list := []string{v.visitRule(ctx.Id()).(string)} for _, f := range ctx.AllSoslId() { list = append(list, v.visitRule(f).(string)) diff --git a/go.mod b/go.mod index f9c931c..9207569 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.21.4 require ( github.com/antlr4-go/antlr/v4 v4.13.0 + github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.8.0 ) @@ -14,6 +15,7 @@ require ( github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc // indirect + golang.org/x/sys v0.1.0 // indirect gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 89a60ef..13f6dcf 100644 --- a/go.sum +++ b/go.sum @@ -2,6 +2,9 @@ github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8 github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/cpuguy83/go-md2man/v2 v2.0.3 h1:qMCsGGgs+MAzDFyp9LpAe1Lqy/fY/qCovCm0qnXZOBM= github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= @@ -9,16 +12,27 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc h1:mCRnTeVUjcrhlRmO0VK8a6k6Rrf6TF9htwo2pJVSjIU= golang.org/x/exp v0.0.0-20230515195305-f3d0a9c9a5cc/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=