Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions cypher/models/pgsql/test/pattern_predicate_shape_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
package test

import (
"context"
"strings"
"testing"

"github.com/specterops/dawgs/cypher/frontend"
"github.com/specterops/dawgs/cypher/models/cypher"
"github.com/specterops/dawgs/cypher/models/pgsql/translate"
"github.com/specterops/dawgs/graph"
"github.com/specterops/dawgs/query"
)

func normalizeSQL(sqlQuery string) string {
return strings.Join(strings.Fields(strings.ToLower(sqlQuery)), " ")
}

func assertInboundPatternPredicateShape(t *testing.T, sqlQuery string) {
t.Helper()

normalized := normalizeSQL(sqlQuery)

requiredFragments := []string{
"select count(*) > 0 from s1",
"from edge e0 join node n1",
"(s0.n0).id = e0.end_id",
"n1.kind_ids operator (pg_catalog.@>) array [1]::int2[]",
"e0.kind_id = any (array [3]::int2[])",
}

for _, fragment := range requiredFragments {
if !strings.Contains(normalized, fragment) {
t.Fatalf("expected SQL to contain fragment %q but it did not:\n%s", fragment, sqlQuery)
}
}

forbiddenFragments := []string{
"from s0 join edge e0 on (s0.n0).id = e0.end_id",
}

for _, fragment := range forbiddenFragments {
if strings.Contains(normalized, fragment) {
t.Fatalf("expected SQL to avoid fragment %q but it was present:\n%s", fragment, sqlQuery)
}
}
}

func buildInboundNodeKind1EdgeKind1PatternPredicate(symbol string) *cypher.PatternPredicate {
patternPredicate := cypher.NewPatternPredicate()

patternPredicate.AddElement(&cypher.NodePattern{
Variable: cypher.NewVariableWithSymbol(symbol),
})

patternPredicate.AddElement(&cypher.RelationshipPattern{
Kinds: graph.Kinds{EdgeKind1},
Direction: graph.DirectionInbound,
})

patternPredicate.AddElement(&cypher.NodePattern{
Kinds: graph.Kinds{NodeKind1},
})

return patternPredicate
}

func TestTranslate_PatternPredicateInboundShape_CypherFrontend(t *testing.T) {
regularQuery, err := frontend.ParseCypher(
frontend.NewContext(),
"match (g:NodeKind2) where not ((g)<-[:EdgeKind1]-(:NodeKind1)) return g",
)
if err != nil {
t.Fatalf("failed to parse cypher query: %v", err)
}

translatedQuery, err := translate.Translate(context.Background(), regularQuery, newKindMapper(), nil)
if err != nil {
t.Fatalf("failed to translate cypher query: %v", err)
}

formattedQuery, err := translate.Translated(translatedQuery)
if err != nil {
t.Fatalf("failed to format translated SQL query: %v", err)
}

assertInboundPatternPredicateShape(t, formattedQuery)
}

func TestTranslate_PatternPredicateInboundShape_GraphFrontend(t *testing.T) {
builder := query.NewBuilderWithCriteria(
query.Where(query.And(
query.Kind(query.Node(), NodeKind2),
query.Not(buildInboundNodeKind1EdgeKind1PatternPredicate(query.NodeSymbol)),
)),
query.Returning(query.Node()),
)

rawQuery, err := builder.Build(false)
if err != nil {
t.Fatalf("failed to build graph frontend query: %v", err)
}

translatedQuery, err := translate.Translate(context.Background(), rawQuery, newKindMapper(), nil)
if err != nil {
t.Fatalf("failed to translate graph frontend query: %v", err)
}

formattedQuery, err := translate.Translated(translatedQuery)
if err != nil {
t.Fatalf("failed to format translated SQL query: %v", err)
}

assertInboundPatternPredicateShape(t, formattedQuery)
}
6 changes: 3 additions & 3 deletions cypher/models/pgsql/test/translation_cases/nodes.sql
Original file line number Diff line number Diff line change
Expand Up @@ -184,13 +184,13 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where (not exists (select 1 from edge e0 where e0.start_id = (s0.n0).id or e0.end_id = (s0.n0).id));

