diff --git a/packages/go/cypher/models/pgsql/format/format.go b/packages/go/cypher/models/pgsql/format/format.go index 0b09b19c4..86454300a 100644 --- a/packages/go/cypher/models/pgsql/format/format.go +++ b/packages/go/cypher/models/pgsql/format/format.go @@ -429,7 +429,17 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error { exprStack = append(exprStack, *typedNextExpr) case pgsql.TypeCast: - switch typedNextExpr.Expression.(type) { + switch typedCastedExpr := typedNextExpr.Expression.(type) { + case *pgsql.BinaryExpression: + if typedCastedExpr.Operator == pgsql.OperatorJSONTextField && typedNextExpr.CastType == pgsql.Text { + // Avoid formatting property lookups wrapped in text type casts + exprStack = append(exprStack, typedNextExpr.Expression) + } else { + exprStack = append(exprStack, pgsql.FormattingLiteral(typedNextExpr.CastType), pgsql.FormattingLiteral(")::")) + exprStack = append(exprStack, typedNextExpr.Expression) + exprStack = append(exprStack, pgsql.FormattingLiteral("(")) + } + case pgsql.Parenthetical: // Avoid formatting type-casted parenthetical statements as (('test'))::text - this should instead look like ('test')::text exprStack = append(exprStack, pgsql.FormattingLiteral(typedNextExpr.CastType), pgsql.FormattingLiteral("::")) diff --git a/packages/go/cypher/models/pgsql/operators.go b/packages/go/cypher/models/pgsql/operators.go index 6034ad4b3..c260e715f 100644 --- a/packages/go/cypher/models/pgsql/operators.go +++ b/packages/go/cypher/models/pgsql/operators.go @@ -71,6 +71,13 @@ func OperatorIsPropertyLookup(operator Expression) bool { ) } +func OperatorIsComparator(operator Expression) bool { + return OperatorIsIn(operator, + OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, + OperatorLessThanOrEqualTo, OperatorArrayOverlap, OperatorLike, OperatorILike, OperatorPGArrayOverlap, + OperatorRegexMatch, OperatorSimilarTo) +} + const ( UnsetOperator Operator = "" OperatorUnion Operator = "union" @@ -107,6 +114,7 @@ const ( OperatorCypherStartsWith Operator = "starts with" OperatorCypherContains Operator = "contains" OperatorCypherEndsWith Operator = "ends with" + OperatorCypherAdd Operator = "+" OperatorPropertyLookup Operator = "property_lookup" OperatorKindAssignment Operator = "kind_assignment" diff --git a/packages/go/cypher/models/pgsql/pgtypes.go b/packages/go/cypher/models/pgsql/pgtypes.go index c50d8656b..1d726a169 100644 --- a/packages/go/cypher/models/pgsql/pgtypes.go +++ b/packages/go/cypher/models/pgsql/pgtypes.go @@ -26,7 +26,6 @@ import ( var ( ErrNoAvailableArrayDataType = errors.New("data type has no direct array representation") - ErrNonArrayDataType = errors.New("data type is not an array type") ) const ( @@ -67,34 +66,42 @@ func (s DataType) NodeType() string { } const ( - UnsetDataType DataType = "" - UnknownDataType DataType = "UNKNOWN" - Reference DataType = "REFERENCE" - Null DataType = "NULL" - NodeComposite DataType = "nodecomposite" - NodeCompositeArray DataType = "nodecomposite[]" - EdgeComposite DataType = "edgecomposite" - EdgeCompositeArray DataType = "edgecomposite[]" - PathComposite DataType = "pathcomposite" - Int DataType = "int" - IntArray DataType = "int[]" - Int2 DataType = "int2" - Int2Array DataType = "int2[]" - Int4 DataType = "int4" - Int4Array DataType = "int4[]" - Int8 DataType = "int8" - Int8Array DataType = "int8[]" - Float4 DataType = "float4" - Float4Array DataType = "float4[]" - Float8 DataType = "float8" - Float8Array DataType = "float8[]" - Boolean DataType = "bool" - Text DataType = "text" - TextArray DataType = "text[]" - JSONB DataType = "jsonb" - JSONBArray DataType = "jsonb[]" - Numeric DataType = "numeric" - NumericArray DataType = "numeric[]" + // UnsetDataType represents a DataType that has not been visited by any logic. It is the default, zero-value for + // the DataType type. + UnsetDataType DataType = "" + + // UnknownDataType represents a DataType that has been visited by type inference logic but remains unknowable. + UnknownDataType DataType = "unknown" + + Null DataType = "null" + Any DataType = "any" + NodeComposite DataType = "nodecomposite" + EdgeComposite DataType = "edgecomposite" + PathComposite DataType = "pathcomposite" + Int DataType = "int" + Int2 DataType = "int2" + Int4 DataType = "int4" + Int8 DataType = "int8" + Float4 DataType = "float4" + Float8 DataType = "float8" + Boolean DataType = "bool" + Text DataType = "text" + JSONB DataType = "jsonb" + Numeric DataType = "numeric" + + AnyArray DataType = "any[]" + NodeCompositeArray DataType = "nodecomposite[]" + EdgeCompositeArray DataType = "edgecomposite[]" + IntArray DataType = "int[]" + Int2Array DataType = "int2[]" + Int4Array DataType = "int4[]" + Int8Array DataType = "int8[]" + Float4Array DataType = "float4[]" + Float8Array DataType = "float8[]" + TextArray DataType = "text[]" + JSONBArray DataType = "jsonb[]" + NumericArray DataType = "numeric[]" + Date DataType = "date" TimeWithTimeZone DataType = "time with time zone" TimeWithoutTimeZone DataType = "time without time zone" @@ -121,160 +128,182 @@ func (s DataType) IsKnown() bool { } } -// TODO: operator, while unused, is part of a refactor for this function to make it operator aware -func (s DataType) Compatible(other DataType, operator Operator) (DataType, bool) { - if s == other { - return s, true - } +func (s DataType) IsComparable(other DataType, operator Operator) bool { + switch operator { + case OperatorPGArrayOverlap, OperatorArrayOverlap: + if !s.IsArrayType() || !other.IsArrayType() { + return false + } - if other == UnknownDataType { - // Assume unknown data types will offload type matching to the DB - return s, true - } + return s == other - switch s { - case UnknownDataType: - // Assume unknown data types will offload type matching to the DB - return other, true + case OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo: + switch s { + case NodeComposite, EdgeComposite, PathComposite, JSONB, AnyArray, Text, Boolean, + IntArray, Int8Array, Int4Array, Int2Array, Float8Array, Float4Array, NumericArray, TextArray, + Date, TimeWithTimeZone, TimeWithoutTimeZone, Interval, TimestampWithTimeZone, TimestampWithoutTimeZone: + return other == s - case Text: - return Text, true + case Int, Int8, Int4, Int2: + switch other { + case Int, Int8, Int4, Int2, Float8, Float4, Numeric: + return true - case Float4: - switch other { - case Float8: - return Float8, true + default: + return false + } - case Float4Array: - return Float4, true + case Float8, Float4, Numeric: + switch other { + case Int, Int8, Int4, Int2, Float8, Float4, Numeric: + return true - case Float8Array: - return Float8, true + default: + return false + } - case Text: - return Text, true + default: + return false } - case Float8: - switch other { - case Float4: - return Float8, true - - case Float4Array, Float8Array: - return Float8, true - + case OperatorLike, OperatorILike, OperatorSimilarTo, OperatorRegexMatch: + switch s { case Text: - return Text, true + return other == s + default: + return false } - case Numeric: - switch other { - case Float4, Float8, Int2, Int4, Int8: - return Numeric, true + default: + return false + } +} - case Float4Array, Float8Array, NumericArray: - return Numeric, true +// CoerceToSupertype attempts to take the super of the type s and the type other +func (s DataType) CoerceToSupertype(other DataType) (DataType, bool) { + switch other { + case UnknownDataType: + // If the other data type is unknown then assume this data type as the super type + return s, true + } - case Text: - return Text, true - } + switch s { + case UnknownDataType: + // If this data type is unknown then assume the other type presented as the super type + return other, true case Int2: switch other { case Int2: - return Int2, true - - case Int4: - return Int4, true - - case Int8: - return Int8, true + return s, true - case Int2Array: - return Int2, true - - case Int4Array: - return Int4, true - - case Int8Array: - return Int8, true - - case Text: - return Text, true + case Int, Int8, Int4: + return other, true } case Int4: switch other { - case Int2, Int4: - return Int4, true - - case Int8: - return Int8, true - - case Int2Array, Int4Array: - return Int4, true - - case Int8Array: - return Int8, true + case Int4, Int2: + return s, true - case Text: - return Text, true + case Int, Int8: + return other, true } case Int8: switch other { - case Int2, Int4, Int8: - return Int8, true - - case Int2Array, Int4Array, Int8Array: - return Int8, true - - case Text: - return Text, true + case Int, Int8, Int4, Int2: + return s, true } case Int: switch other { - case Int2, Int4, Int: - return Int, true + case Int: + return s, true case Int8: - return Int8, true - - case Text: - return Text, true + return other, true } - case Int2Array: + case Float4: switch other { - case Int2Array, Int4Array, Int8Array: + case Float4, Float8, Numeric: return other, true } - case Int4Array: + case Float8: switch other { - case Int4Array, Int8Array: + case Float4: + return s, true + + case Float8, Numeric: return other, true } - case Float4Array: + case Numeric: switch other { - case Float4Array, Float8Array: + case Float4, Float8, Int8, Int, Int4, Int2: + return s, true + + case Numeric: return other, true } } - return UnsetDataType, false + // Otherwise unable to identify a super type + return UnknownDataType, false } -func (s DataType) TextConvertable() bool { - switch s { - case TimestampWithoutTimeZone, TimestampWithTimeZone, TimeWithoutTimeZone, TimeWithTimeZone, Date, Text: - return true +func (s DataType) OperatorResultType(other DataType, operator Operator) (DataType, bool) { + if OperatorIsComparator(operator) && s.IsComparable(other, operator) { + return Boolean, true + } - default: - return false + // Validate all other supported operators for result type inference + switch operator { + case OperatorAnd, OperatorOr: + return Boolean, true + + case OperatorAdd, OperatorSubtract, OperatorMultiply, OperatorDivide: + if s == other { + return s, true + } + + if supertype, validSupertype := s.CoerceToSupertype(other); validSupertype { + return supertype, true + } + + case OperatorConcatenate: + // Array types may only concatenate if their base types match + if s.IsArrayType() { + return s, s == other || s.ArrayBaseType() == other + } + + if other.IsArrayType() { + return other, s == other || s == other.ArrayBaseType() + } + + switch s { + case UnknownDataType: + // Overwrite the unknown data type here and assume that it will resolve correctly + return other, true + + case Text: + switch other { + case UnknownDataType: + // Overwrite the unknown data type here and assume that it will resolve to text + return s, true + + default: + return s, s == other + } + + default: + return UnknownDataType, false + } } + + return UnknownDataType, false } func (s DataType) MatchesOneOf(others ...DataType) bool { @@ -289,24 +318,14 @@ func (s DataType) MatchesOneOf(others ...DataType) bool { func (s DataType) IsArrayType() bool { switch s { - case Int2Array, Int4Array, Int8Array, Float4Array, Float8Array, TextArray, JSONBArray, NodeCompositeArray, EdgeCompositeArray, NumericArray: + case Int2Array, Int4Array, Int8Array, IntArray, Float4Array, Float8Array, TextArray, JSONBArray, + NodeCompositeArray, EdgeCompositeArray, NumericArray: return true } return false } -func (s DataType) ToUpdateResultType() (DataType, error) { - switch s { - case NodeComposite: - return s, nil - case EdgeComposite: - return s, nil - default: - return UnsetDataType, fmt.Errorf("data type %s has no update result representation", s) - } -} - func (s DataType) ToArrayType() (DataType, error) { switch s { case Int2, Int2Array: @@ -315,6 +334,16 @@ func (s DataType) ToArrayType() (DataType, error) { return Int4Array, nil case Int8, Int8Array: return Int8Array, nil + case Int, IntArray: + return IntArray, nil + case Any, AnyArray: + return AnyArray, nil + case JSONB, JSONBArray: + return JSONBArray, nil + case NodeComposite, NodeCompositeArray: + return NodeCompositeArray, nil + case EdgeComposite, EdgeCompositeArray: + return EdgeCompositeArray, nil case Float4, Float4Array: return Float4Array, nil case Float8, Float8Array: @@ -328,24 +357,32 @@ func (s DataType) ToArrayType() (DataType, error) { } } -func (s DataType) ArrayBaseType() (DataType, error) { +func (s DataType) ArrayBaseType() DataType { switch s { case Int2Array: - return Int2, nil + return Int2 case Int4Array: - return Int4, nil + return Int4 case Int8Array: - return Int8, nil + return Int8 case Float4Array: - return Float4, nil + return Float4 case Float8Array: - return Float8, nil + return Float8 case TextArray: - return Text, nil + return Text case NumericArray: - return Numeric, nil + return Numeric + case JSONBArray: + return JSONB + case AnyArray: + return Any + case NodeCompositeArray: + return NodeComposite + case EdgeCompositeArray: + return EdgeComposite default: - return s, nil + return s } } diff --git a/packages/go/cypher/models/pgsql/pytypes_test.go b/packages/go/cypher/models/pgsql/pytypes_test.go index 2ab0a6c15..06cf2696c 100644 --- a/packages/go/cypher/models/pgsql/pytypes_test.go +++ b/packages/go/cypher/models/pgsql/pytypes_test.go @@ -24,6 +24,220 @@ import ( "github.com/stretchr/testify/require" ) +func TestDataType_CoerceToSupertype(t *testing.T) { + testCases := []struct { + LeftTypes []DataType + RightTypes []DataType + Expected DataType + ExpectRightType bool + }{{ + LeftTypes: []DataType{UnknownDataType}, + RightTypes: []DataType{Int}, + Expected: Int, + }, { + LeftTypes: []DataType{Int}, + RightTypes: []DataType{UnknownDataType}, + Expected: Int, + }, { + LeftTypes: []DataType{Int8}, + RightTypes: []DataType{Int2, Int4, Int, Int8}, + Expected: Int8, + }, { + LeftTypes: []DataType{Int4}, + RightTypes: []DataType{Int2, Int4}, + Expected: Int4, + }, { + LeftTypes: []DataType{Int4}, + RightTypes: []DataType{Int}, + Expected: Int, + }, { + LeftTypes: []DataType{Int4}, + RightTypes: []DataType{Int8}, + Expected: Int8, + }, { + LeftTypes: []DataType{Int2}, + RightTypes: []DataType{Int2, Int4, Int, Int8}, + ExpectRightType: true, + }, { + LeftTypes: []DataType{Int}, + RightTypes: []DataType{Int, Int8}, + ExpectRightType: true, + }, { + LeftTypes: []DataType{Float4}, + RightTypes: []DataType{Float4, Float8, Numeric}, + ExpectRightType: true, + }, { + LeftTypes: []DataType{Float8}, + RightTypes: []DataType{Float4}, + Expected: Float8, + }, { + LeftTypes: []DataType{Float8}, + RightTypes: []DataType{Float8, Numeric}, + ExpectRightType: true, + }, { + LeftTypes: []DataType{Numeric}, + RightTypes: []DataType{Numeric, Float8, Float4, Int8, Int, Int4, Int2}, + Expected: Numeric, + }} + + for _, testCase := range testCases { + for _, leftType := range testCase.LeftTypes { + for _, rightType := range testCase.RightTypes { + superType, coerced := leftType.CoerceToSupertype(rightType) + + if !coerced { + t.Fatalf("coercing left type %s to right type %s failed", leftType, rightType) + } + + if testCase.ExpectRightType { + require.Equalf(t, rightType, superType, "expected type %s does not match super type %s", rightType, superType) + } else { + require.Equalf(t, testCase.Expected, superType, "expected type %s does not match super type %s", testCase.Expected, superType) + } + + } + } + } +} + +func TestDataType_Comparable(t *testing.T) { + testCases := []struct { + LeftTypes []DataType + Operators []Operator + RightTypes []DataType + Expected bool + }{ + // Supported comparisons + { + LeftTypes: []DataType{Int, Int8, Int4, Int2}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{Int, Int8, Int4, Int2, Float8, Float4, Numeric}, + Expected: true, + }, + { + LeftTypes: []DataType{Float8, Float4, Numeric}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{Int, Int8, Int4, Int2, Float8, Float4, Numeric}, + Expected: true, + }, + { + LeftTypes: []DataType{NodeComposite}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{NodeComposite}, + Expected: true, + }, + { + LeftTypes: []DataType{EdgeComposite}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{EdgeComposite}, + Expected: true, + }, + { + LeftTypes: []DataType{PathComposite}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{PathComposite}, + Expected: true, + }, + { + LeftTypes: []DataType{JSONB}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{JSONB}, + Expected: true, + }, + { + LeftTypes: []DataType{AnyArray}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{AnyArray}, + Expected: true, + }, + { + LeftTypes: []DataType{Text}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{Text}, + Expected: true, + }, + { + LeftTypes: []DataType{Boolean}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{Boolean}, + Expected: true, + }, + + // Right hand unknown types should not be comparable against any left hand int type + { + LeftTypes: []DataType{Int}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{UnknownDataType}, + Expected: false, + }, + + // Right hand unknown types should not be comparable against any left hand float type + { + LeftTypes: []DataType{Float8}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{UnknownDataType}, + Expected: false, + }, + + // Left hand unknown types should not be comparable against any right hand type + { + LeftTypes: []DataType{UnknownDataType}, + Operators: []Operator{OperatorEquals, OperatorNotEquals, OperatorGreaterThan, OperatorGreaterThanOrEqualTo, OperatorLessThan, OperatorLessThanOrEqualTo}, + RightTypes: []DataType{Int}, + Expected: false, + }, + + // Validate text operations + { + LeftTypes: []DataType{Text}, + Operators: []Operator{OperatorLike, OperatorILike, OperatorSimilarTo, OperatorRegexMatch}, + RightTypes: []DataType{Text}, + Expected: true, + }, + + // Text operations on non-text types should fail + { + LeftTypes: []DataType{Int}, + Operators: []Operator{OperatorLike, OperatorILike, OperatorSimilarTo, OperatorRegexMatch}, + RightTypes: []DataType{Int}, + Expected: false, + }, + + // Array types may use the overlap operator but only if their base types match + { + LeftTypes: []DataType{IntArray}, + Operators: []Operator{OperatorPGArrayOverlap}, + RightTypes: []DataType{IntArray}, + Expected: true, + }, + { + LeftTypes: []DataType{IntArray}, + Operators: []Operator{OperatorPGArrayOverlap}, + RightTypes: []DataType{Int}, + Expected: false, + }, + + // Catch all for any unsupported operator + { + LeftTypes: []DataType{Int}, + Operators: []Operator{"Unsupported Operator Class"}, + RightTypes: []DataType{Int}, + Expected: false, + }, + } + + for idx, testCase := range testCases { + for _, leftType := range testCase.LeftTypes { + for _, operator := range testCase.Operators { + for _, rightType := range testCase.RightTypes { + result := leftType.IsComparable(rightType, operator) + require.Equalf(t, testCase.Expected, result, "failed test case %d: %+v, %+v", idx, testCase.LeftTypes, testCase.RightTypes) + } + } + } + } +} + func TestValueToDataType(t *testing.T) { testCases := []struct { Value any diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql index 02181bfc8..955ed0b24 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/nodes.sql @@ -108,15 +108,15 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, node n1 - where n1.kind_ids operator (pg_catalog.&&) array [2]::int2[] and ((s0.n0).properties -> 'selected')::bool - or (s0.n0).properties -> 'tid' = n1.properties -> 'tid' and (n1.properties -> 'enabled')::bool) + where n1.kind_ids operator (pg_catalog.&&) array [2]::int2[] and ((s0.n0).properties ->> 'selected')::bool + or (s0.n0).properties -> 'tid' = n1.properties -> 'tid' and (n1.properties ->> 'enabled')::bool) select s1.n0 as s, s1.n1 as e from s1; -- case: match (s) where s.value + 2 / 3 > 10 return s with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 - where (n0.properties -> 'value')::int8 + 2 / 3 > 10) + where (n0.properties ->> 'value')::int8 + 2 / 3 > 10) select s0.n0 as s from s0; @@ -172,7 +172,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, node n1 - where (n1.properties -> 'other')::int8 = 1234) + where (n1.properties ->> 'other')::int8 = 1234) select s1.n0 as s from s1; @@ -182,7 +182,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from from s0, node n1 where (s0.n0).properties ->> 'name' = '1234' - or (n1.properties -> 'other')::int8 = 1234) + or (n1.properties ->> 'other')::int8 = 1234) select s1.n0 as s from s1; @@ -281,12 +281,12 @@ from s0; -- case: match (s) return s.value + 1 with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) -select ((s0.n0).properties -> 'value')::int8 + 1 +select ((s0.n0).properties ->> 'value')::int8 + 1 from s0; -- case: match (s) return (s.value + 1) / 3 with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) -select (((s0.n0).properties -> 'value')::int8 + 1) / 3 +select (((s0.n0).properties ->> 'value')::int8 + 1) / 3 from s0; -- case: match (s) where id(s) in [1, 2, 3, 4] return s @@ -503,7 +503,7 @@ from s0; with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] - and (n0.properties -> 'functionallevel')::text = any + and n0.properties ->> 'functionallevel' = any (array ['2008 R2', '2012', '2008', '2003', '2003 Interim', '2000 Mixed/Native']::text[])) select s0.n0 as n from s0; @@ -512,7 +512,7 @@ from s0; with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] - and (n0.properties -> 'value')::int8 = any (array [1, 2, 3, 4]::int8[])) + and (n0.properties ->> 'value')::int8 = any (array [1, 2, 3, 4]::int8[])) select s0.n0 as n from s0; @@ -520,9 +520,9 @@ from s0; with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] - and (n0.properties -> 'pwdlastset')::numeric < + and (n0.properties ->> 'pwdlastset')::numeric < (extract(epoch from now()::timestamp with time zone)::numeric - (365 * 86400)) - and not (n0.properties -> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[])) + and not (n0.properties ->> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[])) select s0.n0 as u from s0 limit 100; @@ -531,9 +531,9 @@ limit 100; with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] - and (n0.properties -> 'pwdlastset')::numeric < + and (n0.properties ->> 'pwdlastset')::numeric < (extract(epoch from now()::timestamp with time zone)::numeric * 1000 - (365 * 86400000)) - and not (n0.properties -> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[])) + and not (n0.properties ->> 'pwdlastset')::float8 = any (array [- 1, 0]::float8[])) select s0.n0 as u from s0 limit 100; @@ -574,7 +574,7 @@ from s0; with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] - and coalesce((n0.properties -> 'a')::int8, (n0.properties -> 'b')::int8, 1)::int8 = 1) + and coalesce((n0.properties ->> 'a')::int8, (n0.properties ->> 'b')::int8, 1)::int8 = 1) select s0.n0 as n from s0; @@ -582,7 +582,7 @@ from s0; with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] - and coalesce(n0.properties -> 'a', n0.properties -> 'b')::int8 = 1) + and coalesce(n0.properties ->> 'a', n0.properties ->> 'b')::int8 = 1) select s0.n0 as n from s0; @@ -590,6 +590,47 @@ from s0; with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] - and 1 = coalesce(n0.properties -> 'a', n0.properties -> 'b')::int8) + and 1 = coalesce(n0.properties ->> 'a', n0.properties ->> 'b')::int8) select s0.n0 as n from s0; + +-- case: match (u:NodeKind1) where u.hasspn = true and u.enabled = true and not '-502' ends with u.objectid and not coalesce(u.gmsa, false) = true and not coalesce(u.msa, false) = true return u limit 10 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and (n0.properties ->> 'hasspn')::bool = true + and (n0.properties ->> 'enabled')::bool = true + and not '-502' like coalesce(n0.properties ->> 'objectid', '')::text + and not coalesce((n0.properties ->> 'gmsa')::bool, false)::bool = true + and not coalesce((n0.properties ->> 'msa')::bool, false)::bool = true) +select s0.n0 as u +from s0 +limit 10; + +-- case: match (n:NodeKind1) where coalesce(n.name, '') = coalesce(n.migrated_name, '') return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and coalesce(n0.properties ->> 'name', '')::text = coalesce(n0.properties ->> 'migrated_name', '')::text) +select s0.n0 as n +from s0; + +-- case: match (n:NodeKind1) where '1' in n.array_prop + ['1', '2'] return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 + from node n0 + where n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] + and '1' = any (jsonb_to_text_array(n0.properties -> 'array_prop')::text[] || array ['1', '2']::text[])) +select s0.n0 as n +from s0; + +-- case: match p=(:NodeKind1)-[r]->(:NodeKind1) where r.isacl return p limit 100 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from edge e0 + join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.id = e0.start_id + join node n1 on n1.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n1.id = e0.end_id + where (e0.properties ->> 'isacl')::bool) +select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p +from s0 +limit 100; diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql index adda9c1be..5ed4432eb 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/pattern_binding.sql @@ -90,7 +90,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from s0, edge e1 - join node n2 on (n2.properties -> 'is_target')::bool and n2.id = e1.start_id + join node n2 on (n2.properties ->> 'is_target')::bool and n2.id = e1.start_id where (s0.n1).id = e1.end_id) select edges_to_path(variadic array [(s1.e0).id, (s1.e1).id]::int8[])::pathcomposite as p from s1; @@ -238,3 +238,30 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite select edges_to_path(variadic array [(s0.e0).id]::int8[])::pathcomposite as p from s0 limit 1000; + +-- case: match p = (:NodeKind1)-[:EdgeKind1|EdgeKind2]->(e:NodeKind2)-[:EdgeKind2]->(:NodeKind1) where 'a' in e.values or 'b' in e.values or size(e.values) = 0 return p +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from edge e0 + join node n0 on n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n0.id = e0.start_id + join node n1 on n1.kind_ids operator (pg_catalog.&&) array [2]::int2[] and + 'a' = any (jsonb_to_text_array(n1.properties -> 'values')::text[]) or + 'b' = any (jsonb_to_text_array(n1.properties -> 'values')::text[]) or + jsonb_array_length(n1.properties -> 'values')::int = 0 and n1.id = e0.end_id + where e0.kind_id = any (array [3, 4]::int2[])), + s1 as (select s0.e0 as e0, + s0.n0 as n0, + s0.n1 as n1, + (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e1, + (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 + from s0, + edge e1 + join node n2 on n2.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n2.id = e1.end_id + where e1.kind_id = any (array [4]::int2[]) + and (s0.n1).id = e1.start_id) +select edges_to_path(variadic array [(s1.e0).id, (s1.e1).id]::int8[])::pathcomposite as p +from s1; + +-- todo: the case below covers untyped array literals but has not yet been fixed +-- case: match p = (:NodeKind1)-[:EdgeKind1|EdgeKind2]->(e:NodeKind2)-[:EdgeKind2]->(:NodeKind1) where (e.a = [] or 'a' in e.a) and (e.b = 0 or e.b = 1) return p diff --git a/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql b/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql index 00089b8f4..86347ee1e 100644 --- a/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql +++ b/packages/go/cypher/models/pgsql/test/translation_cases/stepwise_traversal.sql @@ -92,7 +92,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.properties ->> 'name' = '123' and n0.id = e0.start_id join node n1 on n1.kind_ids operator (pg_catalog.&&) array [1]::int2[] and n1.id = e0.end_id - where not (e0.properties -> 'property')::bool) + where not (e0.properties ->> 'property')::bool) select s0.n0 as s, s0.e0 as r, s0.n1 as e from s0; @@ -103,7 +103,18 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id - where (e0.properties -> 'value')::int8 = 42) + where (e0.properties ->> 'value')::int8 = 42) +select s0.e0 as r +from s0; + +-- case: match ()-[r]->() where r.bool_prop return r +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from edge e0 + join node n0 on n0.id = e0.start_id + join node n1 on n1.id = e0.end_id + where (e0.properties ->> 'bool_prop')::bool) select s0.e0 as r from s0; @@ -133,7 +144,7 @@ from s0; -- case: match (f), (s)-[r]->(e) where not f.bool_field and s.name = '123' and e.name = '321' return f, s, r, e with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 - where not (n0.properties -> 'bool_field')::bool), + where not (n0.properties ->> 'bool_field')::bool), s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1, (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, @@ -292,3 +303,35 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite select s0.e0 as r from s0 limit 1; + +-- case: match (n1)-[]->(n2) where n1 <> n2 return n2 +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from edge e0 + join node n0 on n0.id = e0.start_id + join node n1 on n1.id = e0.end_id + where n0.id <> n1.id) +select s0.n1 as n2 +from s0; + +-- case: match ()-[r]->()-[e]->(n) where r <> e return n +with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, + (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, + (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 + from edge e0 + join node n0 on n0.id = e0.start_id + join node n1 on n1.id = e0.end_id), + s1 as (select s0.e0 as e0, + s0.n0 as n0, + s0.n1 as n1, + (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e1, + (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 + from s0, + edge e1 + join node n2 on n2.id = e1.end_id + where (s0.n1).id = e1.start_id + and (s0.e0).id <> e1.id) +select s1.n2 as n +from s1; + diff --git a/packages/go/cypher/models/pgsql/translate/expression.go b/packages/go/cypher/models/pgsql/translate/expression.go index d3071a474..92b04fadd 100644 --- a/packages/go/cypher/models/pgsql/translate/expression.go +++ b/packages/go/cypher/models/pgsql/translate/expression.go @@ -97,6 +97,9 @@ func applyUnaryExpressionTypeHints(expression *pgsql.UnaryExpression) error { func rewritePropertyLookupOperator(propertyLookup *pgsql.BinaryExpression, dataType pgsql.DataType) pgsql.Expression { if dataType.IsArrayType() { + // Ensure that array conversions use JSONB + propertyLookup.Operator = pgsql.OperatorJSONField + return pgsql.FunctionCall{ Function: pgsql.FunctionJSONBToTextArray, Parameters: []pgsql.Expression{propertyLookup}, @@ -114,11 +117,11 @@ func rewritePropertyLookupOperator(propertyLookup *pgsql.BinaryExpression, dataT return pgsql.NewTypeCast(propertyLookup, dataType) case pgsql.UnknownDataType: - propertyLookup.Operator = pgsql.OperatorJSONField + propertyLookup.Operator = pgsql.OperatorJSONTextField return propertyLookup default: - propertyLookup.Operator = pgsql.OperatorJSONField + propertyLookup.Operator = pgsql.OperatorJSONTextField return pgsql.NewTypeCast(propertyLookup, dataType) } } @@ -139,7 +142,7 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy if isLeftHinted { if isRightHinted { - if higherLevelHint, matchesOrConverts := leftHint.Compatible(rightHint, expression.Operator); !matchesOrConverts { + if higherLevelHint, matchesOrConverts := leftHint.OperatorResultType(rightHint, expression.Operator); !matchesOrConverts { return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, rightHint) } else { return higherLevelHint, nil @@ -149,7 +152,7 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy } else if inferredRightHint == pgsql.UnknownDataType { // Assume the right side is convertable and return the left operand hint return leftHint, nil - } else if upcastHint, matchesOrConverts := leftHint.Compatible(inferredRightHint, expression.Operator); !matchesOrConverts { + } else if upcastHint, matchesOrConverts := leftHint.OperatorResultType(inferredRightHint, expression.Operator); !matchesOrConverts { return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, leftHint, inferredRightHint) } else { return upcastHint, nil @@ -161,7 +164,7 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy } else if inferredLeftHint == pgsql.UnknownDataType { // Assume the right side is convertable and return the left operand hint return rightHint, nil - } else if upcastHint, matchesOrConverts := rightHint.Compatible(inferredLeftHint, expression.Operator); !matchesOrConverts { + } else if upcastHint, matchesOrConverts := rightHint.OperatorResultType(inferredLeftHint, expression.Operator); !matchesOrConverts { return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, rightHint, inferredLeftHint) } else { return upcastHint, nil @@ -189,7 +192,7 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy // Unable to infer any type information, this may be resolved elsewhere so this is not explicitly // an error condition return pgsql.UnknownDataType, nil - } else if higherLevelHint, matchesOrConverts := inferredLeftHint.Compatible(inferredRightHint, expression.Operator); !matchesOrConverts { + } else if higherLevelHint, matchesOrConverts := inferredLeftHint.OperatorResultType(inferredRightHint, expression.Operator); !matchesOrConverts { return pgsql.UnsetDataType, fmt.Errorf("left and right operands for binary expression \"%s\" are not compatible: %s != %s", expression.Operator, inferredLeftHint, inferredRightHint) } else { return higherLevelHint, nil @@ -201,7 +204,6 @@ func inferBinaryExpressionType(expression *pgsql.BinaryExpression) (pgsql.DataTy func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) { switch typedExpression := expression.(type) { case pgsql.Identifier, pgsql.CompoundExpression: - // TODO: Type inference may be aided by searching the bound scope for a data type return pgsql.UnknownDataType, nil case pgsql.CompoundIdentifier: @@ -211,6 +213,7 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) { // Infer type information for well known column names switch typedExpression[1] { +// TODO: Graph ID should be int2 case pgsql.ColumnGraphID, pgsql.ColumnID, pgsql.ColumnStartID, pgsql.ColumnEndID: return pgsql.Int8, nil @@ -218,7 +221,7 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) { return pgsql.Int2, nil case pgsql.ColumnKindIDs: - return pgsql.Int4Array, nil + return pgsql.Int2Array, nil case pgsql.ColumnProperties: return pgsql.JSONB, nil @@ -232,7 +235,11 @@ func InferExpressionType(expression pgsql.Expression) (pgsql.DataType, error) { case *pgsql.BinaryExpression: switch typedExpression.Operator { - case pgsql.OperatorPropertyLookup, pgsql.OperatorJSONField, pgsql.OperatorJSONTextField: + case pgsql.OperatorJSONTextField: + // Text field lookups could be text or an unknown lookup - reduce it to an unknown type + return pgsql.UnknownDataType, nil + + case pgsql.OperatorPropertyLookup, pgsql.OperatorJSONField: // This is unknown, not unset meaning that it can be re-cast by future inference inspections return pgsql.UnknownDataType, nil @@ -276,11 +283,7 @@ func TypeCastExpression(expression pgsql.Expression, dataType pgsql.DataType) (p if lookupRequiresElementType(dataType, propertyLookup.Operator, propertyLookup.ROperand) { // Take the base type of the array type hint: in - if arrayBaseType, err := dataType.ArrayBaseType(); err != nil { - return nil, err - } else { - lookupTypeHint = arrayBaseType - } + lookupTypeHint = dataType.ArrayBaseType() } return rewritePropertyLookupOperator(propertyLookup, lookupTypeHint), nil @@ -295,8 +298,11 @@ func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error { rightPropertyLookup, hasRightPropertyLookup = asPropertyLookup(expression.ROperand) ) - // Don't rewrite direct property comparisons + // Ensure that direct property comparisons prefer JSONB - JSONB if hasLeftPropertyLookup && hasRightPropertyLookup { + leftPropertyLookup.Operator = pgsql.OperatorJSONField + rightPropertyLookup.Operator = pgsql.OperatorJSONField + return nil } @@ -304,21 +310,13 @@ func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error { // This check exists here to prevent from overwriting a property lookup that's part of a in // binary expression. This may want for better ergonomics in the future if anyExpression, isAnyExpression := expression.ROperand.(pgsql.AnyExpression); isAnyExpression { - if arrayBaseType, err := anyExpression.CastType.ArrayBaseType(); err != nil { - return err - } else { - expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType) - } + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, anyExpression.CastType.ArrayBaseType()) } else if rOperandTypeHint, err := InferExpressionType(expression.ROperand); err != nil { return err } else { switch expression.Operator { case pgsql.OperatorIn: - if arrayBaseType, err := rOperandTypeHint.ArrayBaseType(); err != nil { - return err - } else { - expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, arrayBaseType) - } + expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, rOperandTypeHint.ArrayBaseType()) case pgsql.OperatorCypherStartsWith, pgsql.OperatorCypherEndsWith, pgsql.OperatorCypherContains, pgsql.OperatorRegexMatch: expression.LOperand = rewritePropertyLookupOperator(leftPropertyLookup, pgsql.Text) @@ -353,29 +351,98 @@ func rewritePropertyLookupOperands(expression *pgsql.BinaryExpression) error { return nil } -func applyTypeFunctionCallTypeHints(expression *pgsql.BinaryExpression) error { +func newFunctionCallComparatorError(functionCall pgsql.FunctionCall, operator pgsql.Operator, comparisonType pgsql.DataType) error { + switch functionCall.Function { + case pgsql.FunctionCoalesce: + // This is a specific error statement for coalesce statements. These statements have ill-defined + // type conversion semantics in Cypher. As such, exposing the type specificity of coalesce to the + // user as a distinct error will help reduce the surprise of running on a non-Neo4j substrate. + return fmt.Errorf("coalesce has type %s but is being compared against type %s - ensure that all arguments in the coalesce function match the type of the other side of the comparison", functionCall.CastType, comparisonType) + default: + return fmt.Errorf("function call has return signature of type %s but is being compared using operator %s against type %s", functionCall.CastType, operator, comparisonType) + } +} + +func applyTypeFunctionLikeTypeHints(expression *pgsql.BinaryExpression) error { switch typedLOperand := expression.LOperand.(type) { - case pgsql.FunctionCall: - if !typedLOperand.CastType.IsKnown() { - if rOperandTypeHint, err := InferExpressionType(expression.ROperand); err != nil { - return err + case pgsql.AnyExpression: + if rOperandTypeHint, err := InferExpressionType(expression.ROperand); err != nil { + return err + } else { + // In an any-expression where the type of the any-expression is unknown, attempt to infer it + if !typedLOperand.CastType.IsKnown() { + if rOperandArrayTypeHint, err := rOperandTypeHint.ToArrayType(); err != nil { + return err + } else { + typedLOperand.CastType = rOperandArrayTypeHint + expression.LOperand = typedLOperand + } + } else if !rOperandTypeHint.IsKnown() { + expression.ROperand = pgsql.NewTypeCast(expression.ROperand, typedLOperand.CastType.ArrayBaseType()) } else { + // Validate against the array base type of the any-expression + lOperandBaseType := typedLOperand.CastType.ArrayBaseType() + + if !lOperandBaseType.IsComparable(rOperandTypeHint, expression.Operator) { + return fmt.Errorf("function call has return signature of type %s but is being compared using operator %s against type %s", typedLOperand.CastType, expression.Operator, rOperandTypeHint) + } + } + } + + case pgsql.FunctionCall: + if rOperandTypeHint, err := InferExpressionType(expression.ROperand); err != nil { + return err + } else { + if !typedLOperand.CastType.IsKnown() { typedLOperand.CastType = rOperandTypeHint expression.LOperand = typedLOperand } + + if pgsql.OperatorIsComparator(expression.Operator) && !typedLOperand.CastType.IsComparable(rOperandTypeHint, expression.Operator) { + return newFunctionCallComparatorError(typedLOperand, expression.Operator, rOperandTypeHint) + } } } switch typedROperand := expression.ROperand.(type) { - case pgsql.FunctionCall: - if !typedROperand.CastType.IsKnown() { - if lOperandTypeHint, err := InferExpressionType(expression.LOperand); err != nil { - return err + case pgsql.AnyExpression: + if lOperandTypeHint, err := InferExpressionType(expression.LOperand); err != nil { + return err + } else { + // In an any-expression where the type of the any-expression is unknown, attempt to infer it + if !typedROperand.CastType.IsKnown() { + if rOperandArrayTypeHint, err := lOperandTypeHint.ToArrayType(); err != nil { + return err + } else { + typedROperand.CastType = rOperandArrayTypeHint + expression.LOperand = typedROperand + } + } else if !lOperandTypeHint.IsKnown() { + expression.LOperand = pgsql.NewTypeCast(expression.LOperand, typedROperand.CastType.ArrayBaseType()) } else { + // Validate against the array base type of the any-expression + rOperandBaseType := typedROperand.CastType.ArrayBaseType() + + if !rOperandBaseType.IsComparable(lOperandTypeHint, expression.Operator) { + return fmt.Errorf("function call has return signature of type %s but is being compared using operator %s against type %s", typedROperand.CastType, expression.Operator, lOperandTypeHint) + } + } + } + + case pgsql.FunctionCall: + if lOperandTypeHint, err := InferExpressionType(expression.LOperand); err != nil { + return err + } else { + if !typedROperand.CastType.IsKnown() { typedROperand.CastType = lOperandTypeHint expression.ROperand = typedROperand } + + if pgsql.OperatorIsComparator(expression.Operator) && !typedROperand.CastType.IsComparable(lOperandTypeHint, expression.Operator) { + return newFunctionCallComparatorError(typedROperand, expression.Operator, lOperandTypeHint) + } } + } return nil @@ -385,7 +452,7 @@ func applyBinaryExpressionTypeHints(expression *pgsql.BinaryExpression) error { switch expression.Operator { case pgsql.OperatorPropertyLookup: // Don't directly hint property lookups but replace the operator with the JSON operator - expression.Operator = pgsql.OperatorJSONField + expression.Operator = pgsql.OperatorJSONTextField return nil } @@ -393,7 +460,7 @@ func applyBinaryExpressionTypeHints(expression *pgsql.BinaryExpression) error { return err } - return applyTypeFunctionCallTypeHints(expression) + return applyTypeFunctionLikeTypeHints(expression) } type Builder struct { @@ -548,28 +615,29 @@ func (s *ExpressionTreeTranslator) Peek() pgsql.Expression { return s.treeBuilder.Peek() } -func (s *ExpressionTreeTranslator) NumConstraints() int { - return len(s.IdentifierConstraints.Constraints) -} - func (s *ExpressionTreeTranslator) Pop() (pgsql.Expression, error) { return s.treeBuilder.Pop() } -func (s *ExpressionTreeTranslator) popOperandAsConstraint() error { - if operand, err := s.Pop(); err != nil { +func (s *ExpressionTreeTranslator) popExpressionAsConstraint() error { + if nextExpression, err := s.Pop(); err != nil { return err - } else if identifierDeps, err := ExtractSyntaxNodeReferences(operand); err != nil { + } else if identifierDeps, err := ExtractSyntaxNodeReferences(nextExpression); err != nil { return err } else { - return s.Constrain(identifierDeps, operand) + if propertyLookup, isPropertyLookup := asPropertyLookup(nextExpression); isPropertyLookup { + // If this is a bare property lookup rewrite it with the intended type of boolean + nextExpression = rewritePropertyLookupOperator(propertyLookup, pgsql.Boolean) + } + + return s.Constrain(identifierDeps, nextExpression) } } -func (s *ExpressionTreeTranslator) ConstrainRemainingOperands() error { +func (s *ExpressionTreeTranslator) PopRemainingExpressionsAsConstraints() error { // Pull the right operand only if one exists for !s.treeBuilder.IsEmpty() { - if err := s.popOperandAsConstraint(); err != nil { + if err := s.popExpressionAsConstraint(); err != nil { return err } } @@ -587,25 +655,25 @@ func (s *ExpressionTreeTranslator) ConstrainDisjointOperandPair() error { return err } else if rightDependencies, err := ExtractSyntaxNodeReferences(rightOperand); err != nil { return err + } else if s.treeBuilder.IsEmpty() { + // If the tree builder is empty then this operand is at the top of the disjunction chain + return s.Constrain(rightDependencies, rightOperand) } else if leftOperand, err := s.treeBuilder.Pop(); err != nil { return err - } else if leftDependencies, err := ExtractSyntaxNodeReferences(leftOperand); err != nil { - return err } else { - var ( - combinedDependencies = leftDependencies.Copy().MergeSet(rightDependencies) - projectionConstraint = pgsql.NewBinaryExpression( - leftOperand, - pgsql.OperatorOr, - rightOperand, - ) + newOrExpression := pgsql.NewBinaryExpression( + leftOperand, + pgsql.OperatorOr, + rightOperand, ) - if err := applyBinaryExpressionTypeHints(projectionConstraint); err != nil { + if err := applyBinaryExpressionTypeHints(newOrExpression); err != nil { return err } - return s.Constrain(combinedDependencies, projectionConstraint) + // This operation may not be complete; push it back on the stack + s.Push(newOrExpression) + return nil } } @@ -615,7 +683,7 @@ func (s *ExpressionTreeTranslator) ConstrainConjoinedOperandPair() error { return fmt.Errorf("expected at least one operand for constraint extraction") } - if err := s.popOperandAsConstraint(); err != nil { + if err := s.popExpressionAsConstraint(); err != nil { return err } @@ -675,7 +743,7 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato newExpression.ROperand = pgsql.CompoundIdentifier{typedROperand, pgsql.ColumnID} case pgsql.PathComposite: - return fmt.Errorf("invalid comparison for path identifier %s", typedLOperand) + return fmt.Errorf("comparison for path identifiers is unsupported") } } } @@ -683,6 +751,36 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato } switch operator { + case pgsql.OperatorCypherAdd: + isConcatenationOperation := func(lOperandType, rOperandType pgsql.DataType) bool { + // Any use of an array type automatically assumes concatenation + if lOperandType.IsArrayType() || rOperandType.IsArrayType() { + return true + } + + switch lOperandType { + case pgsql.Text: + switch rOperandType { + case pgsql.Text: + return true + } + } + + return false + } + + // In the case of the use of the cypher `+` operator we must attempt to disambiguate if the intent + // is to concatenate or to perform an addition + if lOperandType, err := InferExpressionType(newExpression.LOperand); err != nil { + return err + } else if rOperandType, err := InferExpressionType(newExpression.ROperand); err != nil { + return err + } else if isConcatenationOperation(lOperandType, rOperandType) { + newExpression.Operator = pgsql.OperatorConcatenate + } + + s.Push(newExpression) + case pgsql.OperatorCypherContains: newExpression.Operator = pgsql.OperatorLike @@ -938,6 +1036,9 @@ func (s *ExpressionTreeTranslator) PopPushBinaryExpression(scope *Scope, operato } else if leftArrayHint, err := leftHint.ToArrayType(); err != nil { return err } else { + // Ensure the lookup uses the JSONB type + propertyLookup.Operator = pgsql.OperatorJSONField + newExpression.ROperand = pgsql.NewAnyExpression( pgsql.FunctionCall{ Function: pgsql.FunctionJSONBToTextArray, diff --git a/packages/go/cypher/models/pgsql/translate/expression_test.go b/packages/go/cypher/models/pgsql/translate/expression_test.go index 1e514ce38..ca6af0da6 100644 --- a/packages/go/cypher/models/pgsql/translate/expression_test.go +++ b/packages/go/cypher/models/pgsql/translate/expression_test.go @@ -111,6 +111,7 @@ func TestInferExpressionType(t *testing.T) { ), ), }, { + Exclusive: true, ExpectedType: pgsql.Int4, Expression: pgsql.NewBinaryExpression( pgsql.NewPropertyLookup( @@ -239,7 +240,7 @@ func TestExpressionTreeTranslator(t *testing.T) { treeTranslator.PopPushOperator(scope, pgsql.OperatorAnd) // Assign remaining operands as constraints - treeTranslator.ConstrainRemainingOperands() + treeTranslator.PopRemainingExpressionsAsConstraints() // Pull out the 'a' constraint aIdentifier := pgsql.AsIdentifierSet("a") diff --git a/packages/go/cypher/models/pgsql/translate/translation.go b/packages/go/cypher/models/pgsql/translate/translation.go index 7b109e120..c5fbac294 100644 --- a/packages/go/cypher/models/pgsql/translate/translation.go +++ b/packages/go/cypher/models/pgsql/translate/translation.go @@ -252,6 +252,7 @@ func (s *Translator) translateCoalesceFunction(functionInvocation *cypher.Functi // Find and validate types of the arguments for _, argument := range arguments { + // Properties have no type information and should be skipped if argumentType, err := InferExpressionType(argument); err != nil { return err } else if argumentType.IsKnown() { @@ -345,9 +346,9 @@ func (s *Translator) translateProjectionItem(scope *Scope, projectionItem *cyphe } case *pgsql.BinaryExpression: - if typedSelectItem.Operator == pgsql.OperatorPropertyLookup { - // TODO: This probably belongs somewhere else - typedSelectItem.Operator = pgsql.OperatorJSONField + if propertyLookup, isPropertyLookup := asPropertyLookup(typedSelectItem); isPropertyLookup { + // Ensure that projections maintain the raw JSONB type of the field + propertyLookup.Operator = pgsql.OperatorJSONField } } diff --git a/packages/go/cypher/models/pgsql/translate/translator.go b/packages/go/cypher/models/pgsql/translate/translator.go index 8c15f52f9..701a0372a 100644 --- a/packages/go/cypher/models/pgsql/translate/translator.go +++ b/packages/go/cypher/models/pgsql/translate/translator.go @@ -446,6 +446,11 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } else if err := RewriteExpressionIdentifiers(lookupExpression, s.query.Scope.CurrentFrameBinding().Identifier, s.query.Scope.Visible()); err != nil { s.SetError(err) } else { + if propertyLookup, isPropertyLookup := asPropertyLookup(lookupExpression); isPropertyLookup { + // If sorting, use the raw type of the JSONB field + propertyLookup.Operator = pgsql.OperatorJSONField + } + s.query.CurrentOrderBy().Expression = lookupExpression } @@ -606,7 +611,10 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { } else { var functionCall pgsql.FunctionCall - if _, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { + if propertyLookup, isPropertyLookup := asPropertyLookup(argument); isPropertyLookup { + // Ensure that the JSONB array length function receives the JSONB type + propertyLookup.Operator = pgsql.OperatorJSONField + functionCall = pgsql.FunctionCall{ Function: pgsql.FunctionJSONBArrayLength, Parameters: []pgsql.Expression{argument}, @@ -746,7 +754,7 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.exitState(StateTranslatingWhere) // Assign the last operands as identifier set constraints - if err := s.treeTranslator.ConstrainRemainingOperands(); err != nil { + if err := s.treeTranslator.PopRemainingExpressionsAsConstraints(); err != nil { s.SetError(err) } diff --git a/packages/go/cypher/models/pgsql/translate/update.go b/packages/go/cypher/models/pgsql/translate/update.go index 6f5fc89ce..db1dfb699 100644 --- a/packages/go/cypher/models/pgsql/translate/update.go +++ b/packages/go/cypher/models/pgsql/translate/update.go @@ -145,9 +145,15 @@ func (s *Translator) buildUpdates(scope *Scope) error { } for _, propertyAssignment := range identifierMutation.PropertyAssignments.Values() { + if propertyLookup, isPropertyLookup := asPropertyLookup(propertyAssignment.ValueExpression); isPropertyLookup { + // Ensure that property lookups in JSONB build functions use the JSONB field type + propertyLookup.Operator = pgsql.OperatorJSONField + } + jsonObjectFunction.Parameters = append(jsonObjectFunction.Parameters, pgsql.NewLiteral(propertyAssignment.Field, pgsql.Text), - propertyAssignment.ValueExpression) + propertyAssignment.ValueExpression, + ) } propertyAssignments = models.ValueOptional(jsonObjectFunction.AsExpression())