Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
a5c4df7
fix(pgsql): preserve WITH identifier aliases across shortest path agg…
zinic May 11, 2026
6d0be96
fix(pgsql): handle zero-hop variable-length expansions - BP-2518
zinic May 11, 2026
31dde31
fix(pgsql): infer temporal arithmetic for relationship property compa…
zinic May 11, 2026
ea126b4
feat(pgsql): support path list functions in expansion predicates - BE…
zinic May 11, 2026
2d979e3
fix(pgsql): infer date interval arithmetic as timestamp
zinic May 11, 2026
33a9e35
fix(pgsql): avoid duplicate paths in zero-hop expansions
zinic May 11, 2026
f0d1b0c
fix(pgsql): reject unsupported property lookups directly
zinic May 11, 2026
f5163a9
fix(pgsql): infer binary array expressions in quantifiers
zinic May 11, 2026
9e1b8ef
fix(pgsql): rewrite value-form array slices in renamer
zinic May 11, 2026
166b0b2
test(integration): cover enforced head relationship path filtering
zinic May 11, 2026
4beaafa
fix(pgsql): guard zero-hop expansion and resolve path aliases
zinic May 12, 2026
0d1de62
fix (neo4j): rewrite temporal expressions to restore neo4j time-date …
zinic May 12, 2026
c83def2
fix(pgsql): rewrite array slices in projections and order clauses
zinic May 12, 2026
97032ae
fix (pgsql): remove divide by zero guard in shortest paths harness fo…
zinic May 12, 2026
ed7ba88
fix (neo4j): remove rewrites from hot-path for queries that do not ne…
zinic May 12, 2026
c02addd
fix (pgsql): pg 16 specific root/terminal id filter throws an overly …
zinic May 12, 2026
1366719
fix(pgsql): preserve primer branch for zero-depth expansions
zinic May 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cypher/models/cypher/functions.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ const (
ToIntegerFunction = "toint"
ToIntegerAliasFunction = "tointeger"
ListSizeFunction = "size"
HeadFunction = "head"
TailFunction = "tail"
NodesFunction = "nodes"
RelationshipsFunction = "relationships"
CoalesceFunction = "coalesce"
CollectFunction = "collect"
SumFunction = "sum"
Expand Down
71 changes: 71 additions & 0 deletions cypher/models/pgsql/format/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,48 @@ func formatLiteral(builder *OutputBuilder, literal pgsql.Literal) error {
return formatValue(builder, literal.Value)
}

func formatCase(builder *OutputBuilder, caseExpr pgsql.Case) error {
if len(caseExpr.Conditions) != len(caseExpr.Then) {
return fmt.Errorf("case expression has %d conditions and %d then expressions", len(caseExpr.Conditions), len(caseExpr.Then))
}

builder.Write("case")

if caseExpr.Operand != nil {
builder.Write(" ")

if err := formatNode(builder, caseExpr.Operand); err != nil {
return err
}
}

for idx, condition := range caseExpr.Conditions {
builder.Write(" when ")

if err := formatNode(builder, condition); err != nil {
return err
}

builder.Write(" then ")

if err := formatNode(builder, caseExpr.Then[idx]); err != nil {
return err
}
}

if caseExpr.Else != nil {
builder.Write(" else ")

if err := formatNode(builder, caseExpr.Else); err != nil {
return err
}
}

builder.Write(" end")

return nil
}

func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {
exprStack := []pgsql.SyntaxNode{
rootExpr,
Expand Down Expand Up @@ -184,6 +226,16 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {
return err
}

case *pgsql.Case:
if err := formatCase(builder, *typedNextExpr); err != nil {
return err
}

case pgsql.Case:
if err := formatCase(builder, typedNextExpr); err != nil {
return err
}

case *pgsql.Materialized:
if typedNextExpr.Materialized {
exprStack = append(exprStack, pgsql.FormattingLiteral("materialized"))
Expand Down Expand Up @@ -434,6 +486,25 @@ func formatNode(builder *OutputBuilder, rootExpr pgsql.SyntaxNode) error {
case *pgsql.ArrayIndex:
exprStack = append(exprStack, *typedNextExpr)

case pgsql.ArraySlice:
exprStack = append(exprStack, pgsql.FormattingLiteral("]"))

if typedNextExpr.Upper != nil {
exprStack = append(exprStack, typedNextExpr.Upper)
}

exprStack = append(exprStack, pgsql.FormattingLiteral(":"))

if typedNextExpr.Lower != nil {
exprStack = append(exprStack, typedNextExpr.Lower)
}

exprStack = append(exprStack, pgsql.FormattingLiteral("["))
exprStack = append(exprStack, typedNextExpr.Expression)

case *pgsql.ArraySlice:
exprStack = append(exprStack, *typedNextExpr)

case pgsql.TypeCast:
switch typedCastedExpr := typedNextExpr.Expression.(type) {
case *pgsql.BinaryExpression:
Expand Down
25 changes: 25 additions & 0 deletions cypher/models/pgsql/format/format_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,31 @@ func TestFormat_TypeCastedParenthetical(t *testing.T) {
require.Equal(t, "('str')::text", formattedQuery)
}

func TestFormat_Case(t *testing.T) {
formattedQuery, err := format.Expression(pgsql.Case{
Conditions: []pgsql.Expression{
pgsql.NewBinaryExpression(
pgsql.CompoundIdentifier{"s0", "root_id"},
pgsql.OperatorNotEquals,
pgsql.CompoundIdentifier{"s0", "next_id"},
),
},
Then: []pgsql.Expression{
pgsql.NewLiteral(true, pgsql.Boolean),
},
Else: pgsql.FunctionCall{
Function: "shortest_path_self_endpoint_error",
Parameters: []pgsql.Expression{
pgsql.CompoundIdentifier{"s0", "root_id"},
pgsql.CompoundIdentifier{"s0", "next_id"},
},
},
}, format.NewOutputBuilder())

require.NoError(t, err)
require.Equal(t, "case when s0.root_id != s0.next_id then true else shortest_path_self_endpoint_error(s0.root_id, s0.next_id) end", formattedQuery)
}

func TestFormat_SelectDistinct(t *testing.T) {
formattedQuery, err := format.Statement(pgsql.Query{
Body: pgsql.Select{
Expand Down
83 changes: 42 additions & 41 deletions cypher/models/pgsql/functions.go
Original file line number Diff line number Diff line change
@@ -1,47 +1,48 @@
package pgsql

const (
FunctionUnidirectionalASPHarness Identifier = "unidirectional_asp_harness"
FunctionUnidirectionalSPHarness Identifier = "unidirectional_sp_harness"
FunctionBidirectionalASPHarness Identifier = "bidirectional_asp_harness"
FunctionBidirectionalSPHarness Identifier = "bidirectional_sp_harness"
FunctionIntArrayUnique Identifier = "uniq"
FunctionIntArraySort Identifier = "sort"
FunctionJSONBToTextArray Identifier = "jsonb_to_text_array"
FunctionJSONBArrayElementsText Identifier = "jsonb_array_elements_text"
FunctionJSONBBuildObject Identifier = "jsonb_build_object"
FunctionJSONBArrayLength Identifier = "jsonb_array_length"
FunctionToJSONB Identifier = "to_jsonb"
FunctionCypherContains Identifier = "cypher_contains"
FunctionCypherStartsWith Identifier = "cypher_starts_with"
FunctionCypherEndsWith Identifier = "cypher_ends_with"
FunctionArrayLength Identifier = "array_length"
FunctionCardinality Identifier = "cardinality"
FunctionArrayAggregate Identifier = "array_agg"
FunctionArrayRemove Identifier = "array_remove"
FunctionMin Identifier = "min"
FunctionMax Identifier = "max"
FunctionSum Identifier = "sum"
FunctionAvg Identifier = "avg"
FunctionLocalTimestamp Identifier = "localtimestamp"
FunctionLocalTime Identifier = "localtime"
FunctionCurrentTime Identifier = "current_time"
FunctionCurrentDate Identifier = "current_date"
FunctionNow Identifier = "now"
FunctionToLower Identifier = "lower"
FunctionToUpper Identifier = "upper"
FunctionCoalesce Identifier = "coalesce"
FunctionReplace Identifier = "replace"
FunctionUnnest Identifier = "unnest"
FunctionNextValue Identifier = "nextval"
FunctionPGGetSerialSequence Identifier = "pg_get_serial_sequence"
FunctionJSONBSet Identifier = "jsonb_set"
FunctionCount Identifier = "count"
FunctionStringToArray Identifier = "string_to_array"
FunctionEdgesToPath Identifier = "edges_to_path"
FunctionOrderedEdgesToPath Identifier = "ordered_edges_to_path"
FunctionNodesToPath Identifier = "nodes_to_path"
FunctionExtract Identifier = "extract"
FunctionUnidirectionalASPHarness Identifier = "unidirectional_asp_harness"
FunctionUnidirectionalSPHarness Identifier = "unidirectional_sp_harness"
FunctionBidirectionalASPHarness Identifier = "bidirectional_asp_harness"
FunctionBidirectionalSPHarness Identifier = "bidirectional_sp_harness"
FunctionShortestPathSelfEndpointError Identifier = "shortest_path_self_endpoint_error"
FunctionIntArrayUnique Identifier = "uniq"
FunctionIntArraySort Identifier = "sort"
FunctionJSONBToTextArray Identifier = "jsonb_to_text_array"
FunctionJSONBArrayElementsText Identifier = "jsonb_array_elements_text"
FunctionJSONBBuildObject Identifier = "jsonb_build_object"
FunctionJSONBArrayLength Identifier = "jsonb_array_length"
FunctionToJSONB Identifier = "to_jsonb"
FunctionCypherContains Identifier = "cypher_contains"
FunctionCypherStartsWith Identifier = "cypher_starts_with"
FunctionCypherEndsWith Identifier = "cypher_ends_with"
FunctionArrayLength Identifier = "array_length"
FunctionCardinality Identifier = "cardinality"
FunctionArrayAggregate Identifier = "array_agg"
FunctionArrayRemove Identifier = "array_remove"
FunctionMin Identifier = "min"
FunctionMax Identifier = "max"
FunctionSum Identifier = "sum"
FunctionAvg Identifier = "avg"
FunctionLocalTimestamp Identifier = "localtimestamp"
FunctionLocalTime Identifier = "localtime"
FunctionCurrentTime Identifier = "current_time"
FunctionCurrentDate Identifier = "current_date"
FunctionNow Identifier = "now"
FunctionToLower Identifier = "lower"
FunctionToUpper Identifier = "upper"
FunctionCoalesce Identifier = "coalesce"
FunctionReplace Identifier = "replace"
FunctionUnnest Identifier = "unnest"
FunctionNextValue Identifier = "nextval"
FunctionPGGetSerialSequence Identifier = "pg_get_serial_sequence"
FunctionJSONBSet Identifier = "jsonb_set"
FunctionCount Identifier = "count"
FunctionStringToArray Identifier = "string_to_array"
FunctionEdgesToPath Identifier = "edges_to_path"
FunctionOrderedEdgesToPath Identifier = "ordered_edges_to_path"
FunctionNodesToPath Identifier = "nodes_to_path"
FunctionExtract Identifier = "extract"
)

func IsAggregateFunction(function Identifier) bool {
Expand Down
52 changes: 52 additions & 0 deletions cypher/models/pgsql/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,18 @@ type Case struct {
Else Expression
}

func (s Case) NodeType() string {
return "case"
}

func (s Case) AsExpression() Expression {
return s
}

func (s Case) AsSelectItem() SelectItem {
return s
}

// InExpression represents a contains operation against a list of evaluated expressions:
// m.identifier in (val1, val2, ...)
type InExpression struct {
Expand Down Expand Up @@ -145,6 +157,10 @@ func (s TypeCast) AsExpression() Expression {
return s
}

func (s TypeCast) AsSelectItem() SelectItem {
return s
}

func (s TypeCast) TypeHint() DataType {
return s.CastType
}
Expand Down Expand Up @@ -617,6 +633,7 @@ func (s Identifier) Matches(others ...Identifier) bool {
type ArrayIndex struct {
Expression Expression
Indexes []Expression
CastType DataType
}

func (s ArrayIndex) NodeType() string {
Expand All @@ -627,6 +644,41 @@ func (s ArrayIndex) AsExpression() Expression {
return s
}

func (s ArrayIndex) TypeHint() DataType {
if s.CastType == UnsetDataType {
return UnknownDataType
}

return s.CastType
}

type ArraySlice struct {
Expression Expression
Lower Expression
Upper Expression
CastType DataType
}

func (s ArraySlice) NodeType() string {
return "array_slice"
}

func (s ArraySlice) AsExpression() Expression {
return s
}

func (s ArraySlice) AsSelectItem() SelectItem {
return s
}

func (s ArraySlice) TypeHint() DataType {
if s.CastType == UnsetDataType {
return UnknownDataType
}

return s.CastType
}

type RowColumnReference struct {
Identifier Expression
Column Identifier
Expand Down
36 changes: 36 additions & 0 deletions cypher/models/pgsql/pgtypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ const (
ColumnGraphID Identifier = "graph_id"
ColumnStartID Identifier = "start_id"
ColumnEndID Identifier = "end_id"
ColumnNodes Identifier = "nodes"
ColumnEdges Identifier = "edges"
)

var (
Expand Down Expand Up @@ -114,6 +116,24 @@ func (s DataType) IsKnown() bool {
}
}

func (s DataType) IsTemporalType() bool {
switch s {
case Date, TimeWithTimeZone, TimeWithoutTimeZone, TimestampWithTimeZone, TimestampWithoutTimeZone:
return true

default:
return false
}
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

func (s DataType) TemporalIntervalArithmeticResultType() DataType {
if s == Date {
return TimestampWithoutTimeZone
}

return s
}

func (s DataType) IsComparable(other DataType, operator Operator) bool {
switch operator {
case OperatorPGArrayOverlap, OperatorArrayOverlap, OperatorPGArrayLHSContainsRHS:
Expand Down Expand Up @@ -263,6 +283,22 @@ func (s DataType) OperatorResultType(other DataType, operator Operator) (DataTyp
}

// Other special cases for arithmetic
switch operator {
case OperatorAdd:
if s.IsTemporalType() && other == Interval {
return s.TemporalIntervalArithmeticResultType(), true
}

if s == Interval && other.IsTemporalType() {
return other.TemporalIntervalArithmeticResultType(), true
}

case OperatorSubtract:
if s.IsTemporalType() && other == Interval {
return s.TemporalIntervalArithmeticResultType(), true
}
}

switch s {
case Date:
switch other {
Expand Down
Loading
Loading