Skip to content

Commit

Permalink
Improve Wrapping/Indenting
Browse files Browse the repository at this point in the history
Add separate visitors to help figure out when we need to wrap and how
deep we need to indent.

Don't wrap simple SOQL queries.

Don't wrap many very short arguments.

Ident lines after wrapped logical operators.

Indent wrapped method arguments, but not assignment expressions in
constructor calls.

Collapse class blocks with no members.

Fix wrapping around equality statements. fix formatting of map
initializers.

Rename reformatting Visitor to FormatVisitor.

Fix preservation of comments before literals and ids.

Add tests.
  • Loading branch information
cwarden committed Dec 8, 2023
1 parent 386610e commit 56cf4a4
Show file tree
Hide file tree
Showing 9 changed files with 838 additions and 253 deletions.
140 changes: 140 additions & 0 deletions formatter/chain.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package formatter

import (
"fmt"

"github.com/antlr4-go/antlr/v4"
"github.com/octoberswimmer/apexfmt/parser"
log "github.com/sirupsen/logrus"
)

type ChainVisitor struct {
parser.BaseApexParserVisitor
}

func NewChainVisitor() *ChainVisitor {
return &ChainVisitor{}
}

func (v *ChainVisitor) visitRule(node antlr.RuleNode) interface{} {
result := node.Accept(v)
if r, ok := result.(int); ok {
return r
}
if result == nil {
log.Debug(fmt.Sprintf("missing ChainVisitor function for %T", node))
}
return 0
}

func (v *ChainVisitor) VisitStatement(ctx *parser.StatementContext) interface{} {
child := ctx.GetChild(0).(antlr.RuleNode)
return v.visitRule(child)
}

func (v *ChainVisitor) VisitExpressionStatement(ctx *parser.ExpressionStatementContext) interface{} {
return v.visitRule(ctx.Expression())
}

func (v *ChainVisitor) VisitEqualityExpression(ctx *parser.EqualityExpressionContext) interface{} {
return 1 + v.visitRule(ctx.Expression(0)).(int) + v.visitRule(ctx.Expression(1)).(int)
}

func (v *ChainVisitor) VisitPrimaryExpression(ctx *parser.PrimaryExpressionContext) interface{} {
switch e := ctx.Primary().(type) {
case *parser.ThisPrimaryContext:
return 0
case *parser.SuperPrimaryContext:
return 0
case *parser.LiteralPrimaryContext:
return 0
case *parser.TypeRefPrimaryContext:
return 0
case *parser.IdPrimaryContext:
return 0
case *parser.SoqlPrimaryContext:
return 1
case *parser.SoslPrimaryContext:
return 1
default:
return fmt.Sprintf("UNHANDLED PRIMARY EXPRESSION: %T %s", e, e.GetText())
}
}

func (v *ChainVisitor) VisitDotExpression(ctx *parser.DotExpressionContext) interface{} {
if ctx.DotMethodCall() != nil {
return 1 + v.visitRule(ctx.Expression()).(int)
}
return v.visitRule(ctx.Expression())
}

func (v *ChainVisitor) VisitLogAndExpression(ctx *parser.LogAndExpressionContext) interface{} {
return 1 + v.visitRule(ctx.Expression(0)).(int) + v.visitRule(ctx.Expression(1)).(int)
}

func (v *ChainVisitor) VisitLogOrExpression(ctx *parser.LogOrExpressionContext) interface{} {
return 1 + v.visitRule(ctx.Expression(0)).(int) + v.visitRule(ctx.Expression(1)).(int)
}

func (v *ChainVisitor) VisitSubExpression(ctx *parser.SubExpressionContext) interface{} {
return 1 + v.visitRule(ctx.Expression()).(int)
}

func (v *ChainVisitor) VisitQuery(ctx *parser.QueryContext) interface{} {
score := v.visitRule(ctx.SelectList()).(int) + v.visitRule(ctx.FromNameList()).(int)
if scope := ctx.UsingScope(); scope != nil {
score++
}
if where := ctx.WhereClause(); where != nil {
score += v.visitRule(where).(int)
}
if groupBy := ctx.GroupByClause(); groupBy != nil {
score += v.visitRule(groupBy).(int)
}
if orderBy := ctx.OrderByClause(); orderBy != nil {
score += v.visitRule(orderBy).(int)
}
if limit := ctx.LimitClause(); limit != nil {
score++
}
if offset := ctx.OffsetClause(); offset != nil {
score++
}
score += v.visitRule(ctx.ForClauses()).(int)
if update := ctx.UpdateList(); update != nil {
score++
}
return score
}

func (v *ChainVisitor) VisitSelectList(ctx *parser.SelectListContext) interface{} {
score := 0
for _, p := range ctx.AllSelectEntry() {
score += v.visitRule(p).(int)
}
return score
}

func (v *ChainVisitor) VisitSelectEntry(ctx *parser.SelectEntryContext) interface{} {
if ctx.SubQuery() != nil {
// estimate complexity; probably close enough
return 5
}
return 1
}

func (v *ChainVisitor) VisitFromNameList(ctx *parser.FromNameListContext) interface{} {
return len(ctx.AllFieldNameAlias())
}

