diff --git a/cypher/models/cypher/format/format.go b/cypher/models/cypher/format/format.go index 495cf806..f2f62c15 100644 --- a/cypher/models/cypher/format/format.go +++ b/cypher/models/cypher/format/format.go @@ -296,7 +296,7 @@ func (s Emitter) formatProjection(output io.Writer, projection *cypher.Projectio } func (s Emitter) formatReturn(output io.Writer, returnClause *cypher.Return) error { - if _, err := io.WriteString(output, " return "); err != nil { + if _, err := io.WriteString(output, "return "); err != nil { return err } @@ -1095,6 +1095,12 @@ func (s Emitter) formatSinglePartQuery(writer io.Writer, singlePartQuery *cypher } if singlePartQuery.Return != nil { + if len(singlePartQuery.ReadingClauses) > 0 || len(singlePartQuery.UpdatingClauses) > 0 { + if _, err := io.WriteString(writer, " "); err != nil { + return err + } + } + return s.formatReturn(writer, singlePartQuery.Return) } diff --git a/cypher/models/pgsql/test/query_test.go b/cypher/models/pgsql/test/query_test.go index f43979bc..58fbade9 100644 --- a/cypher/models/pgsql/test/query_test.go +++ b/cypher/models/pgsql/test/query_test.go @@ -5,12 +5,11 @@ import ( "slices" "testing" - "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/pgsql" "github.com/specterops/dawgs/cypher/models/pgsql/translate" "github.com/specterops/dawgs/cypher/models/walk" "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" + v2 "github.com/specterops/dawgs/query/v2" ) var ( @@ -24,19 +23,18 @@ var ( func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) { mapper := newKindMapper() - queries := []*cypher.Where{ - query.Where(query.KindIn(query.Node(), NodeKind1)), - query.Where(query.Kind(query.Node(), NodeKind2)), + queries := []v2.QueryBuilder{ + v2.New().Where(v2.KindIn(v2.Node(), NodeKind1)).Return(v2.Node()), + v2.New().Where(v2.Kind(v2.Node(), NodeKind2)).Return(v2.Node()), } - for _, nodeQuery := range queries { - builder := query.NewBuilderWithCriteria(nodeQuery) - builtQuery, err := builder.Build(false) + for _, queryBuilder := range queries { + builtQuery, err := queryBuilder.Build() if err != nil { t.Errorf("could not build query: %v", err) } - translatedQuery, err := translate.Translate(context.Background(), builtQuery, mapper, nil) + translatedQuery, err := translate.Translate(context.Background(), builtQuery.Query, mapper, builtQuery.Parameters) if err != nil { t.Errorf("could not translate query: %#v: %v", builtQuery, err) } @@ -47,7 +45,7 @@ func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) { switch leftTyped := typedNode.LOperand.(type) { case pgsql.CompoundIdentifier: if slices.Equal(leftTyped, pgsql.AsCompoundIdentifier("n0", "kind_ids")) && typedNode.Operator != pgsql.OperatorPGArrayOverlap { - t.Errorf("query did not generate an array overlap operator (&&): %#v", nodeQuery) + t.Errorf("query did not generate an array overlap operator (&&): %#v", builtQuery) } } } diff --git a/cypher/models/pgsql/translate/translator.go b/cypher/models/pgsql/translate/translator.go index 3b528cad..2b09f95e 100644 --- a/cypher/models/pgsql/translate/translator.go +++ b/cypher/models/pgsql/translate/translator.go @@ -222,6 +222,9 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.SetError(err) } + case *cypher.Create: + s.SetErrorf("pgsql translator does not support create clauses") + case *cypher.Delete: if err := s.translateDelete(s.scope, typedExpression); err != nil { s.SetError(err) diff --git a/query/neo4j/neo4j.go b/query/neo4j/neo4j.go index 0689f74f..7299d004 100644 --- a/query/neo4j/neo4j.go +++ b/query/neo4j/neo4j.go @@ -53,6 +53,14 @@ func (s *QueryBuilder) rewriteParameters() error { return nil } +func hasPreparedMatchPattern(readingClause *cypher.ReadingClause) bool { + if readingClause == nil || readingClause.Match == nil { + return false + } + + return len(readingClause.Match.Pattern) > 0 +} + func (s *QueryBuilder) Apply(criteria graph.Criteria) { switch typedCriteria := criteria.(type) { case *cypher.Where: @@ -201,6 +209,10 @@ func (s *QueryBuilder) prepareMatch() error { return ErrAmbiguousQueryVariables } + if firstReadingClause := query.GetFirstReadingClause(s.query); hasPreparedMatchPattern(firstReadingClause) { + return nil + } + if singleNodeBound && !creatingSingleNode { patternPart.AddPatternElements(&cypher.NodePattern{ Variable: cypher.NewVariableWithSymbol(query.NodeSymbol), diff --git a/query/v2/backend_test.go b/query/v2/backend_test.go new file mode 100644 index 00000000..ba64ce36 --- /dev/null +++ b/query/v2/backend_test.go @@ -0,0 +1,283 @@ +package v2_test + +import ( + "context" + "strings" + "testing" + + "github.com/specterops/dawgs/cypher/models/pgsql/translate" + "github.com/specterops/dawgs/drivers/pg/pgutil" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query/neo4j" + v2 "github.com/specterops/dawgs/query/v2" + "github.com/stretchr/testify/require" +) + +func testKindMapper(kinds ...graph.Kind) *pgutil.InMemoryKindMapper { + mapper := pgutil.NewInMemoryKindMapper() + + for _, kind := range kinds { + mapper.Put(kind) + } + + return mapper +} + +func TestBackendParityNeo4jPrepare(t *testing.T) { + cases := map[string]struct { + builder v2.QueryBuilder + expectedCypher string + expectedParams map[string]any + }{ + "node read": { + builder: v2.New().Where( + v2.Node().Kinds().Has(graph.StringKind("User")), + v2.Node().Property("name").Contains("admin"), + ).Return( + v2.Node(), + ).OrderBy( + v2.Node().Property("name"), + ), + expectedCypher: "match (n) where n:User and n.name contains $p0 return n order by n.name asc", + expectedParams: map[string]any{"p0": "admin"}, + }, + "relationship read": { + builder: v2.New().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + expectedCypher: "match (s)-[r:MemberOf]->(e) where id(s) = $p0 return id(s), id(r), id(e)", + expectedParams: map[string]any{"p0": 1}, + }, + "shortest path": { + builder: v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedCypher: "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", + expectedParams: map[string]any{"p0": 1, "p1": 2}, + }, + "all shortest paths": { + builder: v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedCypher: "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", + expectedParams: map[string]any{"p0": 1, "p1": 2}, + }, + "create node": { + builder: v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})), + ).Return( + v2.Node().ID(), + ), + expectedCypher: "create (n:User $p0) return id(n)", + expectedParams: map[string]any{"p0": map[string]any{"name": "u"}}, + }, + "update node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "updated"), + ), + expectedCypher: "match (n) where id(n) = $p0 set n.name = $p1", + expectedParams: map[string]any{"p0": 1, "p1": "updated"}, + }, + "delete relationship": { + builder: v2.New().Where( + v2.Relationship().ID().Equals(1), + ).Delete( + v2.Relationship(), + ), + expectedCypher: "match ()-[r]->() where id(r) = $p0 delete r", + expectedParams: map[string]any{"p0": 1}, + }, + "delete node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Delete( + v2.Node(), + ), + expectedCypher: "match (n) where id(n) = $p0 detach delete n", + expectedParams: map[string]any{"p0": 1}, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + queryBuilder := neo4j.NewQueryBuilder(preparedQuery.Query) + require.NoError(t, queryBuilder.Prepare()) + + rendered, err := queryBuilder.Render() + require.NoError(t, err) + require.Equal(t, testCase.expectedCypher, rendered) + require.Equal(t, testCase.expectedParams, queryBuilder.Parameters) + }) + } +} + +func TestBackendParityPGTranslate(t *testing.T) { + userKind := graph.StringKind("User") + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(userKind, edgeKind) + + cases := map[string]struct { + builder v2.QueryBuilder + expectedSQL string + expectedParams map[string]any + }{ + "node read": { + builder: v2.New().Where( + v2.Node().Kinds().Has(userKind), + v2.Node().Property("name").Contains("admin"), + ).Return( + v2.Node().ID(), + v2.Node().Kinds(), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and (n0.properties ->> 'name') like '%' || @pi0::text || '%')) select (s0.n0).id, (s0.n0).kind_ids from s0;", + expectedParams: map[string]any{"p0": "admin", "pi0": "admin"}, + }, + "relationship read": { + builder: v2.New().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on (n0.id = @pi0::int8) and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [2]::int2[])) select (s0.n0).id, (s0.e0).id, (s0.n1).id from s0;", + expectedParams: map[string]any{"p0": 1, "pi0": 1}, + }, + "update node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "updated"), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (update node n1 set properties = n1.properties || jsonb_build_object('name', @pi1::text)::jsonb from s0 where (s0.n0).id = n1.id returning (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n0) select 1;", + expectedParams: map[string]any{"p0": 1, "p1": "updated", "pi0": 1, "pi1": "updated"}, + }, + "delete relationship": { + builder: v2.New().Where( + v2.Relationship().ID().Equals(1), + ).Delete( + v2.Relationship(), + ), + expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.id = @pi0::int8)), s1 as (delete from edge e1 using s0 where (s0.e0).id = e1.id) select 1;", + expectedParams: map[string]any{"p0": 1, "pi0": 1}, + }, + "delete node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Delete( + v2.Node(), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (delete from node n1 using s0 where (s0.n0).id = n1.id) select 1;", + expectedParams: map[string]any{"p0": 1, "pi0": 1}, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + require.Equal(t, testCase.expectedSQL, sql) + require.Equal(t, testCase.expectedParams, translation.Parameters) + }) + } +} + +func TestBackendParityPGTranslateShortestPaths(t *testing.T) { + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(edgeKind) + + cases := map[string]struct { + builder v2.QueryBuilder + expectedHarness string + }{ + "shortest path": { + builder: v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedHarness: "unidirectional_sp_harness", + }, + "all shortest paths": { + builder: v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedHarness: "bidirectional_asp_harness", + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + require.Contains(t, sql, testCase.expectedHarness) + require.Contains(t, sql, "edges_to_path") + require.Equal(t, 1, translation.Parameters["p0"]) + require.Equal(t, 2, translation.Parameters["p1"]) + + serializedHarnessQueryHasKindConstraint := false + for _, parameterValue := range translation.Parameters { + if serializedQuery, typeOK := parameterValue.(string); typeOK && strings.Contains(serializedQuery, "array [1]::int2[]") { + serializedHarnessQueryHasKindConstraint = true + break + } + } + require.True(t, serializedHarnessQueryHasKindConstraint, "expected serialized shortest-path harness query to contain edge kind constraint: %#v", translation.Parameters) + }) + } +} + +func TestBackendParityPGCreateUnsupported(t *testing.T) { + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(edgeKind) + + preparedQuery, err := v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.RelationshipPattern(edgeKind, nil, graph.DirectionOutbound), + ).Return( + v2.Relationship().ID(), + ).Build() + require.NoError(t, err) + + _, err = translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters) + require.ErrorContains(t, err, "pgsql translator does not support create clauses") +} diff --git a/query/v2/compat.go b/query/v2/compat.go new file mode 100644 index 00000000..fb6be7b3 --- /dev/null +++ b/query/v2/compat.go @@ -0,0 +1,303 @@ +package v2 + +import ( + "fmt" + "strings" + "time" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" +) + +func Variable(name string) *cypher.Variable { + return cypher.NewVariableWithSymbol(name) +} + +func Identity(reference any) *cypher.FunctionInvocation { + return cypher.NewSimpleFunctionInvocation(cypher.IdentityFunction, expressionOrError(reference)) +} + +func NodeID() *cypher.FunctionInvocation { + return Identity(Identifiers.Node()) +} + +func RelationshipID() *cypher.FunctionInvocation { + return Identity(Identifiers.Relationship()) +} + +func StartID() *cypher.FunctionInvocation { + return Identity(Identifiers.Start()) +} + +func EndID() *cypher.FunctionInvocation { + return Identity(Identifiers.End()) +} + +func Count(reference any) *cypher.FunctionInvocation { + return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, expressionOrError(reference)) +} + +func CountDistinct(reference any) *cypher.FunctionInvocation { + return &cypher.FunctionInvocation{ + Name: cypher.CountFunction, + Distinct: true, + Arguments: []cypher.Expression{expressionOrError(reference)}, + } +} + +func Size(expression any) *cypher.FunctionInvocation { + return cypher.NewSimpleFunctionInvocation(cypher.ListSizeFunction, expressionOrError(expression)) +} + +func KindsOf(reference any) *cypher.FunctionInvocation { + if scopedReference, typeOK := reference.(scopedExpression); typeOK { + if variable, typeOK := scopedReference.qualifier().(*cypher.Variable); !typeOK { + return invalidExpression(fmt.Errorf("expected variable reference, got %T", scopedReference.qualifier())) + } else if expression, err := kindProjectionExpression(scopedReference.roleName(), variable); err != nil { + return invalidExpression(err) + } else if invocation, typeOK := expression.(*cypher.FunctionInvocation); !typeOK { + return invalidExpression(fmt.Errorf("expected kind projection function, got %T", expression)) + } else { + return invocation + } + } + + expression := expressionOrError(reference) + + switch typedExpression := expression.(type) { + case *cypher.Variable: + switch typedExpression.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedExpression) + + case Identifiers.relationship: + return cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedExpression) + } + } + + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, expression) +} + +func Kind(reference any, kinds ...graph.Kind) *cypher.KindMatcher { + return &cypher.KindMatcher{ + Reference: expressionOrError(reference), + Kinds: kinds, + } +} + +func KindIn(reference any, kinds ...graph.Kind) *cypher.KindMatcher { + return Kind(reference, kinds...) +} + +func AddKind(reference any, kind graph.Kind) *cypher.SetItem { + return AddKinds(reference, graph.Kinds{kind}) +} + +func AddKinds(reference any, kinds graph.Kinds) *cypher.SetItem { + return cypher.NewSetItem(expressionOrError(reference), cypher.OperatorLabelAssignment, kinds) +} + +func DeleteKind(reference any, kind graph.Kind) *cypher.RemoveItem { + return DeleteKinds(reference, graph.Kinds{kind}) +} + +func DeleteKinds(reference any, kinds graph.Kinds) *cypher.RemoveItem { + return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(expressionOrError(reference), kinds, false)) +} + +func SetProperty(reference any, value any) *cypher.SetItem { + return cypher.NewSetItem(expressionOrError(reference), cypher.OperatorAssignment, valueExpression(value)) +} + +func SetProperties(reference any, properties map[string]any) *cypher.Set { + set := &cypher.Set{} + + for _, key := range sortedPropertyKeys(properties) { + set.Items = append(set.Items, cypher.NewSetItem( + propertyLookupOrError(reference, key), + cypher.OperatorAssignment, + valueExpression(properties[key]), + )) + } + + return set +} + +func DeleteProperty(reference any) *cypher.RemoveItem { + return cypher.RemoveProperty(expressionOrError(reference)) +} + +func DeleteProperties(reference any, propertyNames ...string) *cypher.Remove { + remove := &cypher.Remove{} + + for _, propertyName := range propertyNames { + remove.Items = append(remove.Items, cypher.RemoveProperty(propertyLookupOrError(reference, propertyName))) + } + + return remove +} + +func NodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.Node(), + Kinds: kinds, + Properties: properties, + } +} + +func StartNodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.Start(), + Kinds: kinds, + Properties: properties, + } +} + +func EndNodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.End(), + Kinds: kinds, + Properties: properties, + } +} + +func RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) *cypher.RelationshipPattern { + return &cypher.RelationshipPattern{ + Variable: Identifiers.Relationship(), + Kinds: graph.Kinds{kind}, + Direction: direction, + Properties: properties, + } +} + +func Equals(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorEquals, valueExpression(value)) +} + +func GreaterThan(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorGreaterThan, valueExpression(value)) +} + +func After(reference any, value any) cypher.Expression { + return GreaterThan(reference, value) +} + +func GreaterThanOrEqualTo(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorGreaterThanOrEqualTo, valueExpression(value)) +} + +func GreaterThanOrEquals(reference any, value any) cypher.Expression { + return GreaterThanOrEqualTo(reference, value) +} + +func LessThan(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorLessThan, valueExpression(value)) +} + +func LessThanGraphQuery(reference any, other any) cypher.Expression { + return LessThan(reference, other) +} + +func Before(reference any, value time.Time) cypher.Expression { + return LessThan(reference, value) +} + +func BeforeGraphQuery(reference any, other any) cypher.Expression { + return LessThan(reference, other) +} + +func LessThanOrEqualTo(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorLessThanOrEqualTo, valueExpression(value)) +} + +func LessThanOrEquals(reference any, value any) cypher.Expression { + return LessThanOrEqualTo(reference, value) +} + +func In(reference any, value any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIn, valueExpression(value)) +} + +func InInverted(reference any, value any) cypher.Expression { + return cypher.NewComparison(valueExpression(value), cypher.OperatorIn, expressionOrError(reference)) +} + +func InIDs(reference any, ids ...graph.ID) cypher.Expression { + expression := expressionOrError(reference) + + if variable, typeOK := expression.(*cypher.Variable); typeOK { + expression = Identity(variable) + } + + return cypher.NewComparison(expression, cypher.OperatorIn, Parameter(ids)) +} + +func StringContains(reference any, value string) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorContains, Parameter(value)) +} + +func StringStartsWith(reference any, value string) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorStartsWith, Parameter(value)) +} + +func StringEndsWith(reference any, value string) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorEndsWith, Parameter(value)) +} + +func CaseInsensitiveStringContains(reference any, value string) cypher.Expression { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), + cypher.OperatorContains, + Parameter(strings.ToLower(value)), + ) +} + +func CaseInsensitiveStringStartsWith(reference any, value string) cypher.Expression { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), + cypher.OperatorStartsWith, + Parameter(strings.ToLower(value)), + ) +} + +func CaseInsensitiveStringEndsWith(reference any, value string) cypher.Expression { + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), + cypher.OperatorEndsWith, + Parameter(strings.ToLower(value)), + ) +} + +func Exists(reference any) cypher.Expression { + return IsNotNull(reference) +} + +func IsNull(reference any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIs, Literal(nil)) +} + +func IsNotNull(reference any) cypher.Expression { + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIsNot, Literal(nil)) +} + +func HasRelationships(reference any) *cypher.PatternPredicate { + patternPredicate := cypher.NewPatternPredicate() + + if variable, err := variableReference(reference); err != nil { + patternPredicate.AddElement(&cypher.NodePattern{ + Properties: invalidExpression(err), + }) + } else { + patternPredicate.AddElement(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(variable.Symbol), + }) + } + + patternPredicate.AddElement(&cypher.RelationshipPattern{ + Direction: graph.DirectionBoth, + }) + + patternPredicate.AddElement(&cypher.NodePattern{}) + + return patternPredicate +} diff --git a/query/v2/doc.go b/query/v2/doc.go new file mode 100644 index 00000000..a21a5004 --- /dev/null +++ b/query/v2/doc.go @@ -0,0 +1,6 @@ +// Package v2 contains the experimental fluent Cypher query builder. +// +// It is intentionally isolated from the stable query package so callers can +// opt in without pulling the current graph query APIs through a compatibility +// layer. +package v2 diff --git a/query/v2/query.go b/query/v2/query.go new file mode 100644 index 00000000..a388140c --- /dev/null +++ b/query/v2/query.go @@ -0,0 +1,1220 @@ +package v2 + +import ( + "errors" + "fmt" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" +) + +type runtimeIdentifiers struct { + path string + node string + start string + relationship string + end string +} + +func (s runtimeIdentifiers) Path() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.path) +} + +func (s runtimeIdentifiers) Node() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.node) +} + +func (s runtimeIdentifiers) Start() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.start) +} + +func (s runtimeIdentifiers) Relationship() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.relationship) +} + +func (s runtimeIdentifiers) End() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.end) +} + +var Identifiers = runtimeIdentifiers{ + path: "p", + node: "n", + start: "s", + relationship: "r", + end: "e", +} + +type Scope struct { + identifiers runtimeIdentifiers + errors []error +} + +func DefaultScope() Scope { + return Scope{ + identifiers: Identifiers, + } +} + +func NewScope(path, node, start, relationship, end string) Scope { + identifiers := runtimeIdentifiers{ + path: path, + node: node, + start: start, + relationship: relationship, + end: end, + } + + return Scope{ + identifiers: identifiers, + errors: validateRuntimeIdentifiers(identifiers), + } +} + +func validateRuntimeIdentifiers(identifiers runtimeIdentifiers) []error { + aliases := []struct { + role string + value string + }{ + {role: "path", value: identifiers.path}, + {role: "node", value: identifiers.node}, + {role: "start", value: identifiers.start}, + {role: "relationship", value: identifiers.relationship}, + {role: "end", value: identifiers.end}, + } + + var ( + errs []error + seen = map[string]string{} + ) + + for _, alias := range aliases { + if err := validateCypherSymbol(alias.value, "scope alias "+alias.role); err != nil { + errs = append(errs, err) + continue + } + + if existingRole, exists := seen[alias.value]; exists { + errs = append(errs, fmt.Errorf("scope aliases %s and %s both use %q", existingRole, alias.role, alias.value)) + } else { + seen[alias.value] = alias.role + } + } + + return errs +} + +func (s Scope) New() QueryBuilder { + return newBuilder(s.identifiers, s.errors...) +} + +func (s Scope) Node() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.Node(), + role: Identifiers.node, + } +} + +func (s Scope) Path() PathContinuation { + return &entity[PathContinuation]{ + identifier: s.identifiers.Path(), + role: Identifiers.path, + } +} + +func (s Scope) Start() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.Start(), + role: Identifiers.start, + } +} + +func (s Scope) Relationship() RelationshipContinuation { + return &entity[RelationshipContinuation]{ + identifier: s.identifiers.Relationship(), + role: Identifiers.relationship, + } +} + +func (s Scope) End() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.End(), + role: Identifiers.end, + } +} + +func Literal(value any) *cypher.Literal { + if value == nil { + return cypher.NewLiteral(nil, true) + } + + if strValue, typeOK := value.(string); typeOK { + return cypher.NewStringLiteral(strValue) + } + + return cypher.NewLiteral(value, false) +} + +func Parameter(value any) *cypher.Parameter { + if parameter, typeOK := value.(*cypher.Parameter); typeOK { + return parameter + } + + return &cypher.Parameter{ + Value: value, + } +} + +func NamedParameter(symbol string, value any) *cypher.Parameter { + return cypher.NewParameter(symbol, value) +} + +func valueExpression(value any) cypher.Expression { + switch typedValue := value.(type) { + case *cypher.Parameter: + return typedValue + case *cypher.Literal: + return typedValue + case *cypher.Variable: + return typedValue + case *cypher.PropertyLookup: + return typedValue + case *cypher.FunctionInvocation: + return typedValue + case *cypher.Parenthetical: + return typedValue + case *cypher.Comparison: + return typedValue + case *cypher.Negation: + return typedValue + case *cypher.Conjunction: + return typedValue + case *cypher.Disjunction: + return typedValue + case *cypher.ExclusiveDisjunction: + return typedValue + case *cypher.KindMatcher: + return typedValue + case *cypher.ListLiteral: + return typedValue + case cypher.MapLiteral: + return typedValue + case *cypher.PatternPredicate: + return typedValue + case *cypher.ArithmeticExpression: + return typedValue + case *cypher.UnaryAddOrSubtractExpression: + return typedValue + case *cypher.FilterExpression: + return typedValue + case *cypher.IDInCollection: + return typedValue + default: + return Parameter(value) + } +} + +func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) cypher.SyntaxNode { + expressionList := &cypher.Comparison{} + + if len(operands) > 0 { + expressionList.Left = operands[0] + + for _, operand := range operands[1:] { + expressionList.NewPartialComparison(operator, operand) + } + } + + return expressionList +} + +func Not(operand cypher.Expression) cypher.Expression { + return cypher.NewNegation(operand) +} + +func And(operands ...cypher.SyntaxNode) cypher.SyntaxNode { + return joinedExpressionList(cypher.OperatorAnd, operands) +} + +func Or(operands ...cypher.SyntaxNode) cypher.SyntaxNode { + return joinedExpressionList(cypher.OperatorOr, operands) +} + +type SortDirection int + +const ( + SortAscending SortDirection = iota + SortDescending +) + +func Asc(expression any) *cypher.SortItem { + return Order(expression, SortAscending) +} + +func Desc(expression any) *cypher.SortItem { + return Order(expression, SortDescending) +} + +func Order(expression any, direction SortDirection) *cypher.SortItem { + return &cypher.SortItem{ + Ascending: direction != SortDescending, + Expression: expressionOrError(expression), + } +} + +func As(expression any, alias string) *cypher.ProjectionItem { + return &cypher.ProjectionItem{ + Expression: expressionOrError(expression), + Alias: cypher.NewVariableWithSymbol(alias), + } +} + +func Node() NodeContinuation { + return DefaultScope().Node() +} + +func Path() PathContinuation { + return DefaultScope().Path() +} + +func Start() NodeContinuation { + return DefaultScope().Start() +} + +func Relationship() RelationshipContinuation { + return DefaultScope().Relationship() +} + +func End() NodeContinuation { + return DefaultScope().End() +} + +type QualifiedExpression interface { + qualifier() cypher.Expression +} + +type scopedExpression interface { + QualifiedExpression + + roleName() string +} + +type EntityContinuation interface { + QualifiedExpression + + Count() cypher.Expression + ID() IdentityContinuation + Property(name string) PropertyContinuation +} + +type KindContinuation interface { + Is(kind graph.Kind) cypher.Expression + IsOneOf(kinds graph.Kinds) cypher.Expression +} + +type KindsContinuation interface { + Has(kind graph.Kind) cypher.Expression + HasOneOf(kinds graph.Kinds) cypher.Expression + Add(kinds graph.Kinds) cypher.Expression + Remove(kinds graph.Kinds) cypher.Expression +} + +type Comparable interface { + In(value any) cypher.Expression + Contains(value any) cypher.Expression + StartsWith(value any) cypher.Expression + EndsWith(value any) cypher.Expression + Equals(value any) cypher.Expression + GreaterThan(value any) cypher.Expression + GreaterThanOrEqualTo(value any) cypher.Expression + LessThan(value any) cypher.Expression + LessThanOrEqualTo(value any) cypher.Expression + IsNull() cypher.Expression + IsNotNull() cypher.Expression +} + +type PropertyContinuation interface { + QualifiedExpression + Comparable + + Set(value any) *cypher.SetItem + Remove() *cypher.RemoveItem +} + +type IdentityContinuation interface { + QualifiedExpression + Comparable +} + +type comparisonContinuation struct { + qualifierExpression cypher.Expression +} + +func (s *comparisonContinuation) qualifier() cypher.Expression { + return s.qualifierExpression +} + +func (s *comparisonContinuation) asComparison(operator cypher.Operator, rOperand any) cypher.Expression { + return cypher.NewComparison( + s.qualifier(), + operator, + valueExpression(rOperand), + ) +} + +func (s *comparisonContinuation) In(value any) cypher.Expression { + return s.asComparison(cypher.OperatorIn, value) +} + +func (s *comparisonContinuation) Contains(value any) cypher.Expression { + return s.asComparison(cypher.OperatorContains, value) +} + +func (s *comparisonContinuation) StartsWith(value any) cypher.Expression { + return s.asComparison(cypher.OperatorStartsWith, value) +} + +func (s *comparisonContinuation) EndsWith(value any) cypher.Expression { + return s.asComparison(cypher.OperatorEndsWith, value) +} + +func (s *comparisonContinuation) Equals(value any) cypher.Expression { + return s.asComparison(cypher.OperatorEquals, value) +} + +func (s *comparisonContinuation) GreaterThan(value any) cypher.Expression { + return s.asComparison(cypher.OperatorGreaterThan, value) +} + +func (s *comparisonContinuation) GreaterThanOrEqualTo(value any) cypher.Expression { + return s.asComparison(cypher.OperatorGreaterThanOrEqualTo, value) +} + +func (s *comparisonContinuation) LessThan(value any) cypher.Expression { + return s.asComparison(cypher.OperatorLessThan, value) +} + +func (s *comparisonContinuation) LessThanOrEqualTo(value any) cypher.Expression { + return s.asComparison(cypher.OperatorLessThanOrEqualTo, value) +} + +func (s *comparisonContinuation) IsNull() cypher.Expression { + return cypher.NewComparison(s.qualifier(), cypher.OperatorIs, Literal(nil)) +} + +func (s *comparisonContinuation) IsNotNull() cypher.Expression { + return cypher.NewComparison(s.qualifier(), cypher.OperatorIsNot, Literal(nil)) +} + +type propertyContinuation struct { + comparisonContinuation +} + +func (s *propertyContinuation) Set(value any) *cypher.SetItem { + return cypher.NewSetItem( + s.qualifier(), + cypher.OperatorAssignment, + valueExpression(value), + ) +} + +func (s *propertyContinuation) Remove() *cypher.RemoveItem { + return cypher.RemoveProperty(s.qualifier()) +} + +type entity[T any] struct { + identifier *cypher.Variable + role string +} + +func (s *entity[T]) Kind() KindContinuation { + return kindContinuation{ + identifier: s.identifier, + role: s.role, + } +} + +func (s *entity[T]) Kinds() KindsContinuation { + return kindsContinuation{ + identifier: s.identifier, + role: s.role, + } +} + +func (s *entity[T]) Count() cypher.Expression { + return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, s.identifier) +} + +func (s *entity[T]) SetProperties(properties map[string]any) cypher.Expression { + set := &cypher.Set{} + + for _, key := range sortedPropertyKeys(properties) { + set.Items = append(set.Items, s.Property(key).Set(properties[key])) + } + + return set +} + +func (s *entity[T]) RemoveProperties(properties []string) cypher.Expression { + remove := &cypher.Remove{} + + for _, key := range properties { + remove.Items = append(remove.Items, s.Property(key).Remove()) + } + + return remove +} + +func (s *entity[T]) RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression { + return &cypher.RelationshipPattern{ + Variable: s.identifier, + Kinds: graph.Kinds{kind}, + Direction: direction, + Properties: properties, + } +} + +func (s *entity[T]) NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression { + return &cypher.NodePattern{ + Variable: s.identifier, + Kinds: kinds, + Properties: properties, + } +} + +func (s *entity[T]) qualifier() cypher.Expression { + return s.identifier +} + +func (s *entity[T]) roleName() string { + return s.role +} + +func (s *entity[T]) ID() IdentityContinuation { + return &comparisonContinuation{ + qualifierExpression: &cypher.FunctionInvocation{ + Distinct: false, + Name: cypher.IdentityFunction, + Arguments: []cypher.Expression{s.identifier}, + }, + } +} + +func (s *entity[T]) Property(propertyName string) PropertyContinuation { + return &propertyContinuation{ + comparisonContinuation: comparisonContinuation{ + qualifierExpression: cypher.NewPropertyLookup(s.identifier.Symbol, propertyName), + }, + } +} + +type kindContinuation struct { + identifier *cypher.Variable + role string +} + +func (s kindContinuation) qualifier() cypher.Expression { + return s.identifier +} + +func (s kindContinuation) roleName() string { + return s.role +} + +func (s kindContinuation) Is(kind graph.Kind) cypher.Expression { + return s.IsOneOf(graph.Kinds{kind}) +} + +func (s kindContinuation) IsOneOf(kinds graph.Kinds) cypher.Expression { + return &cypher.KindMatcher{ + Reference: s.identifier, + Kinds: kinds, + } +} + +type kindsContinuation struct { + identifier *cypher.Variable + role string +} + +func (s kindsContinuation) qualifier() cypher.Expression { + return s.identifier +} + +func (s kindsContinuation) roleName() string { + return s.role +} + +func (s kindsContinuation) Has(kind graph.Kind) cypher.Expression { + return s.HasOneOf(graph.Kinds{kind}) +} + +func (s kindsContinuation) HasOneOf(kinds graph.Kinds) cypher.Expression { + return &cypher.KindMatcher{ + Reference: s.identifier, + Kinds: kinds, + } +} + +func (s kindsContinuation) Add(kinds graph.Kinds) cypher.Expression { + return cypher.NewSetItem( + s.identifier, + cypher.OperatorLabelAssignment, + kinds, + ) +} + +func (s kindsContinuation) Remove(kinds graph.Kinds) cypher.Expression { + return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(s.identifier, kinds, false)) +} + +type PathContinuation interface { + QualifiedExpression + + Count() cypher.Expression +} + +type RelationshipContinuation interface { + EntityContinuation + + RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression + + Kind() KindContinuation + SetProperties(properties map[string]any) cypher.Expression + RemoveProperties(properties []string) cypher.Expression +} + +type NodeContinuation interface { + EntityContinuation + + NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression + + Kinds() KindsContinuation + SetProperties(properties map[string]any) cypher.Expression + RemoveProperties(properties []string) cypher.Expression +} + +type QueryBuilder interface { + Where(constraints ...cypher.SyntaxNode) QueryBuilder + OrderBy(sortItems ...any) QueryBuilder + Skip(offset int) QueryBuilder + Limit(limit int) QueryBuilder + Return(projections ...any) QueryBuilder + ReturnDistinct(projections ...any) QueryBuilder + Update(updatingClauses ...any) QueryBuilder + Create(creationClauses ...any) QueryBuilder + Delete(expressions ...any) QueryBuilder + WithShortestPaths() QueryBuilder + WithAllShortestPaths() QueryBuilder + WithRelationshipDirection(direction graph.Direction) QueryBuilder + Build() (*PreparedQuery, error) +} + +type updatingClauseKind int + +const ( + updatingClauseSet updatingClauseKind = iota + updatingClauseRemove + updatingClauseDelete + updatingClauseCreate +) + +type pendingUpdatingClause struct { + kind updatingClauseKind + creates []any + setItems []*cypher.SetItem + removeItems []*cypher.RemoveItem + deleteItems []cypher.Expression + detach bool +} + +type builder struct { + errors []error + constraints []cypher.SyntaxNode + sortItems []any + projections []any + distinct bool + identifiers runtimeIdentifiers + updatingClauses []pendingUpdatingClause + creates []any + setItems []*cypher.SetItem + removeItems []*cypher.RemoveItem + deleteItems []cypher.Expression + detachDelete bool + relationshipDirection graph.Direction + shortestPathQuery bool + allShorestPathsQuery bool + skip *int + limit *int +} + +func New() QueryBuilder { + return DefaultScope().New() +} + +func newBuilder(identifiers runtimeIdentifiers, errs ...error) QueryBuilder { + return &builder{ + identifiers: identifiers, + errors: append([]error(nil), errs...), + relationshipDirection: graph.DirectionOutbound, + } +} + +func (s *builder) WithShortestPaths() QueryBuilder { + s.shortestPathQuery = true + return s +} + +func (s *builder) WithAllShortestPaths() QueryBuilder { + s.allShorestPathsQuery = true + return s +} + +func (s *builder) WithRelationshipDirection(direction graph.Direction) QueryBuilder { + if err := validateRelationshipDirection(direction); err != nil { + s.trackError(err) + } else { + s.relationshipDirection = direction + } + + return s +} + +func (s *builder) OrderBy(sortItems ...any) QueryBuilder { + s.sortItems = append(s.sortItems, sortItems...) + return s +} + +func (s *builder) Skip(skip int) QueryBuilder { + s.skip = &skip + return s +} + +func (s *builder) Limit(limit int) QueryBuilder { + s.limit = &limit + return s +} + +func (s *builder) Return(projections ...any) QueryBuilder { + s.projections = append(s.projections, projections...) + return s +} + +func (s *builder) ReturnDistinct(projections ...any) QueryBuilder { + s.distinct = true + s.projections = append(s.projections, projections...) + return s +} + +func (s *builder) appendSetItems(items ...*cypher.SetItem) { + if len(items) == 0 { + return + } + + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseSet { + s.updatingClauses[lastClauseIdx].setItems = append(s.updatingClauses[lastClauseIdx].setItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseSet, + setItems: items, + }) + } +} + +func (s *builder) appendRemoveItems(items ...*cypher.RemoveItem) { + if len(items) == 0 { + return + } + + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseRemove { + s.updatingClauses[lastClauseIdx].removeItems = append(s.updatingClauses[lastClauseIdx].removeItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseRemove, + removeItems: items, + }) + } +} + +func (s *builder) appendDeleteItems(detach bool, items ...cypher.Expression) { + if len(items) == 0 { + return + } + + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseDelete { + s.updatingClauses[lastClauseIdx].detach = s.updatingClauses[lastClauseIdx].detach || detach + s.updatingClauses[lastClauseIdx].deleteItems = append(s.updatingClauses[lastClauseIdx].deleteItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseDelete, + deleteItems: items, + detach: detach, + }) + } +} + +func (s *builder) Create(creationClauses ...any) QueryBuilder { + s.creates = append(s.creates, creationClauses...) + + if len(creationClauses) > 0 { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseCreate, + creates: creationClauses, + }) + } + + return s +} + +func (s *builder) Update(updates ...any) QueryBuilder { + for _, nextUpdate := range updates { + switch typedNextUpdate := nextUpdate.(type) { + case *cypher.Set: + if setItems, err := setItemsFromSet(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.setItems = append(s.setItems, setItems...) + s.appendSetItems(setItems...) + } + + case *cypher.SetItem: + if setItem, err := setItemFromValue(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.setItems = append(s.setItems, setItem) + s.appendSetItems(setItem) + } + + case *cypher.Remove: + if removeItems, err := removeItemsFromRemove(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.removeItems = append(s.removeItems, removeItems...) + s.appendRemoveItems(removeItems...) + } + + case *cypher.RemoveItem: + if removeItem, err := removeItemFromValue(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.removeItems = append(s.removeItems, removeItem) + s.appendRemoveItems(removeItem) + } + + default: + s.trackError(fmt.Errorf("unknown update type: %T", nextUpdate)) + } + } + + return s +} + +func (s *builder) Delete(deleteItems ...any) QueryBuilder { + var pendingDeleteItems []cypher.Expression + pendingDetachDelete := false + + for _, nextDelete := range deleteItems { + switch typedNextUpdate := nextDelete.(type) { + case QualifiedExpression: + qualifier := typedNextUpdate.qualifier() + if err := validateExpressionValue(qualifier, "delete expression"); err != nil { + s.trackError(err) + continue + } + + if isDetachDeleteQualifier(qualifier, s.identifiers) { + s.detachDelete = true + pendingDetachDelete = true + } + + s.deleteItems = append(s.deleteItems, qualifier) + pendingDeleteItems = append(pendingDeleteItems, qualifier) + + case *cypher.Variable: + if err := validateExpressionValue(typedNextUpdate, "delete expression"); err != nil { + s.trackError(err) + continue + } + + switch typedNextUpdate.Symbol { + case s.identifiers.node, s.identifiers.start, s.identifiers.end: + s.detachDelete = true + pendingDetachDelete = true + } + + s.deleteItems = append(s.deleteItems, typedNextUpdate) + pendingDeleteItems = append(pendingDeleteItems, typedNextUpdate) + + default: + s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete)) + } + } + + s.appendDeleteItems(pendingDetachDelete, pendingDeleteItems...) + return s +} + +func (s *builder) trackError(err error) { + s.errors = append(s.errors, err) +} + +func (s *builder) Where(constraints ...cypher.SyntaxNode) QueryBuilder { + s.constraints = append(s.constraints, constraints...) + return s +} + +func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeIdentifiers, creates []any) error { + if len(creates) == 0 { + return nil + } + + var ( + pattern = &cypher.PatternPart{} + createClause = &cypher.Create{ + Unique: false, + Pattern: []*cypher.PatternPart{pattern}, + } + ) + + for _, nextCreate := range creates { + switch typedNextCreate := nextCreate.(type) { + case QualifiedExpression: + switch typedExpression := typedNextCreate.qualifier().(type) { + case *cypher.Variable: + if typedExpression == nil { + return fmt.Errorf("invalid variable reference for create: ") + } + + switch typedExpression.Symbol { + case identifiers.node, identifiers.start, identifiers.end: + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(typedExpression.Symbol), + }) + + default: + return fmt.Errorf("invalid variable reference for create: %s", typedExpression.Symbol) + } + + default: + return fmt.Errorf("invalid qualified expression for create: %T", typedExpression) + } + + case *cypher.NodePattern: + if err := validateNodePattern(typedNextCreate); err != nil { + return err + } + + pattern.AddPatternElements(typedNextCreate) + + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedNextCreate); err != nil { + return err + } + + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Start(), + }) + + pattern.AddPatternElements(typedNextCreate) + + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.End(), + }) + + default: + return fmt.Errorf("invalid type for create: %T", nextCreate) + } + } + + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause(createClause)) + return nil +} + +func (s *builder) buildUpdatingClauses(singlePartQuery *cypher.SinglePartQuery) error { + for _, updatingClause := range s.updatingClauses { + switch updatingClause.kind { + case updatingClauseSet: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewSet(updatingClause.setItems), + )) + + case updatingClauseRemove: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewRemove(updatingClause.removeItems), + )) + + case updatingClauseDelete: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewDelete( + updatingClause.detach, + updatingClause.deleteItems, + ), + )) + + case updatingClauseCreate: + if err := buildCreates(singlePartQuery, s.identifiers, updatingClause.creates); err != nil { + return err + } + } + } + + return nil +} + +func (s *builder) buildProjectionOrder() (*cypher.Order, error) { + var orderByNode *cypher.Order + + if len(s.sortItems) > 0 { + orderByNode = &cypher.Order{} + + for _, untypedSortItem := range s.sortItems { + switch typedSortItem := untypedSortItem.(type) { + case *cypher.Order: + if sortItems, err := sortItemsFromOrder(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItems...) + } + + case *cypher.SortItem: + if sortItem, err := sortItemFromValue(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItem) + } + + default: + if sortItem, err := sortItemFromValue(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItem) + } + } + } + } + + return orderByNode, nil +} + +func appendProjectionOrder(projection *cypher.Projection, sortItems ...*cypher.SortItem) { + if len(sortItems) == 0 { + return + } + + if projection.Order == nil { + projection.Order = &cypher.Order{} + } + + projection.Order.Items = append(projection.Order.Items, sortItems...) +} + +func applyReturnProjection(projection *cypher.Projection, returnClause *cypher.Return) error { + if projectionItems, err := projectionItemsFromReturn(returnClause); err != nil { + return err + } else { + projection.Distinct = projection.Distinct || returnClause.Projection.Distinct + projection.All = projection.All || returnClause.Projection.All + + for _, projectionItem := range projectionItems { + projection.AddItem(projectionItem) + } + } + + if returnClause.Projection.Order != nil { + if sortItems, err := sortItemsFromOrder(returnClause.Projection.Order); err != nil { + return err + } else { + appendProjectionOrder(projection, sortItems...) + } + } + + if returnClause.Projection.Skip != nil { + projection.Skip = returnClause.Projection.Skip + } + + if returnClause.Projection.Limit != nil { + projection.Limit = returnClause.Projection.Limit + } + + return nil +} + +func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error { + var ( + hasProjectedItems = len(s.projections) > 0 + hasSkip = s.skip != nil && *s.skip > 0 + hasLimit = s.limit != nil && *s.limit > 0 + requiresProjection = hasProjectedItems || hasSkip || hasLimit + ) + + if requiresProjection { + if !hasProjectedItems { + return fmt.Errorf("query expected projected items") + } + + projection := singlePartQuery.NewProjection(s.distinct) + + for _, nextProjection := range s.projections { + switch typedNextProjection := nextProjection.(type) { + case *cypher.Return: + if err := applyReturnProjection(projection, typedNextProjection); err != nil { + return err + } + + default: + if projectionItem, err := projectionItemFromValue(typedNextProjection); err != nil { + return err + } else { + projection.AddItem(projectionItem) + } + } + } + + if s.skip != nil && *s.skip > 0 { + projection.Skip = cypher.NewSkip(*s.skip) + } + + if s.limit != nil && *s.limit > 0 { + projection.Limit = cypher.NewLimit(*s.limit) + } + + if projectionOrder, err := s.buildProjectionOrder(); err != nil { + return err + } else if projectionOrder != nil { + appendProjectionOrder(projection, projectionOrder.Items...) + } + } + + return nil +} + +type PreparedQuery struct { + Query *cypher.RegularQuery + Parameters map[string]any +} + +func (s *builder) hasActions() bool { + return len(s.projections) > 0 || len(s.setItems) > 0 || len(s.removeItems) > 0 || len(s.creates) > 0 || len(s.deleteItems) > 0 +} + +func (s *builder) Build() (*PreparedQuery, error) { + if len(s.errors) > 0 { + return nil, errors.Join(s.errors...) + } + + if !s.hasActions() { + return nil, fmt.Errorf("query has no action specified") + } + + if err := collectModelErrorsFromKnownValues(s.constraints, s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems); err != nil { + return nil, err + } + + var ( + regularQuery, singlePartQuery = cypher.NewRegularQueryWithSingleQuery() + match = &cypher.Match{} + readIdentifiers = newIdentifierSet() + relationshipKinds graph.Kinds + ) + + createScope, err := collectCreateScope(s.identifiers, s.creates...) + if err != nil { + return nil, err + } + + if err := s.buildUpdatingClauses(singlePartQuery); err != nil { + return nil, err + } + + if err := s.buildProjection(singlePartQuery); err != nil { + return nil, err + } + + if len(s.constraints) > 0 { + var ( + whereClause = match.NewWhere() + constraints = &cypher.Comparison{} + ) + + for _, nextConstraint := range s.constraints { + if err := collectModelErrorsFromKnownValues(nextConstraint); err != nil { + return nil, err + } + + switch typedNextConstraint := nextConstraint.(type) { + case *cypher.KindMatcher: + if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK { + return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint) + } else if identifier.Symbol == s.identifiers.relationship { + relationshipKinds = relationshipKinds.Add(typedNextConstraint.Kinds...) + readIdentifiers.Add(s.identifiers.relationship) + continue + } + } + + if constraints.Left == nil { + constraints.Left = nextConstraint + } else { + constraints.NewPartialComparison(cypher.OperatorAnd, nextConstraint) + } + } + + if constraints.Left != nil { + whereClause.Add(constraints) + + if err := readIdentifiers.CollectFromExpression(whereClause); err != nil { + return nil, err + } + } + } + + actionIdentifiers, err := collectIdentifiersFromValues(s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems) + if err != nil { + return nil, err + } + + actionIdentifiers.Remove(createScope.identifiers) + + matchIdentifiers := readIdentifiers.Clone() + matchIdentifiers.Or(actionIdentifiers) + + if len(s.constraints) > 0 || matchIdentifiers.Len() > 0 { + if isNodePattern(matchIdentifiers, s.identifiers) { + if err := prepareNodePattern(match, matchIdentifiers, s.identifiers); err != nil { + return nil, err + } + } else if createScope.createsRelationship && !matchIdentifiers.Contains(s.identifiers.relationship) { + if err := prepareCreateRelationshipMatch(match, matchIdentifiers, s.identifiers); err != nil { + return nil, err + } + } else if isRelationshipPattern(matchIdentifiers, s.identifiers) { + if err := prepareRelationshipPattern(match, matchIdentifiers, s.identifiers, relationshipKinds, s.relationshipDirection, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("query has no node and relationship query identifiers specified") + } + } + + if len(match.Pattern) > 0 { + newReadingClause := cypher.NewReadingClause() + newReadingClause.Match = match + + singlePartQuery.ReadingClauses = append(singlePartQuery.ReadingClauses, newReadingClause) + } + + if err := collectModelErrors(regularQuery); err != nil { + return nil, err + } + + if parameters, err := materializeParameters(regularQuery); err != nil { + return nil, err + } else { + return &PreparedQuery{ + Query: regularQuery, + Parameters: parameters, + }, nil + } +} diff --git a/query/v2/query_test.go b/query/v2/query_test.go new file mode 100644 index 00000000..da3e80cd --- /dev/null +++ b/query/v2/query_test.go @@ -0,0 +1,517 @@ +package v2_test + +import ( + "testing" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/cypher/format" + "github.com/specterops/dawgs/graph" + v2 "github.com/specterops/dawgs/query/v2" + "github.com/stretchr/testify/require" +) + +func renderPrepared(t *testing.T, preparedQuery *v2.PreparedQuery) string { + t.Helper() + + cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + return cypherQueryStr +} + +func TestQuery(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Not(v2.Relationship().Kind().Is(graph.StringKind("test"))), + v2.Not(v2.Relationship().Kind().IsOneOf(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")})), + v2.Relationship().Property("rel_prop").LessThanOrEqualTo(1234), + v2.Relationship().Property("other_prop").Equals(5678), + v2.Start().Kinds().HasOneOf(graph.Kinds{graph.StringKind("test")}), + ).Update( + v2.Start().Property("this_prop").Set(1234), + v2.End().Kinds().Remove(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")}), + ).Delete( + v2.Start(), + ).Return( + v2.Relationship(), + v2.Start().Property("node_prop"), + ).Skip(10).Limit(10).Build() + require.NoError(t, err) + + cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + require.Equal(t, "match (s)-[r]->(e) where not r:test and not (r:A or r:B) and r.rel_prop <= $p0 and r.other_prop = $p1 and s:test set s.this_prop = $p2 remove e:A:B detach delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) + require.Equal(t, map[string]any{ + "p0": 1234, + "p1": 5678, + "p2": 1234, + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().Create( + v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, cypher.NewParameter("props", map[string]any{})), + ).Build() + + require.NoError(t, err) + + cypherQueryStr, err = format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + require.Equal(t, "create (n:A $props)", cypherQueryStr) + require.Equal(t, map[string]any{ + "props": map[string]any{}, + }, preparedQuery.Parameters) +} + +func TestCreateRelationshipWithMatchedEndpoints(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.Relationship().RelationshipPattern(graph.StringKind("A"), v2.NamedParameter("props", map[string]any{"name": "rel"}), graph.DirectionOutbound), + ).Return( + v2.Relationship().ID(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:A $props]->(e) return id(r)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "props": map[string]any{"name": "rel"}, + }, preparedQuery.Parameters) +} + +func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, v2.NamedParameter("props", map[string]any{"name": "node"})), + ).Return( + v2.Node().ID(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (n:A $props) return id(n)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "props": map[string]any{"name": "node"}, + }, preparedQuery.Parameters) +} + +func TestInvalidCreateQualifiedExpressionReturnsError(t *testing.T) { + _, err := v2.New().Create(v2.Node().Property("name")).Build() + require.ErrorContains(t, err, "invalid qualified expression for create: *cypher.PropertyLookup") +} + +func TestUpdatingClausesPreserveFluentOrder(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("User")}, nil), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "created"), + ).Return( + v2.Node().Property("name"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (n:User) set n.name = $p0 return n.name", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": "created", + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.DeleteProperties(v2.Node(), "old"), + ).Update( + v2.SetProperty(v2.Node().Property("new"), "value"), + ).Return( + v2.Node(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) = $p0 remove n.old set n.new = $p1 return n", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": "value", + }, preparedQuery.Parameters) +} + +func TestScopedRelationshipPatternControls(t *testing.T) { + scope := v2.NewScope("path", "person", "source", "edge", "target") + + preparedQuery, err := scope.New().WithRelationshipDirection(graph.DirectionInbound).Where( + scope.Relationship().Kind().Is(graph.StringKind("MemberOf")), + scope.Start().ID().Equals(1), + ).Return( + scope.Relationship().Kind(), + scope.End().Kinds(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (source)<-[edge:MemberOf]-(target) where id(source) = $p0 return type(edge), labels(target)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + }, preparedQuery.Parameters) +} + +func TestScopedKindsOfCompatibilityHelper(t *testing.T) { + scope := v2.NewScope("path", "person", "source", "edge", "target") + + preparedQuery, err := scope.New().Return( + v2.KindsOf(scope.Relationship()), + v2.KindsOf(scope.End()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match ()-[edge]->(target) return type(edge), labels(target)", renderPrepared(t, preparedQuery)) +} + +func TestInvalidScopeAliasesReturnBuildErrors(t *testing.T) { + emptyAliasScope := v2.NewScope("", "node", "start", "relationship", "end") + _, err := emptyAliasScope.New().Return(emptyAliasScope.Node()).Build() + require.ErrorContains(t, err, "scope alias path is empty") + + duplicateAliasScope := v2.NewScope("path", "node", "node", "relationship", "end") + _, err = duplicateAliasScope.New().Return(duplicateAliasScope.Start()).Build() + require.ErrorContains(t, err, `scope aliases node and start both use "node"`) + + invalidAliasScope := v2.NewScope("path", "bad name", "start", "relationship", "end") + _, err = invalidAliasScope.New().Return(invalidAliasScope.Node()).Build() + require.ErrorContains(t, err, `scope alias node has invalid symbol "bad name"`) +} + +func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { + _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build() + require.ErrorContains(t, err, "unsupported relationship direction: invalid") + + _, err = v2.New().WithRelationshipDirection(graph.DirectionBoth).Return(v2.Relationship()).Build() + require.ErrorContains(t, err, "unsupported relationship direction: both") +} + +func TestShortestPathControls(t *testing.T) { + preparedQuery, err := v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + + _, err = v2.New().WithShortestPaths().WithAllShortestPaths().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.ErrorContains(t, err, "query is requesting both all shortest paths and shortest paths") +} + +func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) { + _, err := v2.New().Where( + v2.Node().ID().Equals(1), + v2.Relationship().ID().Equals(2), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, "query mixes node and relationship query identifiers") +} + +func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) { + _, err := v2.New().Create( + v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth), + ).Build() + require.ErrorContains(t, err, "unsupported relationship direction: both") + + _, err = v2.New().Create( + v2.Relationship().RelationshipPattern(graph.StringKind("Edge"), nil, graph.Direction(99)), + ).Build() + require.ErrorContains(t, err, "unsupported relationship direction: invalid") +} + +func TestProjectionAndOrderHelpers(t *testing.T) { + preparedQuery, err := v2.New().ReturnDistinct( + v2.As(v2.Node().ID(), "node_id"), + ).OrderBy( + v2.Node().Property("name"), + v2.Desc(v2.Node().ID()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) as node_id order by n.name asc, id(n) desc", renderPrepared(t, preparedQuery)) +} + +func TestProjectionAliasDoesNotCreateMatchInference(t *testing.T) { + preparedQuery, err := v2.New().Return( + v2.As(v2.Literal(1), "one"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "return 1 as one", renderPrepared(t, preparedQuery)) + require.Empty(t, preparedQuery.Parameters) +} + +func TestInvalidProjectionAliasReturnsBuildError(t *testing.T) { + _, err := v2.New().Return(v2.As(v2.Literal(1), "bad alias")).Build() + require.ErrorContains(t, err, `projection alias has invalid symbol "bad alias"`) + + _, err = v2.New().Return(&cypher.ProjectionItem{ + Expression: v2.Literal(1), + Alias: cypher.NewVariableWithSymbol("1bad"), + }).Build() + require.ErrorContains(t, err, `projection alias has invalid symbol "1bad"`) +} + +func TestUnsupportedOrderByTypeReturnsError(t *testing.T) { + _, err := v2.New().Return(v2.Node()).OrderBy(123).Build() + require.ErrorContains(t, err, "unsupported expression type: int") +} + +func TestRawProjectionAndOrderInputsAreValidated(t *testing.T) { + _, err := v2.New().Return(&cypher.Return{}).Build() + require.ErrorContains(t, err, "return clause has nil projection") + + returnClause := cypher.NewReturn() + returnClause.NewProjection(false).Items = append(returnClause.Projection.Items, &cypher.ProjectionItem{}) + _, err = v2.New().Return(returnClause).Build() + require.ErrorContains(t, err, "projection item has nil expression") + + _, err = v2.New().Return(v2.Node()).OrderBy(&cypher.SortItem{}).Build() + require.ErrorContains(t, err, "sort item has nil expression") + + _, err = v2.New().Return(v2.Node()).OrderBy(&cypher.Order{ + Items: []*cypher.SortItem{{}}, + }).Build() + require.ErrorContains(t, err, "sort item has nil expression") +} + +func TestRawProjectionAndOrderInputsAreNormalized(t *testing.T) { + returnClause := cypher.NewReturn() + returnClause.NewProjection(false).Items = append(returnClause.Projection.Items, v2.Node().ID()) + + preparedQuery, err := v2.New().Return(returnClause).OrderBy(&cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + }).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return id(n) order by n.name desc", renderPrepared(t, preparedQuery)) +} + +func TestRawReturnInputPreservesProjectionMetadata(t *testing.T) { + returnClause := cypher.NewReturn() + projection := returnClause.NewProjection(true) + projection.Items = append(projection.Items, v2.Node().ID()) + projection.Order = &cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + } + projection.Skip = cypher.NewSkip(5) + projection.Limit = cypher.NewLimit(10) + + preparedQuery, err := v2.New().Return(returnClause).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) order by n.name desc skip 5 limit 10", renderPrepared(t, preparedQuery)) +} + +func TestRawReturnInputMergesWithBuilderProjectionControls(t *testing.T) { + returnClause := cypher.NewReturn() + projection := returnClause.NewProjection(true) + projection.Items = append(projection.Items, v2.Node().ID()) + projection.Order = &cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + } + projection.Skip = cypher.NewSkip(5) + projection.Limit = cypher.NewLimit(10) + + preparedQuery, err := v2.New().Return(returnClause).OrderBy( + v2.Asc(v2.Node().Property("created_at")), + ).Skip(15).Limit(20).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) order by n.name desc, n.created_at asc skip 15 limit 20", renderPrepared(t, preparedQuery)) +} + +func TestRawUpdatingInputsAreValidated(t *testing.T) { + var setClause *cypher.Set + _, err := v2.New().Update(setClause).Build() + require.ErrorContains(t, err, "set clause is nil") + + _, err = v2.New().Update(&cypher.Set{Items: []*cypher.SetItem{nil}}).Build() + require.ErrorContains(t, err, "set item is nil") + + _, err = v2.New().Update(&cypher.SetItem{}).Build() + require.ErrorContains(t, err, "set item left has nil expression") + + var removeClause *cypher.Remove + _, err = v2.New().Update(removeClause).Build() + require.ErrorContains(t, err, "remove clause is nil") + + _, err = v2.New().Update(&cypher.Remove{Items: []*cypher.RemoveItem{nil}}).Build() + require.ErrorContains(t, err, "remove item is nil") + + _, err = v2.New().Update(&cypher.RemoveItem{}).Build() + require.ErrorContains(t, err, "remove item has no target") + + var deleteVariable *cypher.Variable + _, err = v2.New().Delete(deleteVariable).Build() + require.ErrorContains(t, err, "delete expression has nil expression") + + var nodePattern *cypher.NodePattern + _, err = v2.New().Create(nodePattern).Build() + require.ErrorContains(t, err, "node pattern is nil") + + var relationshipPattern *cypher.RelationshipPattern + _, err = v2.New().Create(relationshipPattern).Build() + require.ErrorContains(t, err, "relationship pattern is nil") +} + +func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) { + cases := map[string]struct { + builder v2.QueryBuilder + err string + }{ + "aliased projection": { + builder: v2.New().Return(v2.As(123, "bad")), + err: "unsupported expression type: int", + }, + "sort item": { + builder: v2.New().Return(v2.Node()).OrderBy(v2.Desc(123)), + err: "unsupported expression type: int", + }, + "set properties": { + builder: v2.New().Update(v2.SetProperties(123, map[string]any{"name": "bad"})), + err: "unsupported expression type: int", + }, + "delete properties": { + builder: v2.New().Update(v2.DeleteProperties(123, "name")), + err: "unsupported expression type: int", + }, + "pattern predicate": { + builder: v2.New().Where(v2.HasRelationships(123)).Return(v2.Node()), + err: "unsupported expression type: int", + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + _, err := testCase.builder.Build() + require.ErrorContains(t, err, testCase.err) + }) + } +} + +func TestNamedParameterMaterialization(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Node().Property("first").Equals("auto"), + v2.Node().Property("second").Equals(v2.NamedParameter("p0", "named")), + ).Return( + v2.Node(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where n.first = $p1 and n.second = $p0 return n", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": "named", + "p1": "auto", + }, preparedQuery.Parameters) + + _, err = v2.New().Where( + v2.Node().Property("name").Equals(v2.NamedParameter("bad name", "value")), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, `parameter has invalid symbol "bad name"`) + + _, err = v2.New().Where( + v2.Node().Property("first").Equals(v2.NamedParameter("same", "first")), + v2.Node().Property("second").Equals(v2.NamedParameter("same", "second")), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, "parameter same is bound to multiple values") +} + +func TestCompatibilityHelpers(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.And( + v2.InIDs(v2.NodeID(), 1, 2), + v2.KindIn(v2.Node(), graph.StringKind("User")), + v2.CaseInsensitiveStringContains(v2.Node().Property("name"), "ADMIN"), + v2.IsNotNull(v2.Node().Property("enabled")), + ), + ).Return( + v2.CountDistinct(v2.Node()), + v2.KindsOf(v2.Node()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) in $p0 and n:User and toLower(n.name) contains $p1 and n.enabled is not null return count(distinct n), labels(n)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": []graph.ID{1, 2}, + "p1": "admin", + }, preparedQuery.Parameters) +} + +func TestUpdateCompatibilityHelpers(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.AddKind(v2.Node(), graph.StringKind("Enabled")), + v2.SetProperties(v2.Node(), map[string]any{"name": "updated"}), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) = $p0 set n:Enabled, n.name = $p1", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": "updated", + }, preparedQuery.Parameters) +} + +func TestSetPropertiesSortsKeys(t *testing.T) { + properties := map[string]any{ + "zeta": 3, + "alpha": 1, + "mid": 2, + } + + preparedQuery, err := v2.New().Update( + v2.SetProperties(v2.Node(), properties), + ).Build() + require.NoError(t, err) + require.Equal(t, "match (n) set n.alpha = $p0, n.mid = $p1, n.zeta = $p2", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "p2": 3, + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().Update( + v2.Node().SetProperties(properties), + ).Build() + require.NoError(t, err) + require.Equal(t, "match (n) set n.alpha = $p0, n.mid = $p1, n.zeta = $p2", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "p2": 3, + }, preparedQuery.Parameters) +} diff --git a/query/v2/util.go b/query/v2/util.go new file mode 100644 index 00000000..ab003dc1 --- /dev/null +++ b/query/v2/util.go @@ -0,0 +1,1108 @@ +package v2 + +import ( + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/walk" + "github.com/specterops/dawgs/graph" +) + +func isNodePattern(seen *identifierSet, identifiers runtimeIdentifiers) bool { + return seen.Contains(identifiers.node) +} + +func isRelationshipPattern(seen *identifierSet, identifiers runtimeIdentifiers) bool { + var ( + hasStart = seen.Contains(identifiers.start) + hasRelationship = seen.Contains(identifiers.relationship) + hasEnd = seen.Contains(identifiers.end) + ) + + return hasStart || hasRelationship || hasEnd +} + +func prepareNodePattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error { + if isRelationshipPattern(seen, identifiers) { + return fmt.Errorf("query mixes node and relationship query identifiers") + } + + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Node(), + }) + + return nil +} + +func validateRelationshipDirection(direction graph.Direction) error { + switch direction { + case graph.DirectionInbound, graph.DirectionOutbound: + return nil + default: + return fmt.Errorf("unsupported relationship direction: %s", direction) + } +} + +func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers, relationshipKinds graph.Kinds, direction graph.Direction, shortestPaths, allShortestPaths bool) error { + if shortestPaths && allShortestPaths { + return errors.New("query is requesting both all shortest paths and shortest paths") + } + + if err := validateRelationshipDirection(direction); err != nil { + return err + } + + var ( + newPatternPart = match.NewPatternPart() + startNodeSeen = seen.Contains(identifiers.start) + relationshipSeen = seen.Contains(identifiers.relationship) + endNodeSeen = seen.Contains(identifiers.end) + ) + + newPatternPart.ShortestPathPattern = shortestPaths + newPatternPart.AllShortestPathsPattern = allShortestPaths + + if startNodeSeen { + newPatternPart.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Start(), + }) + } else { + newPatternPart.AddPatternElements(&cypher.NodePattern{}) + } + + relationshipPattern := &cypher.RelationshipPattern{ + Kinds: relationshipKinds, + Direction: direction, + } + + if relationshipSeen { + relationshipPattern.Variable = identifiers.Relationship() + } + + if shortestPaths || allShortestPaths { + newPatternPart.Variable = identifiers.Path() + relationshipPattern.Range = &cypher.PatternRange{} + } + + newPatternPart.AddPatternElements(relationshipPattern) + + if endNodeSeen { + newPatternPart.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.End(), + }) + } else { + newPatternPart.AddPatternElements(&cypher.NodePattern{}) + } + + return nil +} + +func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error { + if seen.Contains(identifiers.start) { + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Start(), + }) + } + + if seen.Contains(identifiers.end) { + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.End(), + }) + } + + return nil +} + +func isDetachDeleteQualifier(qualifier cypher.Expression, identifiers runtimeIdentifiers) bool { + variable, typeOK := qualifier.(*cypher.Variable) + if !typeOK || variable == nil { + return false + } + + switch variable.Symbol { + case identifiers.node, identifiers.start, identifiers.end: + return true + default: + return false + } +} + +func kindProjectionExpression(role string, identifier *cypher.Variable) (cypher.Expression, error) { + switch role { + case Identifiers.node, Identifiers.start, Identifiers.end: + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, identifier), nil + + case Identifiers.relationship: + return cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), nil + + default: + return nil, fmt.Errorf("invalid kind projection reference: %s", identifier.Symbol) + } +} + +func invalidExpression(err error) *cypher.FunctionInvocation { + return cypher.WithErrors(cypher.NewSimpleFunctionInvocation("__invalid_expression__"), err) +} + +func isNilPointer(value any) bool { + if value == nil { + return true + } + + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} + +func expressionOrError(value any) cypher.Expression { + if expression, err := projectionExpression(value); err != nil { + return invalidExpression(err) + } else { + return expression + } +} + +func variableReference(value any) (*cypher.Variable, error) { + expression, err := projectionExpression(value) + if err != nil { + return nil, err + } + + if variable, typeOK := expression.(*cypher.Variable); !typeOK { + return nil, fmt.Errorf("expected variable reference, got %T", expression) + } else { + return variable, nil + } +} + +func propertyLookupOrError(reference any, propertyName string) cypher.Expression { + if variable, err := variableReference(reference); err != nil { + return invalidExpression(err) + } else { + return cypher.NewPropertyLookup(variable.Symbol, propertyName) + } +} + +func sortedPropertyKeys(properties map[string]any) []string { + keys := make([]string, 0, len(properties)) + + for key := range properties { + keys = append(keys, key) + } + + sort.Strings(keys) + return keys +} + +func isCypherSymbolStart(char byte) bool { + return char == '_' || (char >= 'A' && char <= 'Z') || (char >= 'a' && char <= 'z') +} + +func isCypherSymbolPart(char byte) bool { + return isCypherSymbolStart(char) || (char >= '0' && char <= '9') +} + +func validateCypherSymbol(symbol, context string) error { + if strings.TrimSpace(symbol) == "" { + return fmt.Errorf("%s is empty", context) + } + + if !isCypherSymbolStart(symbol[0]) { + return fmt.Errorf("%s has invalid symbol %q", context, symbol) + } + + for idx := 1; idx < len(symbol); idx++ { + if !isCypherSymbolPart(symbol[idx]) { + return fmt.Errorf("%s has invalid symbol %q", context, symbol) + } + } + + return nil +} + +func projectionExpression(value any) (cypher.Expression, error) { + if isNilPointer(value) { + return nil, fmt.Errorf("expression is nil: %T", value) + } + + switch typedValue := value.(type) { + case kindContinuation: + return kindProjectionExpression(typedValue.role, typedValue.identifier) + + case kindsContinuation: + return kindProjectionExpression(typedValue.role, typedValue.identifier) + + case QualifiedExpression: + return typedValue.qualifier(), nil + + case *cypher.ProjectionItem: + if typedValue.Expression == nil { + return nil, fmt.Errorf("projection item has nil expression") + } + + return typedValue.Expression, nil + + case *cypher.Parameter: + return typedValue, nil + + case *cypher.Literal: + return typedValue, nil + + case *cypher.Variable: + return typedValue, nil + + case *cypher.PropertyLookup: + return typedValue, nil + + case *cypher.FunctionInvocation: + return typedValue, nil + + case *cypher.Parenthetical: + return typedValue, nil + + case *cypher.Comparison: + return typedValue, nil + + case *cypher.Negation: + return typedValue, nil + + case *cypher.Conjunction: + return typedValue, nil + + case *cypher.Disjunction: + return typedValue, nil + + case *cypher.ExclusiveDisjunction: + return typedValue, nil + + case *cypher.KindMatcher: + return typedValue, nil + + case *cypher.ListLiteral: + return typedValue, nil + + case cypher.MapLiteral: + return typedValue, nil + + case *cypher.PatternPredicate: + return typedValue, nil + + case *cypher.ArithmeticExpression: + return typedValue, nil + + case *cypher.UnaryAddOrSubtractExpression: + return typedValue, nil + + case *cypher.FilterExpression: + return typedValue, nil + + case *cypher.IDInCollection: + return typedValue, nil + + default: + return nil, fmt.Errorf("unsupported expression type: %T", value) + } +} + +func validateExpressionValue(expression cypher.Expression, context string) error { + if isNilPointer(expression) { + return fmt.Errorf("%s has nil expression", context) + } + + return collectModelErrors(expression) +} + +func validateAssignmentOperator(operator cypher.AssignmentOperator) error { + switch operator { + case cypher.OperatorAssignment, cypher.OperatorAdditionAssignment, cypher.OperatorLabelAssignment: + return nil + default: + return fmt.Errorf("unsupported set item operator: %s", operator) + } +} + +func setItemFromValue(setItem *cypher.SetItem) (*cypher.SetItem, error) { + if setItem == nil { + return nil, fmt.Errorf("set item is nil") + } + + if err := validateExpressionValue(setItem.Left, "set item left"); err != nil { + return nil, err + } + + if err := validateAssignmentOperator(setItem.Operator); err != nil { + return nil, err + } + + if err := validateExpressionValue(setItem.Right, "set item right"); err != nil { + return nil, err + } + + if err := collectModelErrors(setItem); err != nil { + return nil, err + } + + return setItem, nil +} + +func setItemsFromSet(setClause *cypher.Set) ([]*cypher.SetItem, error) { + if setClause == nil { + return nil, fmt.Errorf("set clause is nil") + } + + setItems := make([]*cypher.SetItem, 0, len(setClause.Items)) + + for _, setItem := range setClause.Items { + if normalizedSetItem, err := setItemFromValue(setItem); err != nil { + return nil, err + } else { + setItems = append(setItems, normalizedSetItem) + } + } + + return setItems, nil +} + +func removeItemFromValue(removeItem *cypher.RemoveItem) (*cypher.RemoveItem, error) { + if removeItem == nil { + return nil, fmt.Errorf("remove item is nil") + } + + hasKindMatcher := removeItem.KindMatcher != nil + hasProperty := !isNilPointer(removeItem.Property) + + switch { + case hasKindMatcher && hasProperty: + return nil, fmt.Errorf("remove item has multiple targets") + + case hasKindMatcher: + if err := collectModelErrors(removeItem.KindMatcher); err != nil { + return nil, err + } + + case hasProperty: + if err := validateExpressionValue(removeItem.Property, "remove item property"); err != nil { + return nil, err + } + + default: + return nil, fmt.Errorf("remove item has no target") + } + + if err := collectModelErrors(removeItem); err != nil { + return nil, err + } + + return removeItem, nil +} + +func removeItemsFromRemove(removeClause *cypher.Remove) ([]*cypher.RemoveItem, error) { + if removeClause == nil { + return nil, fmt.Errorf("remove clause is nil") + } + + removeItems := make([]*cypher.RemoveItem, 0, len(removeClause.Items)) + + for _, removeItem := range removeClause.Items { + if normalizedRemoveItem, err := removeItemFromValue(removeItem); err != nil { + return nil, err + } else { + removeItems = append(removeItems, normalizedRemoveItem) + } + } + + return removeItems, nil +} + +func validateNodePattern(nodePattern *cypher.NodePattern) error { + if nodePattern == nil { + return fmt.Errorf("node pattern is nil") + } + + return collectModelErrors(nodePattern) +} + +func validateRelationshipPattern(relationshipPattern *cypher.RelationshipPattern) error { + if relationshipPattern == nil { + return fmt.Errorf("relationship pattern is nil") + } + + if err := validateRelationshipDirection(relationshipPattern.Direction); err != nil { + return err + } + + return collectModelErrors(relationshipPattern) +} + +func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { + if projectionItem, typeOK := value.(*cypher.ProjectionItem); typeOK { + if projectionItem == nil { + return nil, fmt.Errorf("projection item is nil") + } + + if err := validateExpressionValue(projectionItem.Expression, "projection item"); err != nil { + return nil, err + } + + if projectionItem.Alias != nil { + if err := validateCypherSymbol(projectionItem.Alias.Symbol, "projection alias"); err != nil { + return nil, err + } + } + + if err := collectModelErrors(projectionItem); err != nil { + return nil, err + } + + return projectionItem, nil + } + + if expression, err := projectionExpression(value); err != nil { + return nil, err + } else { + return cypher.NewProjectionItemWithExpr(expression), nil + } +} + +func sortItemFromValue(value any) (*cypher.SortItem, error) { + if sortItem, typeOK := value.(*cypher.SortItem); typeOK { + if sortItem == nil { + return nil, fmt.Errorf("sort item is nil") + } + + if err := validateExpressionValue(sortItem.Expression, "sort item"); err != nil { + return nil, err + } + + if err := collectModelErrors(sortItem); err != nil { + return nil, err + } + + return sortItem, nil + } + + if expression, err := projectionExpression(value); err != nil { + return nil, err + } else { + return &cypher.SortItem{ + Ascending: true, + Expression: expression, + }, nil + } +} + +func projectionItemsFromReturn(returnClause *cypher.Return) ([]*cypher.ProjectionItem, error) { + if returnClause == nil { + return nil, fmt.Errorf("return clause is nil") + } + + if returnClause.Projection == nil { + return nil, fmt.Errorf("return clause has nil projection") + } + + if err := validateProjectionMetadata(returnClause.Projection); err != nil { + return nil, err + } + + projectionItems := make([]*cypher.ProjectionItem, 0, len(returnClause.Projection.Items)) + + for _, returnItem := range returnClause.Projection.Items { + if projectionItem, err := projectionItemFromValue(returnItem); err != nil { + return nil, err + } else { + projectionItems = append(projectionItems, projectionItem) + } + } + + return projectionItems, nil +} + +func validateProjectionMetadata(projection *cypher.Projection) error { + if projection.Order != nil { + if _, err := sortItemsFromOrder(projection.Order); err != nil { + return err + } + } + + if projection.Skip != nil { + if err := validateExpressionValue(projection.Skip.Value, "projection skip"); err != nil { + return err + } + } + + if projection.Limit != nil { + if err := validateExpressionValue(projection.Limit.Value, "projection limit"); err != nil { + return err + } + } + + return nil +} + +func sortItemsFromOrder(order *cypher.Order) ([]*cypher.SortItem, error) { + if order == nil { + return nil, fmt.Errorf("order is nil") + } + + sortItems := make([]*cypher.SortItem, 0, len(order.Items)) + + for _, sortItem := range order.Items { + if normalizedSortItem, err := sortItemFromValue(sortItem); err != nil { + return nil, err + } else { + sortItems = append(sortItems, normalizedSortItem) + } + } + + return sortItems, nil +} + +type identifierSet struct { + identifiers map[string]struct{} +} + +func newIdentifierSet() *identifierSet { + return &identifierSet{ + identifiers: map[string]struct{}{}, + } +} + +func (s *identifierSet) Add(identifier string) { + s.identifiers[identifier] = struct{}{} +} + +func (s *identifierSet) Len() int { + return len(s.identifiers) +} + +func (s *identifierSet) Clone() *identifierSet { + clone := newIdentifierSet() + clone.Or(s) + return clone +} + +func (s *identifierSet) Or(other *identifierSet) { + for otherIdentifier := range other.identifiers { + s.identifiers[otherIdentifier] = struct{}{} + } +} + +func (s *identifierSet) Remove(other *identifierSet) { + for otherIdentifier := range other.identifiers { + delete(s.identifiers, otherIdentifier) + } +} + +func (s *identifierSet) Contains(identifier string) bool { + _, containsIdentifier := s.identifiers[identifier] + return containsIdentifier +} + +func (s *identifierSet) CollectFromExpression(expr cypher.Expression) error { + if exprIdentifiers, err := extractCypherIdentifiers(expr); err != nil { + return err + } else { + s.Or(exprIdentifiers) + return nil + } +} + +func (s *identifierSet) CollectFromValue(value any) error { + switch typedValue := value.(type) { + case nil: + return nil + + case QualifiedExpression: + return s.CollectFromExpression(typedValue.qualifier()) + + case *cypher.Return: + if projectionItems, err := projectionItemsFromReturn(typedValue); err != nil { + return err + } else { + for _, projectionItem := range projectionItems { + if err := s.CollectFromExpression(projectionItem); err != nil { + return err + } + } + } + + case *cypher.Order: + if sortItems, err := sortItemsFromOrder(typedValue); err != nil { + return err + } else { + for _, sortItem := range sortItems { + if err := s.CollectFromExpression(sortItem); err != nil { + return err + } + } + } + + case *cypher.SortItem: + if sortItem, err := sortItemFromValue(typedValue); err != nil { + return err + } else { + return s.CollectFromExpression(sortItem) + } + + case *cypher.Set: + if setItems, err := setItemsFromSet(typedValue); err != nil { + return err + } else { + for _, setItem := range setItems { + if err := s.CollectFromExpression(setItem); err != nil { + return err + } + } + } + + case *cypher.SetItem: + if setItem, err := setItemFromValue(typedValue); err != nil { + return err + } else { + return s.CollectFromExpression(setItem) + } + + case *cypher.Remove: + if removeItems, err := removeItemsFromRemove(typedValue); err != nil { + return err + } else { + for _, removeItem := range removeItems { + if err := s.CollectFromValue(removeItem); err != nil { + return err + } + } + } + + case *cypher.RemoveItem: + if removeItem, err := removeItemFromValue(typedValue); err != nil { + return err + } else if removeItem.KindMatcher != nil { + return s.CollectFromExpression(removeItem.KindMatcher) + } else { + return s.CollectFromExpression(removeItem) + } + + case *cypher.NodePattern: + if err := validateNodePattern(typedValue); err != nil { + return err + } + + return s.CollectFromExpression(typedValue) + + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + return err + } + + return s.CollectFromExpression(typedValue) + + case *cypher.Variable: + return s.CollectFromExpression(typedValue) + + case *cypher.FunctionInvocation: + return s.CollectFromExpression(typedValue) + + case *cypher.PropertyLookup: + return s.CollectFromExpression(typedValue) + + case []any: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []cypher.SyntaxNode: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []cypher.Expression: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []*cypher.SetItem: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []*cypher.RemoveItem: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + default: + return nil + } + + return nil +} + +func collectIdentifiersFromValues(values ...any) (*identifierSet, error) { + identifiers := newIdentifierSet() + + for _, value := range values { + if err := identifiers.CollectFromValue(value); err != nil { + return nil, err + } + } + + return identifiers, nil +} + +type createScope struct { + identifiers *identifierSet + createsRelationship bool +} + +func collectCreateScope(identifiers runtimeIdentifiers, values ...any) (*createScope, error) { + scope := &createScope{ + identifiers: newIdentifierSet(), + } + + for _, value := range values { + switch typedValue := value.(type) { + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + return nil, err + } + + scope.createsRelationship = true + scope.identifiers.Add(identifiers.start) + scope.identifiers.Add(identifiers.end) + + if typedValue.Variable != nil { + scope.identifiers.Add(typedValue.Variable.Symbol) + } + + default: + if err := scope.identifiers.CollectFromValue(value); err != nil { + return nil, err + } + } + } + + return scope, nil +} + +type identifierExtractor struct { + walk.Visitor[cypher.SyntaxNode] + + seen *identifierSet +} + +func newIdentifierExtractor() *identifierExtractor { + return &identifierExtractor{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + seen: newIdentifierSet(), + } +} + +func (s *identifierExtractor) Enter(node cypher.SyntaxNode) { + switch typedNode := node.(type) { + case *cypher.Variable: + s.seen.Add(typedNode.Symbol) + + case *cypher.NodePattern: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.RelationshipPattern: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.PatternPart: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + } +} + +func extractCypherIdentifiers(expression cypher.Expression) (*identifierSet, error) { + var ( + identifierExtractorVisitor = newIdentifierExtractor() + err = walk.Cypher(expression, identifierExtractorVisitor) + ) + + return identifierExtractorVisitor.seen, err +} + +func collectModelErrors(node cypher.SyntaxNode) error { + var modelErrors []error + + if err := walk.Cypher(node, walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + if errorNode, typeOK := node.(cypher.Fallible); typeOK { + modelErrors = append(modelErrors, errorNode.Errors()...) + } + })); err != nil { + modelErrors = append(modelErrors, err) + } + + return errors.Join(modelErrors...) +} + +func collectModelErrorsFromKnownValues(values ...any) error { + var modelErrors []error + + for _, value := range values { + switch typedValue := value.(type) { + case nil: + continue + + case []any: + if err := collectModelErrorsFromKnownValues(typedValue...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []cypher.SyntaxNode: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []cypher.Expression: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.SetItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.RemoveItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.ProjectionItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.SortItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.NodePattern: + if err := validateNodePattern(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Order: + if _, err := sortItemsFromOrder(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.ProjectionItem: + if _, err := projectionItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Return: + if _, err := projectionItemsFromReturn(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Remove: + if _, err := removeItemsFromRemove(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.RemoveItem: + if _, err := removeItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Set: + if _, err := setItemsFromSet(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.SetItem: + if _, err := setItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.SortItem: + if _, err := sortItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.ArithmeticExpression, + *cypher.Comparison, + *cypher.Conjunction, + *cypher.Create, + *cypher.Delete, + *cypher.Disjunction, + *cypher.ExclusiveDisjunction, + *cypher.FilterExpression, + *cypher.FunctionInvocation, + *cypher.IDInCollection, + *cypher.KindMatcher, + *cypher.ListLiteral, + *cypher.Negation, + *cypher.Parenthetical, + *cypher.PatternPredicate, + *cypher.PropertyLookup, + *cypher.UnaryAddOrSubtractExpression, + *cypher.UpdatingClause, + *cypher.Variable: + if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + } + } + + return errors.Join(modelErrors...) +} + +func anySlice[T any](values []T) []any { + items := make([]any, len(values)) + + for idx, value := range values { + items[idx] = value + } + + return items +} + +type parameterMaterializer struct { + walk.Visitor[cypher.SyntaxNode] + + parameters map[string]any + nextIndex int +} + +func newParameterMaterializer(parameters map[string]any) *parameterMaterializer { + materializedParameters := map[string]any{} + + for symbol, value := range parameters { + materializedParameters[symbol] = value + } + + return ¶meterMaterializer{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + parameters: materializedParameters, + } +} + +func (s *parameterMaterializer) nextSymbol() string { + for { + symbol := "p" + strconv.Itoa(s.nextIndex) + s.nextIndex++ + + if _, taken := s.parameters[symbol]; !taken { + return symbol + } + } +} + +func (s *parameterMaterializer) Enter(node cypher.SyntaxNode) { + parameter, typeOK := node.(*cypher.Parameter) + if !typeOK { + return + } + + if parameter.Symbol == "" { + parameter.Symbol = s.nextSymbol() + } + + if existingValue, exists := s.parameters[parameter.Symbol]; exists && !reflect.DeepEqual(existingValue, parameter.Value) { + s.SetErrorf("parameter %s is bound to multiple values", parameter.Symbol) + return + } + + s.parameters[parameter.Symbol] = parameter.Value +} + +type namedParameterCollector struct { + walk.Visitor[cypher.SyntaxNode] + + parameters map[string]any +} + +func newNamedParameterCollector() *namedParameterCollector { + return &namedParameterCollector{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + parameters: map[string]any{}, + } +} + +func (s *namedParameterCollector) Enter(node cypher.SyntaxNode) { + parameter, typeOK := node.(*cypher.Parameter) + if !typeOK || parameter.Symbol == "" { + return + } + + if err := validateCypherSymbol(parameter.Symbol, "parameter"); err != nil { + s.SetError(err) + return + } + + if existingValue, exists := s.parameters[parameter.Symbol]; exists && !reflect.DeepEqual(existingValue, parameter.Value) { + s.SetErrorf("parameter %s is bound to multiple values", parameter.Symbol) + return + } + + s.parameters[parameter.Symbol] = parameter.Value +} + +func collectNamedParameters(query *cypher.RegularQuery) (map[string]any, error) { + collector := newNamedParameterCollector() + + if err := walk.Cypher(query, collector); err != nil { + return nil, err + } + + return collector.parameters, nil +} + +func materializeParameters(query *cypher.RegularQuery) (map[string]any, error) { + namedParameters, err := collectNamedParameters(query) + if err != nil { + return nil, err + } + + materializer := newParameterMaterializer(namedParameters) + + if err := walk.Cypher(query, materializer); err != nil { + return nil, err + } + + return materializer.parameters, nil +}