-- case: match (s) where not (s)-[]->()-[]->() return s
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where (not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0 join edge e0 on (s0.n0).id = e0.start_id join node n1 on n1.id = e0.end_id), s2 as (select s1.e0 as e0, (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e1, s1.n0 as n0, s1.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from s1 join edge e1 on (s1.n1).id = e1.start_id join node n2 on n2.id = e1.end_id) select count(*) > 0 from s2));
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where (not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n1 on n1.id = e0.end_id where (s0.n0).id = e0.start_id), s2 as (select s1.e0 as e0, (e1.id, e1.start_id, e1.end_id, e1.kind_id, e1.properties)::edgecomposite as e1, s1.n0 as n0, s1.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from s1 join edge e1 on (s1.n1).id = e1.start_id join node n2 on n2.id = e1.end_id) select count(*) > 0 from s2));

-- case: match (s) where not (s)-[{prop: 'a'}]-({name: 'n3'}) return s
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where (not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0 join edge e0 on ((s0.n0).id = e0.end_id or (s0.n0).id = e0.start_id) join node n1 on (n1.properties ->> 'name') = 'n3' and (n1.id = e0.end_id or n1.id = e0.start_id) where ((s0.n0).id <> n1.id) and (e0.properties ->> 'prop') = 'a') select count(*) > 0 from s1));

-- case: match (s) where not (s)<-[{prop: 'a'}]-({name: 'n3'}) return s
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where (not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0 join edge e0 on (s0.n0).id = e0.end_id join node n1 on (n1.properties ->> 'name') = 'n3' and n1.id = e0.start_id where (e0.properties ->> 'prop') = 'a') select count(*) > 0 from s1));
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where (not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n1 on (n1.properties ->> 'name') = 'n3' and n1.id = e0.start_id where (e0.properties ->> 'prop') = 'a' and (s0.n0).id = e0.end_id) select count(*) > 0 from s1));

-- case: match (n:NodeKind1) where n.distinguishedname = toUpper('admin') return n
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where ((n0.properties ->> 'distinguishedname') = upper('admin')::text) and n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select s0.n0 as n from s0;
Expand All @@ -205,7 +205,7 @@ with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where ((n0.properties ->> 'distinguishedname') like '%' || upper('admin')::text) and n0.kind_ids operator (pg_catalog.@>) array [1]::int2[]) select s0.n0 as n from s0;

-- case: match (s) where not (s)-[{prop: 'a'}]->({name: 'n3'}) return s
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where (not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0 join edge e0 on (s0.n0).id = e0.start_id join node n1 on (n1.properties ->> 'name') = 'n3' and n1.id = e0.end_id where (e0.properties ->> 'prop') = 'a') select count(*) > 0 from s1));
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select s0.n0 as s from s0 where (not (with s1 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n1 on (n1.properties ->> 'name') = 'n3' and n1.id = e0.end_id where (e0.properties ->> 'prop') = 'a' and (s0.n0).id = e0.start_id) select count(*) > 0 from s1));

-- case: match (s) where not (s)-[]-() return id(s)
with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0) select (s0.n0).id from s0 where (not exists (select 1 from edge e0 where e0.start_id = (s0.n0).id or e0.end_id = (s0.n0).id));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::e
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.@>) array [1]::int2[] and n0.id = e0.start_id join node n1 on n1.kind_ids operator (pg_catalog.@>) array [2]::int2[] and n1.id = e0.end_id where e0.kind_id = any (array [3, 4]::int2[])) select ((s0.n0).properties -> 'name'), ((s0.n1).properties -> 'name') from s0;

-- case: match (s)-[r:EdgeKind1]->() where (s)-[r {prop: 'a'}]->() return s
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.properties ->> 'prop') = 'a' and e0.kind_id = any (array [3]::int2[])) select s0.n0 as s from s0 where ((with s1 as (select s0.e0 as e0, s0.n0 as n0, s0.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from s0 join edge e0 on (s0.n0).id = (s0.e0).start_id join node n2 on n2.id = (s0.e0).end_id) select count(*) > 0 from s1));
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.properties ->> 'prop') = 'a' and e0.kind_id = any (array [3]::int2[])) select s0.n0 as s from s0 where ((with s1 as (select s0.e0 as e0, s0.n0 as n0, s0.n1 as n1, (n2.id, n2.kind_ids, n2.properties)::nodecomposite as n2 from edge e0 join node n2 on n2.id = (s0.e0).end_id where (s0.n0).id = (s0.e0).start_id) select count(*) > 0 from s1));