func (v *ChainVisitor) VisitWhereClause(ctx *parser.WhereClauseContext) interface{} {
return v.visitRule(ctx.LogicalExpression())
}

func (v *ChainVisitor) VisitLogicalExpression(ctx *parser.LogicalExpressionContext) interface{} {
return len(ctx.AllSOQLOR()) + len(ctx.AllSOQLAND()) + 1
}

func (v *ChainVisitor) VisitForClauses(ctx *parser.ForClausesContext) interface{} {
return len(ctx.AllForClause())
}
43 changes: 43 additions & 0 deletions formatter/chain_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package formatter

import (
"testing"

"github.com/antlr4-go/antlr/v4"
"github.com/octoberswimmer/apexfmt/parser"

log "github.com/sirupsen/logrus"
)

func TestChain(t *testing.T) {
if testing.Verbose() {
log.SetLevel(log.DebugLevel)

}
tests :=
[]struct {
input string
output int
}{
{`Schema.SObjectType.Account.getRecordTypeInfosByDeveloperName().get('Business').getRecordTypeId()`, 3},
{`Fixtures.Contact(account).put(Contact.RecordTypeId, Schema.SObjectType.Contact.getRecordTypeInfosByDeveloperName().get('Person').getRecordTypeId()).put(Contact.My_Lookup__c, newRecord[0].Id).save()`, 4},
}

for _, tt := range tests {
input := antlr.NewInputStream(tt.input)
lexer := parser.NewApexLexer(input)
stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)

p := parser.NewApexParser(stream)
p.RemoveErrorListeners()

v := NewChainVisitor()
out, ok := v.visitRule(p.Statement()).(int)
if !ok {
t.Errorf("Unexpected result parsing apex")
}
if out != tt.output {
t.Errorf("unexpected result. expected: %d; got: %d", tt.output, out)
}
}
}
204 changes: 204 additions & 0 deletions formatter/format_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,204 @@
package formatter

import (
"testing"

"github.com/antlr4-go/antlr/v4"
"github.com/octoberswimmer/apexfmt/parser"

log "github.com/sirupsen/logrus"
)

func TestStatement(t *testing.T) {
if testing.Verbose() {
log.SetLevel(log.DebugLevel)

}
tests :=
[]struct {
input string
output string
}{
{
`Account a = new Account(Name='Acme', RecordTypeId=myRecordTypeId, BillingCity='Los Angeles', BillingState = 'CA');`,
`Account a = new Account(
Name = 'Acme',
RecordTypeId = myRecordTypeId,
BillingCity = 'Los Angeles',
BillingState = 'CA'
);`},

{
`System.runAs(u) {
facility = Fixtures.account().put(Schema.Account.RecordTypeId, facilityRecordType).put(Schema.Account.HealthCloudGA__SourceSystemId__c, '0001')
.put(Schema.Account.Patient_Owner__c, u.Id)
.save();
}`, `System.runAs(u) {
facility = Fixtures.account()
.put(Schema.Account.RecordTypeId, facilityRecordType)
.put(Schema.Account.HealthCloudGA__SourceSystemId__c, '0001')
.put(Schema.Account.Patient_Owner__c, u.Id)
.save();
}`},

{

`System.assertEquals(UserInfo.getUserId(), [SELECT OwnerId FROM Account WHERE Id = :person.Id].OwnerId, 'Account should be owned by correct user');`,
`System.assertEquals(UserInfo.getUserId(), [SELECT OwnerId FROM Account WHERE Id = :person.Id].OwnerId, 'Account should be owned by correct user');`},
{
`System.assert(lsr[0].getErrors()[0].getMessage().contains(constants.ERR_MSG_NO_CLIENT_DEMOGRAPHICS), 'error message');`,
`System.assert(lsr[0].getErrors()[0].getMessage().contains(constants.ERR_MSG_NO_CLIENT_DEMOGRAPHICS), 'error message');`,
},

{
`RecordType referralType = [ SELECT Id FROM RecordType WHERE SobjectType = 'Contact' AND DeveloperName = 'Referral_Contact' ];`,
`RecordType referralType = [
SELECT
Id
FROM
RecordType
WHERE
SobjectType = 'Contact' AND
DeveloperName = 'Referral_Contact'
];`},

{
`update [SELECT Id FROM Territory_Coverage__c WHERE Named_Account__c IN :accountIds];`,
`update [SELECT Id FROM Territory_Coverage__c WHERE Named_Account__c IN :accountIds];`,
},

{
`for (Referral__c ref : [SELECT Summary_Name__c, Name FROM Referral__c WHERE Id IN :referralIdSet]) {
System.assertEquals(ref.Name, ref.Summary_Name__c);
}`,
`for (Referral__c ref : [
SELECT
Summary_Name__c,
Name
FROM
Referral__c
WHERE
Id IN :referralIdSet
]) {
System.assertEquals(ref.Name, ref.Summary_Name__c);
}`},

{
`for (Referral__c ref : [SELECT Name FROM Referral__c WHERE Id IN :referralIdSet]) {
System.assertEquals('test', ref.Summary_Name__c);
}`,
`for (Referral__c ref : [SELECT Name FROM Referral__c WHERE Id IN :referralIdSet]) {
System.assertEquals('test', ref.Summary_Name__c);
}`},

{
`if (!r.isSuccess()) {
throw new BenefitCheckNotificationException(
'Failed to send Benefit Check notification. First error: ' +
r.getErrors()[0].getMessage()
);
}`,
`if (!r.isSuccess()) {
throw new BenefitCheckNotificationException(
'Failed to send Benefit Check notification. First error: ' +
r.getErrors()[0].getMessage()
);
}`},

{
`return new List<CountryZip>{
new CountryZip( new Territory_Zip_Lookup__c( Id = zip.Id, Name = zip.Name, City__c = zip.City__c, State_2_Letter_Code__c = zip.State_2_Letter_Code__c, Country__c = zip.Country__c))
};`,
`return new List<CountryZip>{
new CountryZip(
new Territory_Zip_Lookup__c(
Id = zip.Id,
Name = zip.Name,
City__c = zip.City__c,
State_2_Letter_Code__c = zip.State_2_Letter_Code__c,
Country__c = zip.Country__c
)
)
};`},

{
`Psychological__c psyc = Fixtures.psychological(inq).put(Psychological__c.RecordTypeId, Schema.SObjectType.Psychological__c.getRecordTypeInfosByDeveloperName().get('ICD_10').getRecordTypeId()).put(Psychological__c.Diagnosis_Lookup__c, newDiagnosis[0].Id).save();`,
`Psychological__c psyc = Fixtures.psychological(inq)
.put(Psychological__c.RecordTypeId, Schema.SObjectType.Psychological__c
.getRecordTypeInfosByDeveloperName()
.get('ICD_10')
.getRecordTypeId())
.put(Psychological__c.Diagnosis_Lookup__c, newDiagnosis[0].Id)
.save();`},

{
`this.assertPassed(Assert.isNumericallyInner(2, '~0.5', 2.4, null));`,
`this.assertPassed(Assert.isNumericallyInner(2, '~0.5', 2.4, null));`},

{
`assertFailed(Assert.consistOfInner(new List<Object>{ a, b }, new List<Object>{ b, c }, 'doodle'), 'doodle: expected (1, 2) to consist of (2, 3)\nextra elements:\n\t(1)\nmissing elements:\n\t(3)');`,
`assertFailed(Assert.consistOfInner(new List<Object>{ a, b }, new List<Object>{ b, c }, 'doodle'),
'doodle: expected (1, 2) to consist of (2, 3)\nextra elements:\n\t(1)\nmissing elements:\n\t(3)');`},

{
`if (cl_record.Last_Placement__c == true &&
(Trigger.isInsert || (Trigger.isUpdate && cl_record.Last_Placement__c != Trigger.OldMap.get(cl_record.Id).Last_Placement__c))) {x=1;}`,
`if (cl_record.Last_Placement__c == true &&
(Trigger.isInsert ||
(Trigger.isUpdate &&
cl_record.Last_Placement__c != Trigger.OldMap.get(cl_record.Id).Last_Placement__c))) {
x = 1;
}`},
}
for _, tt := range tests {
input := antlr.NewInputStream(tt.input)
lexer := parser.NewApexLexer(input)
stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)

p := parser.NewApexParser(stream)
p.RemoveErrorListeners()

v := NewFormatVisitor(stream)
out, ok := v.visitRule(p.Statement()).(string)
if !ok {
t.Errorf("Unexpected result parsing apex")
}
if out != tt.output {
t.Errorf("unexpected format. expected:\n%s;\ngot:\n%s\n", tt.output, out)
}
}

}

func TestCompilationUnit(t *testing.T) {
if testing.Verbose() {
log.SetLevel(log.DebugLevel)

}
tests :=
[]struct {
input string
output string
}{
{
`private class T1Exception extends Exception {}`,
`private class T1Exception extends Exception {}`},
}
for _, tt := range tests {
input := antlr.NewInputStream(tt.input)
lexer := parser.NewApexLexer(input)
stream := antlr.NewCommonTokenStream(lexer, antlr.TokenDefaultChannel)

p := parser.NewApexParser(stream)
p.RemoveErrorListeners()

v := NewFormatVisitor(stream)
out, ok := v.visitRule(p.CompilationUnit()).(string)
if !ok {
t.Errorf("Unexpected result parsing apex")
}
if out != tt.output {
t.Errorf("unexpected format. expected:\n%s;\ngot:\n%s\n", tt.output, out)
}
}
}
2 changes: 1 addition & 1 deletion formatter/formatter.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func (f *Formatter) Format() error {
p.AddErrorListener(&errorListener{filename: f.filename})
// p.AddErrorListener(antlr.NewDiagnosticErrorListener(false))

v := NewVisitor(stream)
v := NewFormatVisitor(stream)
out, ok := v.visitRule(p.CompilationUnit()).(string)
if !ok {
return fmt.Errorf("Unexpected result parsing apex")
Expand Down
Loading

0 comments on commit 56cf4a4

Please sign in to comment.