-- case: match (s)-[r:EdgeKind1]->(e) where not (s.system_tags contains 'admin_tier_0') and id(e) = 1 return id(s), labels(s), id(r), type(r)
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n1 on (n1.id = 1) and n1.id = e0.end_id join node n0 on (not (coalesce((n0.properties ->> 'system_tags'), '')::text like '%admin\_tier\_0%')) and n0.id = e0.start_id where e0.kind_id = any (array [3]::int2[])) select (s0.n0).id, (s0.n0).kind_ids, (s0.e0).id, (s0.e0).kind_id from s0;
Expand All @@ -102,4 +102,3 @@ with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::e

-- case: match (s:NodeKind1:NodeKind2)-[r:EdgeKind1|EdgeKind2]->(e:NodeKind2:NodeKind1) return s.name, e.name
with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.kind_ids operator (pg_catalog.@>) array [1, 2]::int2[] and n0.id = e0.start_id join node n1 on n1.kind_ids operator (pg_catalog.@>) array [2, 1]::int2[] and n1.id = e0.end_id where e0.kind_id = any (array [3, 4]::int2[])) select ((s0.n0).properties -> 'name'), ((s0.n1).properties -> 'name') from s0;

14 changes: 13 additions & 1 deletion cypher/models/pgsql/translate/predicate.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ func (s *Translator) translatePatternPredicate() error {
return nil
}

// buildPatternPredicates is used by translateMatch to resolve deferred pattern predicate
// futures collected for the current MATCH/OPTIONAL MATCH query part's WHERE expressions
func (s *Translator) buildPatternPredicates() error {
for _, predicateFuture := range s.query.CurrentPart().patternPredicates {
var (
Expand Down Expand Up @@ -142,7 +144,17 @@ func (s *Translator) buildPatternPredicates() error {
})
}
} else {
if traversalStepQuery, err := s.buildTraversalPatternRoot(traversalStep.Frame, traversalStep); err != nil {
var (
traversalStepQuery pgsql.Query
err error
)
if traversalStep.Direction != graph.DirectionBoth && (traversalStep.LeftNodeBound || traversalStep.RightNodeBound) {
traversalStepQuery, err = s.buildTraversalPatternRootWithOuterCorrelation(traversalStep.Frame, traversalStep)
} else {
traversalStepQuery, err = s.buildTraversalPatternRoot(traversalStep.Frame, traversalStep)
}

if err != nil {
return err
} else {
subQuery.AddCTE(pgsql.CommonTableExpression{
Expand Down
93 changes: 92 additions & 1 deletion cypher/models/pgsql/translate/traversal.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func (s *Translator) buildDirectionlessTraversalPatternRoot(traversalStep *Trave
},
JoinOperator: pgsql.JoinOperator{
JoinType: pgsql.JoinTypeInner,
Constraint: pgsql.OptionalAnd(rightJoinLocal, traversalStep.RightNodeJoinCondition)},
Constraint: pgsql.OptionalAnd(rightJoinLocal, traversalStep.RightNodeJoinCondition),
},
}},
})

Expand Down Expand Up @@ -138,6 +139,96 @@ func (s *Translator) buildDirectionlessTraversalPatternRoot(traversalStep *Trave
}, nil
}

// buildTraversalPatternRootWithOuterCorrelation constructs a traversal pattern root, preserving the correlation to
// the outer query part's context
func (s *Translator) buildTraversalPatternRootWithOuterCorrelation(partFrame *Frame, traversalStep *TraversalStep) (pgsql.Query, error) {
if traversalStep.Direction == graph.DirectionBoth {
return s.buildDirectionlessTraversalPatternRoot(traversalStep)
}

var (
// Partition right-node constraints: only locally-scoped terms go into JOIN ON.
// Constraints that reference comma-connected CTEs (e.g. s0.i0 from a prior WITH)
// must remain in WHERE — they are out of scope inside an explicit JOIN chain.
rightJoinLocal, rightJoinExternal = partitionConstraintByLocality(
traversalStep.RightNodeConstraints,
pgsql.AsIdentifierSet(traversalStep.RightNode.Identifier, traversalStep.Edge.Identifier),
)

nextSelect = pgsql.Select{
Projection: traversalStep.Projection,
}
)

if traversalStep.LeftNodeBound {
nextSelect.From = append(nextSelect.From, pgsql.FromClause{
Source: pgsql.TableReference{
Name: pgsql.CompoundIdentifier{pgsql.TableEdge},
Binding: models.OptionalValue(traversalStep.Edge.Identifier),
},
Joins: []pgsql.Join{{
Table: pgsql.TableReference{
Name: pgsql.CompoundIdentifier{pgsql.TableNode},
Binding: models.OptionalValue(traversalStep.RightNode.Identifier),
},
JoinOperator: pgsql.JoinOperator{
JoinType: pgsql.JoinTypeInner,
Constraint: pgsql.OptionalAnd(rightJoinLocal, traversalStep.RightNodeJoinCondition),
},
}},
})

nextSelect.Where = pgsql.OptionalAnd(traversalStep.LeftNodeConstraints, nextSelect.Where)
nextSelect.Where = pgsql.OptionalAnd(traversalStep.LeftNodeJoinCondition, nextSelect.Where)
nextSelect.Where = pgsql.OptionalAnd(traversalStep.EdgeConstraints.Expression, nextSelect.Where)
nextSelect.Where = pgsql.OptionalAnd(rightJoinExternal, nextSelect.Where)

return pgsql.Query{
Body: nextSelect,
}, nil
} else if traversalStep.RightNodeBound {
// Right node was already materialized in a previous frame.
//
// We have to promote that frame to the explicit JOIN root so that RightNodeJoinCondition can reference
// it in the ON clause. PostgreSQL forbids referencing a comma-joined table inside a subsequent
// explicit JOIN's ON clause.
leftJoinLocal, leftJoinExternal := partitionConstraintByLocality(
traversalStep.LeftNodeConstraints,
pgsql.AsIdentifierSet(traversalStep.LeftNode.Identifier, traversalStep.Edge.Identifier),
)

nextSelect.From = append(nextSelect.From, pgsql.FromClause{
Source: pgsql.TableReference{
Name: pgsql.CompoundIdentifier{pgsql.TableEdge},
Binding: models.OptionalValue(traversalStep.Edge.Identifier),
},
Joins: []pgsql.Join{{
Table: pgsql.TableReference{
Name: pgsql.CompoundIdentifier{pgsql.TableNode},
Binding: models.OptionalValue(traversalStep.LeftNode.Identifier),
},
JoinOperator: pgsql.JoinOperator{
JoinType: pgsql.JoinTypeInner,
Constraint: pgsql.OptionalAnd(leftJoinLocal, traversalStep.LeftNodeJoinCondition),
},
}},
})

nextSelect.Where = pgsql.OptionalAnd(rightJoinLocal, nextSelect.Where)
nextSelect.Where = pgsql.OptionalAnd(traversalStep.RightNodeJoinCondition, nextSelect.Where)
nextSelect.Where = pgsql.OptionalAnd(leftJoinExternal, nextSelect.Where)
nextSelect.Where = pgsql.OptionalAnd(traversalStep.EdgeConstraints.Expression, nextSelect.Where)
nextSelect.Where = pgsql.OptionalAnd(rightJoinExternal, nextSelect.Where)

return pgsql.Query{
Body: nextSelect,
}, nil
} else {
// There is nothing to do to preserve outer bounds correlation - do the unbound traversal step
return s.buildTraversalPatternRoot(partFrame, traversalStep)
}
}

func (s *Translator) buildTraversalPatternRoot(partFrame *Frame, traversalStep *TraversalStep) (pgsql.Query, error) {
if traversalStep.Direction == graph.DirectionBoth {
return s.buildDirectionlessTraversalPatternRoot(traversalStep)
Expand Down
34 changes: 34 additions & 0 deletions integration/testdata/bed6695.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
{
"graph": {
"nodes": [
{
"id": "key_admins_empty",
"kinds": ["NodeKind1"],
"properties": {
"name": "BED-6695 KEY ADMINS EMPTY"
}
},
{
"id": "key_admins_membered",
"kinds": ["NodeKind1"],
"properties": {
"name": "BED-6695 KEY ADMINS MEMBERED"
}
},
{
"id": "member_user",
"kinds": ["NodeKind2"],
"properties": {
"name": "BED-6695 MEMBER USER"
}
}
],
"edges": [
{
"start_id": "member_user",
"end_id": "key_admins_membered",
"kind": "EdgeKind1"
}
]
}
}
Loading
Loading