From 249a9ed642ca9004929978b661c5da0715cc7a99 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 6 May 2026 23:27:44 -0700 Subject: [PATCH 01/55] feat (cypher): port v2 query builder --- query/v2/doc.go | 6 + query/v2/query.go | 771 +++++++++++++++++++++++++++++++++++++++++ query/v2/query_test.go | 46 +++ query/v2/util.go | 174 ++++++++++ 4 files changed, 997 insertions(+) create mode 100644 query/v2/doc.go create mode 100644 query/v2/query.go create mode 100644 query/v2/query_test.go create mode 100644 query/v2/util.go diff --git a/query/v2/doc.go b/query/v2/doc.go new file mode 100644 index 00000000..a21a5004 --- /dev/null +++ b/query/v2/doc.go @@ -0,0 +1,6 @@ +// Package v2 contains the experimental fluent Cypher query builder. +// +// It is intentionally isolated from the stable query package so callers can +// opt in without pulling the current graph query APIs through a compatibility +// layer. +package v2 diff --git a/query/v2/query.go b/query/v2/query.go new file mode 100644 index 00000000..75eb0d44 --- /dev/null +++ b/query/v2/query.go @@ -0,0 +1,771 @@ +package v2 + +import ( + "errors" + "fmt" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" +) + +type runtimeIdentifiers struct { + path string + node string + start string + relationship string + end string +} + +func (s runtimeIdentifiers) Path() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.path) +} + +func (s runtimeIdentifiers) Node() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.node) +} + +func (s runtimeIdentifiers) Start() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.start) +} + +func (s runtimeIdentifiers) Relationship() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.relationship) +} + +func (s runtimeIdentifiers) End() *cypher.Variable { + return cypher.NewVariableWithSymbol(s.end) +} + +var Identifiers = runtimeIdentifiers{ + path: "p", + node: "n", + start: "s", + relationship: "r", + end: "e", +} + +func newLiteral(value any) *cypher.Literal { + if value == nil { + return cypher.NewLiteral(nil, true) + } + + if strValue, typeOK := value.(string); typeOK { + return cypher.NewStringLiteral(strValue) + } + + return cypher.NewLiteral(value, false) +} + +func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) cypher.SyntaxNode { + expressionList := &cypher.Comparison{} + + if len(operands) > 0 { + expressionList.Left = operands[0] + + for _, operand := range operands[1:] { + expressionList.NewPartialComparison(operator, operand) + } + } + + return expressionList +} + +func Not(operand cypher.Expression) cypher.Expression { + return cypher.NewNegation(operand) +} + +func And(operands ...cypher.SyntaxNode) cypher.SyntaxNode { + return joinedExpressionList(cypher.OperatorAnd, operands) +} + +func Or(operands ...cypher.SyntaxNode) cypher.SyntaxNode { + return joinedExpressionList(cypher.OperatorOr, operands) +} + +func Node() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: Identifiers.Node(), + } +} + +func Path() PathContinuation { + return &entity[PathContinuation]{ + identifier: Identifiers.Path(), + } +} + +func Start() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: Identifiers.Start(), + } +} + +func Relationship() RelationshipContinuation { + return &entity[RelationshipContinuation]{ + identifier: Identifiers.Relationship(), + } +} + +func End() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: Identifiers.End(), + } +} + +type QualifiedExpression interface { + qualifier() cypher.Expression +} + +type EntityContinuation interface { + QualifiedExpression + + Count() cypher.Expression + ID() IdentityContinuation + Property(name string) PropertyContinuation +} + +type KindContinuation interface { + Is(kind graph.Kind) cypher.Expression + IsOneOf(kinds graph.Kinds) cypher.Expression +} + +type KindsContinuation interface { + Has(kind graph.Kind) cypher.Expression + HasOneOf(kinds graph.Kinds) cypher.Expression + Add(kinds graph.Kinds) cypher.Expression + Remove(kinds graph.Kinds) cypher.Expression +} + +type Comparable interface { + In(value any) cypher.Expression + Contains(value any) cypher.Expression + Equals(value any) cypher.Expression + GreaterThan(value any) cypher.Expression + GreaterThanOrEqualTo(value any) cypher.Expression + LessThan(value any) cypher.Expression + LessThanOrEqualTo(value any) cypher.Expression +} + +type PropertyContinuation interface { + QualifiedExpression + Comparable + + Set(value any) *cypher.SetItem + Remove() *cypher.RemoveItem +} + +type IdentityContinuation interface { + QualifiedExpression + Comparable +} + +type comparisonContinuation struct { + qualifierExpression cypher.Expression +} + +func (s *comparisonContinuation) qualifier() cypher.Expression { + return s.qualifierExpression +} + +func (s *comparisonContinuation) asComparison(operator cypher.Operator, rOperand any) cypher.Expression { + return cypher.NewComparison( + s.qualifier(), + operator, + newLiteral(rOperand), + ) +} + +func (s *comparisonContinuation) In(value any) cypher.Expression { + return s.asComparison(cypher.OperatorIn, value) +} + +func (s *comparisonContinuation) Contains(value any) cypher.Expression { + return s.asComparison(cypher.OperatorContains, value) +} + +func (s *comparisonContinuation) Equals(value any) cypher.Expression { + return s.asComparison(cypher.OperatorEquals, value) +} + +func (s *comparisonContinuation) GreaterThan(value any) cypher.Expression { + return s.asComparison(cypher.OperatorGreaterThan, value) +} + +func (s *comparisonContinuation) GreaterThanOrEqualTo(value any) cypher.Expression { + return s.asComparison(cypher.OperatorGreaterThanOrEqualTo, value) +} + +func (s *comparisonContinuation) LessThan(value any) cypher.Expression { + return s.asComparison(cypher.OperatorLessThan, value) +} + +func (s *comparisonContinuation) LessThanOrEqualTo(value any) cypher.Expression { + return s.asComparison(cypher.OperatorLessThanOrEqualTo, value) +} + +type propertyContinuation struct { + comparisonContinuation +} + +func (s *propertyContinuation) Set(value any) *cypher.SetItem { + return cypher.NewSetItem( + s.qualifier(), + cypher.OperatorAssignment, + newLiteral(value), + ) +} + +func (s *propertyContinuation) Remove() *cypher.RemoveItem { + return cypher.RemoveProperty(s.qualifier()) +} + +type entity[T any] struct { + identifier *cypher.Variable +} + +func (s *entity[T]) Kind() KindContinuation { + return kindContinuation{ + identifier: s.identifier, + } +} + +func (s *entity[T]) Kinds() KindsContinuation { + return kindsContinuation{ + identifier: s.identifier, + } +} + +func (s *entity[T]) Count() cypher.Expression { + return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, s.identifier) +} + +func (s *entity[T]) SetProperties(properties map[string]any) cypher.Expression { + set := &cypher.Set{} + + for key, value := range properties { + set.Items = append(set.Items, s.Property(key).Set(value)) + } + + return set +} + +func (s *entity[T]) RemoveProperties(properties []string) cypher.Expression { + remove := &cypher.Remove{} + + for _, key := range properties { + remove.Items = append(remove.Items, s.Property(key).Remove()) + } + + return remove +} + +func (s *entity[T]) RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression { + return &cypher.RelationshipPattern{ + Variable: s.identifier, + Kinds: graph.Kinds{kind}, + Direction: direction, + Properties: properties, + } +} + +func (s *entity[T]) NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression { + return &cypher.NodePattern{ + Variable: s.identifier, + Kinds: kinds, + Properties: properties, + } +} + +func (s *entity[T]) qualifier() cypher.Expression { + return s.identifier +} + +func (s *entity[T]) ID() IdentityContinuation { + return &comparisonContinuation{ + qualifierExpression: &cypher.FunctionInvocation{ + Distinct: false, + Name: cypher.IdentityFunction, + Arguments: []cypher.Expression{s.identifier}, + }, + } +} + +func (s *entity[T]) Property(propertyName string) PropertyContinuation { + return &propertyContinuation{ + comparisonContinuation: comparisonContinuation{ + qualifierExpression: cypher.NewPropertyLookup(s.identifier.Symbol, propertyName), + }, + } +} + +type kindContinuation struct { + identifier *cypher.Variable +} + +func (s kindContinuation) Is(kind graph.Kind) cypher.Expression { + return s.IsOneOf(graph.Kinds{kind}) +} + +func (s kindContinuation) IsOneOf(kinds graph.Kinds) cypher.Expression { + return &cypher.KindMatcher{ + Reference: s.identifier, + Kinds: kinds, + } +} + +type kindsContinuation struct { + identifier *cypher.Variable +} + +func (s kindsContinuation) Has(kind graph.Kind) cypher.Expression { + return s.HasOneOf(graph.Kinds{kind}) +} + +func (s kindsContinuation) HasOneOf(kinds graph.Kinds) cypher.Expression { + return &cypher.KindMatcher{ + Reference: s.identifier, + Kinds: kinds, + } +} + +func (s kindsContinuation) Add(kinds graph.Kinds) cypher.Expression { + return cypher.NewSetItem( + s.identifier, + cypher.OperatorLabelAssignment, + kinds, + ) +} + +func (s kindsContinuation) Remove(kinds graph.Kinds) cypher.Expression { + return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(s.identifier, kinds, false)) +} + +type PathContinuation interface { + QualifiedExpression + + Count() cypher.Expression +} + +type RelationshipContinuation interface { + EntityContinuation + + RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression + + Kind() KindContinuation + SetProperties(properties map[string]any) cypher.Expression + RemoveProperties(properties []string) cypher.Expression +} + +type NodeContinuation interface { + EntityContinuation + + NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression + + Kinds() KindsContinuation + SetProperties(properties map[string]any) cypher.Expression + RemoveProperties(properties []string) cypher.Expression +} + +type QueryBuilder interface { + Where(constraints ...cypher.SyntaxNode) QueryBuilder + OrderBy(sortItems ...cypher.SyntaxNode) QueryBuilder + Skip(offset int) QueryBuilder + Limit(limit int) QueryBuilder + Return(projections ...any) QueryBuilder + Update(updatingClauses ...any) QueryBuilder + Create(creationClauses ...any) QueryBuilder + Delete(expressions ...any) QueryBuilder + WithShortestPaths() QueryBuilder + WithAllShortestPaths() QueryBuilder + Build() (*PreparedQuery, error) +} + +type builder struct { + errors []error + constraints []cypher.SyntaxNode + sortItems []cypher.SyntaxNode + projections []any + creates []any + setItems []*cypher.SetItem + removeItems []*cypher.RemoveItem + deleteItems []cypher.Expression + detachDelete bool + shortestPathQuery bool + allShorestPathsQuery bool + skip *int + limit *int +} + +func New() QueryBuilder { + return &builder{} +} + +func (s *builder) WithShortestPaths() QueryBuilder { + s.shortestPathQuery = true + return s +} + +func (s *builder) WithAllShortestPaths() QueryBuilder { + s.allShorestPathsQuery = true + return s +} + +func (s *builder) OrderBy(sortItems ...cypher.SyntaxNode) QueryBuilder { + s.sortItems = append(s.sortItems, sortItems...) + return s +} + +func (s *builder) Skip(skip int) QueryBuilder { + s.skip = &skip + return s +} + +func (s *builder) Limit(limit int) QueryBuilder { + s.limit = &limit + return s +} + +func (s *builder) Return(projections ...any) QueryBuilder { + s.projections = append(s.projections, projections...) + return s +} + +func (s *builder) Create(creationClauses ...any) QueryBuilder { + s.creates = append(s.creates, creationClauses...) + return s +} + +func (s *builder) Update(updates ...any) QueryBuilder { + for _, nextUpdate := range updates { + switch typedNextUpdate := nextUpdate.(type) { + case *cypher.Set: + s.setItems = append(s.setItems, typedNextUpdate.Items...) + + case *cypher.SetItem: + s.setItems = append(s.setItems, typedNextUpdate) + + case *cypher.Remove: + s.removeItems = append(s.removeItems, typedNextUpdate.Items...) + + case *cypher.RemoveItem: + s.removeItems = append(s.removeItems, typedNextUpdate) + + default: + s.trackError(fmt.Errorf("unknown update type: %T", nextUpdate)) + } + } + + return s +} + +func (s *builder) Delete(deleteItems ...any) QueryBuilder { + for _, nextDelete := range deleteItems { + switch typedNextUpdate := nextDelete.(type) { + case QualifiedExpression: + qualifier := typedNextUpdate.qualifier() + + switch qualifier { + case Identifiers.node, Identifiers.start, Identifiers.end: + s.detachDelete = true + } + + s.deleteItems = append(s.deleteItems, qualifier) + + case *cypher.Variable: + switch typedNextUpdate.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + s.detachDelete = true + } + + s.deleteItems = append(s.deleteItems, typedNextUpdate) + + default: + s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete)) + } + } + + return s +} + +func (s *builder) trackError(err error) { + s.errors = append(s.errors, err) +} + +func (s *builder) Where(constraints ...cypher.SyntaxNode) QueryBuilder { + s.constraints = append(s.constraints, constraints...) + return s +} + +func (s *builder) buildCreates(singlePartQuery *cypher.SinglePartQuery) error { + if len(s.creates) == 0 { + return nil + } + + var ( + pattern = &cypher.PatternPart{} + createClause = &cypher.Create{ + Unique: false, + Pattern: []*cypher.PatternPart{pattern}, + } + ) + + for _, nextCreate := range s.creates { + switch typedNextCreate := nextCreate.(type) { + case QualifiedExpression: + switch typedExpression := typedNextCreate.qualifier().(type) { + case *cypher.Variable: + switch typedExpression.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(typedExpression.Symbol), + }) + + default: + return fmt.Errorf("invalid variable reference for create: %s", typedExpression.Symbol) + } + } + + case *cypher.NodePattern: + pattern.AddPatternElements(typedNextCreate) + + case *cypher.RelationshipPattern: + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(Identifiers.start), + }) + + pattern.AddPatternElements(typedNextCreate) + + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(Identifiers.end), + }) + + default: + return fmt.Errorf("invalid type for create: %T", nextCreate) + } + } + + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause(createClause)) + return nil +} + +func (s *builder) buildUpdatingClauses(singlePartQuery *cypher.SinglePartQuery) error { + if len(s.setItems) > 0 { + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewSet(s.setItems), + )) + } + + if len(s.removeItems) > 0 { + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewRemove(s.removeItems), + )) + } + + if len(s.deleteItems) > 0 { + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewDelete( + s.detachDelete, + s.deleteItems, + ), + )) + } + + return s.buildCreates(singlePartQuery) +} + +func (s *builder) buildProjectionOrder() (*cypher.Order, error) { + var orderByNode *cypher.Order + + if len(s.sortItems) > 0 { + orderByNode = &cypher.Order{} + + for _, untypedSortItem := range s.sortItems { + switch typedSortItem := untypedSortItem.(type) { + case *cypher.Order: + for _, sortItem := range typedSortItem.Items { + orderByNode.Items = append(orderByNode.Items, sortItem) + } + + case *cypher.SortItem: + orderByNode.Items = append(orderByNode.Items, typedSortItem) + } + } + } + + return orderByNode, nil +} + +func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error { + var ( + hasProjectedItems = len(s.projections) > 0 + hasSkip = s.skip != nil && *s.skip > 0 + hasLimit = s.limit != nil && *s.limit > 0 + requiresProjection = hasProjectedItems || hasSkip || hasLimit + ) + + if requiresProjection { + if !hasProjectedItems { + return fmt.Errorf("query expected projected items") + } + + projection := singlePartQuery.NewProjection(false) + + for _, nextProjection := range s.projections { + switch typedNextProjection := nextProjection.(type) { + case *cypher.Return: + for _, returnItem := range typedNextProjection.Projection.Items { + if typedReturnItem, typeOK := returnItem.(*cypher.ProjectionItem); !typeOK { + return fmt.Errorf("invalid type for return: %T", returnItem) + } else { + projection.AddItem(typedReturnItem) + } + } + + case QualifiedExpression: + projection.AddItem(cypher.NewProjectionItemWithExpr(typedNextProjection.qualifier())) + + case kindContinuation: + var kindExpr cypher.Expression + + switch typedNextProjection.identifier.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + kindExpr = cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedNextProjection.identifier) + + case Identifiers.relationship: + kindExpr = cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedNextProjection.identifier) + } + + projection.AddItem(cypher.NewProjectionItemWithExpr(kindExpr)) + + case kindsContinuation: + var kindExpr cypher.Expression + + switch typedNextProjection.identifier.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + kindExpr = cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedNextProjection.identifier) + + case Identifiers.relationship: + kindExpr = cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedNextProjection.identifier) + } + + projection.AddItem(cypher.NewProjectionItemWithExpr(kindExpr)) + + default: + projection.AddItem(cypher.NewProjectionItemWithExpr(typedNextProjection)) + } + } + + if s.skip != nil && *s.skip > 0 { + projection.Skip = cypher.NewSkip(*s.skip) + } + + if s.limit != nil && *s.limit > 0 { + projection.Limit = cypher.NewLimit(*s.limit) + } + + if projectionOrder, err := s.buildProjectionOrder(); err != nil { + return err + } else if projectionOrder != nil { + projection.Order = projectionOrder + } + } + + return nil +} + +type PreparedQuery struct { + Query *cypher.RegularQuery + Parameters map[string]any +} + +func (s *builder) hasActions() bool { + return len(s.projections) > 0 || len(s.setItems) > 0 || len(s.removeItems) > 0 || len(s.creates) > 0 || len(s.deleteItems) > 0 +} + +func (s *builder) Build() (*PreparedQuery, error) { + if len(s.errors) > 0 { + return nil, errors.Join(s.errors...) + } + + if !s.hasActions() { + return nil, fmt.Errorf("query has no action specified") + } + + var ( + regularQuery, singlePartQuery = cypher.NewRegularQueryWithSingleQuery() + match = &cypher.Match{} + seenIdentifiers = newIdentifierSet() + relationshipKinds graph.Kinds + ) + + if err := s.buildUpdatingClauses(singlePartQuery); err != nil { + return nil, err + } + + if err := s.buildProjection(singlePartQuery); err != nil { + return nil, err + } + + if len(s.constraints) > 0 { + var ( + whereClause = match.NewWhere() + constraints = &cypher.Comparison{} + ) + + for _, nextConstraint := range s.constraints { + switch typedNextConstraint := nextConstraint.(type) { + case *cypher.KindMatcher: + if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK { + return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint) + } else if identifier.Symbol == Identifiers.relationship { + relationshipKinds = relationshipKinds.Add(typedNextConstraint.Kinds...) + continue + } + } + + if constraints.Left == nil { + constraints.Left = nextConstraint + } else { + constraints.NewPartialComparison(cypher.OperatorAnd, nextConstraint) + } + } + + if constraints.Left != nil { + whereClause.Add(constraints) + + if err := seenIdentifiers.CollectFromExpression(whereClause); err != nil { + return nil, err + } + } + } + + if err := seenIdentifiers.CollectFromExpression(singlePartQuery); err != nil { + return nil, err + } + + if len(s.constraints) > 0 || len(s.creates) == 0 { + if isNodePattern(seenIdentifiers) { + if err := prepareNodePattern(match, seenIdentifiers); err != nil { + return nil, err + } + } else if isRelationshipPattern(seenIdentifiers) { + if err := prepareRelationshipPattern(match, seenIdentifiers, relationshipKinds, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { + return nil, err + } + } else { + return nil, fmt.Errorf("query has no node and relationship query identifiers specified") + } + } + + if len(match.Pattern) > 0 { + newReadingClause := cypher.NewReadingClause() + newReadingClause.Match = match + + singlePartQuery.ReadingClauses = append(singlePartQuery.ReadingClauses, newReadingClause) + } + + return &PreparedQuery{ + Query: regularQuery, + Parameters: map[string]any{}, + }, nil +} diff --git a/query/v2/query_test.go b/query/v2/query_test.go new file mode 100644 index 00000000..76e793f9 --- /dev/null +++ b/query/v2/query_test.go @@ -0,0 +1,46 @@ +package v2_test + +import ( + "testing" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/cypher/format" + "github.com/specterops/dawgs/graph" + v2 "github.com/specterops/dawgs/query/v2" + "github.com/stretchr/testify/require" +) + +func TestQuery(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Not(v2.Relationship().Kind().Is(graph.StringKind("test"))), + v2.Not(v2.Relationship().Kind().IsOneOf(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")})), + v2.Relationship().Property("rel_prop").LessThanOrEqualTo(1234), + v2.Relationship().Property("other_prop").Equals(5678), + v2.Start().Kinds().HasOneOf(graph.Kinds{graph.StringKind("test")}), + ).Update( + v2.Start().Property("this_prop").Set(1234), + v2.End().Kinds().Remove(graph.Kinds{graph.StringKind("A"), graph.StringKind("B")}), + ).Delete( + v2.Start(), + ).Return( + v2.Relationship(), + v2.Start().Property("node_prop"), + ).Skip(10).Limit(10).Build() + require.NoError(t, err) + + cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + require.Equal(t, "match (s)-[r]->() where not r:test and not (r:A or r:B) and r.rel_prop <= 1234 and r.other_prop = 5678 and s:test set s.this_prop = 1234 remove e:A:B delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) + + preparedQuery, err = v2.New().Create( + v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, cypher.NewParameter("props", map[string]any{})), + ).Build() + + require.NoError(t, err) + + cypherQueryStr, err = format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + require.Equal(t, "create (n:A $props)", cypherQueryStr) +} diff --git a/query/v2/util.go b/query/v2/util.go new file mode 100644 index 00000000..d8bdaaba --- /dev/null +++ b/query/v2/util.go @@ -0,0 +1,174 @@ +package v2 + +import ( + "errors" + "fmt" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/cypher/models/walk" + "github.com/specterops/dawgs/graph" +) + +func isNodePattern(seen *identifierSet) bool { + return seen.Contains(Identifiers.node) +} + +func isRelationshipPattern(seen *identifierSet) bool { + var ( + hasStart = seen.Contains(Identifiers.start) + hasRelationship = seen.Contains(Identifiers.relationship) + hasEnd = seen.Contains(Identifiers.end) + ) + + return hasStart || hasRelationship || hasEnd +} + +func prepareNodePattern(match *cypher.Match, seen *identifierSet) error { + if isRelationshipPattern(seen) { + return fmt.Errorf("query mixes node and relationship query identifiers") + } + + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: Identifiers.Node(), + }) + + return nil +} + +func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, relationshipKinds graph.Kinds, shortestPaths, allShortestPaths bool) error { + if shortestPaths && allShortestPaths { + return errors.New("query is requesting both all shortest paths and shortest paths") + } + + var ( + newPatternPart = match.NewPatternPart() + startNodeSeen = seen.Contains(Identifiers.start) + relationshipSeen = seen.Contains(Identifiers.relationship) + endNodeSeen = seen.Contains(Identifiers.end) + ) + + newPatternPart.ShortestPathPattern = shortestPaths + newPatternPart.AllShortestPathsPattern = allShortestPaths + + if startNodeSeen { + newPatternPart.AddPatternElements(&cypher.NodePattern{ + Variable: Identifiers.Start(), + }) + } else { + newPatternPart.AddPatternElements(&cypher.NodePattern{}) + } + + relationshipPattern := &cypher.RelationshipPattern{ + Kinds: relationshipKinds, + Direction: graph.DirectionOutbound, + } + + if relationshipSeen { + relationshipPattern.Variable = Identifiers.Relationship() + } + + if shortestPaths || allShortestPaths { + newPatternPart.Variable = Identifiers.Path() + relationshipPattern.Range = &cypher.PatternRange{} + } + + newPatternPart.AddPatternElements(relationshipPattern) + + if endNodeSeen { + newPatternPart.AddPatternElements(&cypher.NodePattern{ + Variable: Identifiers.End(), + }) + } else { + newPatternPart.AddPatternElements(&cypher.NodePattern{}) + } + + return nil +} + +type identifierSet struct { + identifiers map[string]struct{} +} + +func newIdentifierSet() *identifierSet { + return &identifierSet{ + identifiers: map[string]struct{}{}, + } +} + +func (s *identifierSet) Add(identifier string) { + s.identifiers[identifier] = struct{}{} +} + +func (s *identifierSet) Or(other *identifierSet) { + for otherIdentifier := range other.identifiers { + s.identifiers[otherIdentifier] = struct{}{} + } +} + +func (s *identifierSet) Contains(identifier string) bool { + _, containsIdentifier := s.identifiers[identifier] + return containsIdentifier +} + +func (s *identifierSet) CollectFromExpression(expr cypher.Expression) error { + if exprIdentifiers, err := extractCypherIdentifiers(expr); err != nil { + return err + } else { + s.Or(exprIdentifiers) + return nil + } +} + +type identifierExtractor struct { + walk.Visitor[cypher.SyntaxNode] + + seen *identifierSet + + inDelete bool + inUpdate bool + inCreate bool + inWhere bool +} + +func newIdentifierExtractor() *identifierExtractor { + return &identifierExtractor{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + seen: newIdentifierSet(), + } +} + +func (s *identifierExtractor) Enter(node cypher.SyntaxNode) { + switch typedNode := node.(type) { + case *cypher.Variable: + s.seen.Add(typedNode.Symbol) + + case *cypher.NodePattern: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.RelationshipPattern: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.PatternPart: + if typedNode.Variable != nil { + s.seen.Add(typedNode.Variable.Symbol) + } + + case *cypher.ProjectionItem: + if typedNode.Alias != nil { + s.seen.Add(typedNode.Alias.Symbol) + } + } +} + +func extractCypherIdentifiers(expression cypher.Expression) (*identifierSet, error) { + var ( + identifierExtractorVisitor = newIdentifierExtractor() + err = walk.Cypher(expression, identifierExtractorVisitor) + ) + + return identifierExtractorVisitor.seen, err +} From db03b7bc33761fa209b3aca2f07da2183efc4674 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 6 May 2026 23:32:30 -0700 Subject: [PATCH 02/55] feat(query/v2): materialize builder parameters --- query/v2/query.go | 77 ++++++++++++++++++++++++++++++++++++++---- query/v2/query_test.go | 10 +++++- query/v2/util.go | 55 ++++++++++++++++++++++++++++++ 3 files changed, 134 insertions(+), 8 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 75eb0d44..5710c9bd 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -44,7 +44,7 @@ var Identifiers = runtimeIdentifiers{ end: "e", } -func newLiteral(value any) *cypher.Literal { +func Literal(value any) *cypher.Literal { if value == nil { return cypher.NewLiteral(nil, true) } @@ -56,6 +56,65 @@ func newLiteral(value any) *cypher.Literal { return cypher.NewLiteral(value, false) } +func Parameter(value any) *cypher.Parameter { + if parameter, typeOK := value.(*cypher.Parameter); typeOK { + return parameter + } + + return &cypher.Parameter{ + Value: value, + } +} + +func NamedParameter(symbol string, value any) *cypher.Parameter { + return cypher.NewParameter(symbol, value) +} + +func valueExpression(value any) cypher.Expression { + switch typedValue := value.(type) { + case *cypher.Parameter: + return typedValue + case *cypher.Literal: + return typedValue + case *cypher.Variable: + return typedValue + case *cypher.PropertyLookup: + return typedValue + case *cypher.FunctionInvocation: + return typedValue + case *cypher.Parenthetical: + return typedValue + case *cypher.Comparison: + return typedValue + case *cypher.Negation: + return typedValue + case *cypher.Conjunction: + return typedValue + case *cypher.Disjunction: + return typedValue + case *cypher.ExclusiveDisjunction: + return typedValue + case *cypher.KindMatcher: + return typedValue + case *cypher.ListLiteral: + return typedValue + case cypher.MapLiteral: + return typedValue + case *cypher.PatternPredicate: + return typedValue + case *cypher.ArithmeticExpression: + return typedValue + case *cypher.UnaryAddOrSubtractExpression: + return typedValue + case *cypher.FilterExpression: + return typedValue + case *cypher.IDInCollection: + return typedValue + default: + return Parameter(value) + } +} + func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) cypher.SyntaxNode { expressionList := &cypher.Comparison{} @@ -171,7 +230,7 @@ func (s *comparisonContinuation) asComparison(operator cypher.Operator, rOperand return cypher.NewComparison( s.qualifier(), operator, - newLiteral(rOperand), + valueExpression(rOperand), ) } @@ -211,7 +270,7 @@ func (s *propertyContinuation) Set(value any) *cypher.SetItem { return cypher.NewSetItem( s.qualifier(), cypher.OperatorAssignment, - newLiteral(value), + valueExpression(value), ) } @@ -764,8 +823,12 @@ func (s *builder) Build() (*PreparedQuery, error) { singlePartQuery.ReadingClauses = append(singlePartQuery.ReadingClauses, newReadingClause) } - return &PreparedQuery{ - Query: regularQuery, - Parameters: map[string]any{}, - }, nil + if parameters, err := materializeParameters(regularQuery); err != nil { + return nil, err + } else { + return &PreparedQuery{ + Query: regularQuery, + Parameters: parameters, + }, nil + } } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 76e793f9..d464f30c 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -31,7 +31,12 @@ func TestQuery(t *testing.T) { cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) require.NoError(t, err) - require.Equal(t, "match (s)-[r]->() where not r:test and not (r:A or r:B) and r.rel_prop <= 1234 and r.other_prop = 5678 and s:test set s.this_prop = 1234 remove e:A:B delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) + require.Equal(t, "match (s)-[r]->() where not r:test and not (r:A or r:B) and r.rel_prop <= $p0 and r.other_prop = $p1 and s:test set s.this_prop = $p2 remove e:A:B delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) + require.Equal(t, map[string]any{ + "p0": 1234, + "p1": 5678, + "p2": 1234, + }, preparedQuery.Parameters) preparedQuery, err = v2.New().Create( v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, cypher.NewParameter("props", map[string]any{})), @@ -43,4 +48,7 @@ func TestQuery(t *testing.T) { require.NoError(t, err) require.Equal(t, "create (n:A $props)", cypherQueryStr) + require.Equal(t, map[string]any{ + "props": map[string]any{}, + }, preparedQuery.Parameters) } diff --git a/query/v2/util.go b/query/v2/util.go index d8bdaaba..6fe35983 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -3,6 +3,8 @@ package v2 import ( "errors" "fmt" + "reflect" + "strconv" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/walk" @@ -172,3 +174,56 @@ func extractCypherIdentifiers(expression cypher.Expression) (*identifierSet, err return identifierExtractorVisitor.seen, err } + +type parameterMaterializer struct { + walk.Visitor[cypher.SyntaxNode] + + parameters map[string]any + nextIndex int +} + +func newParameterMaterializer() *parameterMaterializer { + return ¶meterMaterializer{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + parameters: map[string]any{}, + } +} + +func (s *parameterMaterializer) nextSymbol() string { + for { + symbol := "p" + strconv.Itoa(s.nextIndex) + s.nextIndex++ + + if _, taken := s.parameters[symbol]; !taken { + return symbol + } + } +} + +func (s *parameterMaterializer) Enter(node cypher.SyntaxNode) { + parameter, typeOK := node.(*cypher.Parameter) + if !typeOK { + return + } + + if parameter.Symbol == "" { + parameter.Symbol = s.nextSymbol() + } + + if existingValue, exists := s.parameters[parameter.Symbol]; exists && !reflect.DeepEqual(existingValue, parameter.Value) { + s.SetErrorf("parameter %s is bound to multiple values", parameter.Symbol) + return + } + + s.parameters[parameter.Symbol] = parameter.Value +} + +func materializeParameters(query *cypher.RegularQuery) (map[string]any, error) { + materializer := newParameterMaterializer() + + if err := walk.Cypher(query, materializer); err != nil { + return nil, err + } + + return materializer.parameters, nil +} From f47a5c23310ae15b17bd324db07519854c701e65 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 6 May 2026 23:34:20 -0700 Subject: [PATCH 03/55] fix(query/v2): infer match patterns by query scope --- query/v2/query.go | 32 ++++++-- query/v2/query_test.go | 42 ++++++++++ query/v2/util.go | 169 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 235 insertions(+), 8 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 5710c9bd..9e913b0c 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -753,10 +753,15 @@ func (s *builder) Build() (*PreparedQuery, error) { var ( regularQuery, singlePartQuery = cypher.NewRegularQueryWithSingleQuery() match = &cypher.Match{} - seenIdentifiers = newIdentifierSet() + readIdentifiers = newIdentifierSet() relationshipKinds graph.Kinds ) + createScope, err := collectCreateScope(s.creates...) + if err != nil { + return nil, err + } + if err := s.buildUpdatingClauses(singlePartQuery); err != nil { return nil, err } @@ -778,6 +783,7 @@ func (s *builder) Build() (*PreparedQuery, error) { return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint) } else if identifier.Symbol == Identifiers.relationship { relationshipKinds = relationshipKinds.Add(typedNextConstraint.Kinds...) + readIdentifiers.Add(Identifiers.relationship) continue } } @@ -792,23 +798,33 @@ func (s *builder) Build() (*PreparedQuery, error) { if constraints.Left != nil { whereClause.Add(constraints) - if err := seenIdentifiers.CollectFromExpression(whereClause); err != nil { + if err := readIdentifiers.CollectFromExpression(whereClause); err != nil { return nil, err } } } - if err := seenIdentifiers.CollectFromExpression(singlePartQuery); err != nil { + actionIdentifiers, err := collectIdentifiersFromValues(s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems) + if err != nil { return nil, err } - if len(s.constraints) > 0 || len(s.creates) == 0 { - if isNodePattern(seenIdentifiers) { - if err := prepareNodePattern(match, seenIdentifiers); err != nil { + actionIdentifiers.Remove(createScope.identifiers) + + matchIdentifiers := readIdentifiers.Clone() + matchIdentifiers.Or(actionIdentifiers) + + if len(s.constraints) > 0 || len(s.creates) == 0 || matchIdentifiers.Len() > 0 { + if isNodePattern(matchIdentifiers) { + if err := prepareNodePattern(match, matchIdentifiers); err != nil { + return nil, err + } + } else if createScope.createsRelationship && !matchIdentifiers.Contains(Identifiers.relationship) { + if err := prepareCreateRelationshipMatch(match, matchIdentifiers); err != nil { return nil, err } - } else if isRelationshipPattern(seenIdentifiers) { - if err := prepareRelationshipPattern(match, seenIdentifiers, relationshipKinds, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { + } else if isRelationshipPattern(matchIdentifiers) { + if err := prepareRelationshipPattern(match, matchIdentifiers, relationshipKinds, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { return nil, err } } else { diff --git a/query/v2/query_test.go b/query/v2/query_test.go index d464f30c..743be443 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -10,6 +10,15 @@ import ( "github.com/stretchr/testify/require" ) +func renderPrepared(t *testing.T, preparedQuery *v2.PreparedQuery) string { + t.Helper() + + cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) + require.NoError(t, err) + + return cypherQueryStr +} + func TestQuery(t *testing.T) { preparedQuery, err := v2.New().Where( v2.Not(v2.Relationship().Kind().Is(graph.StringKind("test"))), @@ -52,3 +61,36 @@ func TestQuery(t *testing.T) { "props": map[string]any{}, }, preparedQuery.Parameters) } + +func TestCreateRelationshipWithMatchedEndpoints(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.Relationship().RelationshipPattern(graph.StringKind("A"), v2.NamedParameter("props", map[string]any{"name": "rel"}), graph.DirectionOutbound), + ).Return( + v2.Relationship().ID(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:A $props]->(e) return id(r)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "props": map[string]any{"name": "rel"}, + }, preparedQuery.Parameters) +} + +func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, v2.NamedParameter("props", map[string]any{"name": "node"})), + ).Return( + v2.Node().ID(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (n:A $props) return id(n)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "props": map[string]any{"name": "node"}, + }, preparedQuery.Parameters) +} diff --git a/query/v2/util.go b/query/v2/util.go index 6fe35983..75949e65 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -87,6 +87,22 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, relati return nil } +func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet) error { + if seen.Contains(Identifiers.start) { + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: Identifiers.Start(), + }) + } + + if seen.Contains(Identifiers.end) { + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ + Variable: Identifiers.End(), + }) + } + + return nil +} + type identifierSet struct { identifiers map[string]struct{} } @@ -101,12 +117,28 @@ func (s *identifierSet) Add(identifier string) { s.identifiers[identifier] = struct{}{} } +func (s *identifierSet) Len() int { + return len(s.identifiers) +} + +func (s *identifierSet) Clone() *identifierSet { + clone := newIdentifierSet() + clone.Or(s) + return clone +} + func (s *identifierSet) Or(other *identifierSet) { for otherIdentifier := range other.identifiers { s.identifiers[otherIdentifier] = struct{}{} } } +func (s *identifierSet) Remove(other *identifierSet) { + for otherIdentifier := range other.identifiers { + delete(s.identifiers, otherIdentifier) + } +} + func (s *identifierSet) Contains(identifier string) bool { _, containsIdentifier := s.identifiers[identifier] return containsIdentifier @@ -121,6 +153,143 @@ func (s *identifierSet) CollectFromExpression(expr cypher.Expression) error { } } +func (s *identifierSet) CollectFromValue(value any) error { + switch typedValue := value.(type) { + case nil: + return nil + + case QualifiedExpression: + return s.CollectFromExpression(typedValue.qualifier()) + + case kindContinuation: + s.Add(typedValue.identifier.Symbol) + return nil + + case kindsContinuation: + s.Add(typedValue.identifier.Symbol) + return nil + + case *cypher.Return: + return s.CollectFromExpression(typedValue) + + case *cypher.Order: + return s.CollectFromExpression(typedValue) + + case *cypher.SortItem: + return s.CollectFromExpression(typedValue) + + case *cypher.Set: + return s.CollectFromExpression(typedValue) + + case *cypher.SetItem: + return s.CollectFromExpression(typedValue) + + case *cypher.Remove: + return s.CollectFromExpression(typedValue) + + case *cypher.RemoveItem: + return s.CollectFromExpression(typedValue) + + case *cypher.NodePattern: + return s.CollectFromExpression(typedValue) + + case *cypher.RelationshipPattern: + return s.CollectFromExpression(typedValue) + + case *cypher.Variable: + return s.CollectFromExpression(typedValue) + + case *cypher.FunctionInvocation: + return s.CollectFromExpression(typedValue) + + case *cypher.PropertyLookup: + return s.CollectFromExpression(typedValue) + + case []any: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []cypher.SyntaxNode: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []cypher.Expression: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []*cypher.SetItem: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + case []*cypher.RemoveItem: + for _, item := range typedValue { + if err := s.CollectFromValue(item); err != nil { + return err + } + } + + default: + return nil + } + + return nil +} + +func collectIdentifiersFromValues(values ...any) (*identifierSet, error) { + identifiers := newIdentifierSet() + + for _, value := range values { + if err := identifiers.CollectFromValue(value); err != nil { + return nil, err + } + } + + return identifiers, nil +} + +type createScope struct { + identifiers *identifierSet + createsRelationship bool +} + +func collectCreateScope(values ...any) (*createScope, error) { + scope := &createScope{ + identifiers: newIdentifierSet(), + } + + for _, value := range values { + switch typedValue := value.(type) { + case *cypher.RelationshipPattern: + scope.createsRelationship = true + scope.identifiers.Add(Identifiers.start) + scope.identifiers.Add(Identifiers.end) + + if typedValue.Variable != nil { + scope.identifiers.Add(typedValue.Variable.Symbol) + } + + default: + if err := scope.identifiers.CollectFromValue(value); err != nil { + return nil, err + } + } + } + + return scope, nil +} + type identifierExtractor struct { walk.Visitor[cypher.SyntaxNode] From f656bc09f603443fcdbda3ccd067ac13140a3adb Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 6 May 2026 23:35:54 -0700 Subject: [PATCH 04/55] feat(query/v2): add typed projection and order helpers --- query/v2/query.go | 94 +++++++++++++++++------------ query/v2/query_test.go | 19 +++++- query/v2/util.go | 130 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 206 insertions(+), 37 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 9e913b0c..8ad8fdc0 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -141,6 +141,39 @@ func Or(operands ...cypher.SyntaxNode) cypher.SyntaxNode { return joinedExpressionList(cypher.OperatorOr, operands) } +type SortDirection int + +const ( + SortAscending SortDirection = iota + SortDescending +) + +func Asc(expression any) *cypher.SortItem { + return Order(expression, SortAscending) +} + +func Desc(expression any) *cypher.SortItem { + return Order(expression, SortDescending) +} + +func Order(expression any, direction SortDirection) *cypher.SortItem { + sortExpression, _ := projectionExpression(expression) + + return &cypher.SortItem{ + Ascending: direction != SortDescending, + Expression: sortExpression, + } +} + +func As(expression any, alias string) *cypher.ProjectionItem { + projectionExpression, _ := projectionExpression(expression) + + return &cypher.ProjectionItem{ + Expression: projectionExpression, + Alias: cypher.NewVariableWithSymbol(alias), + } +} + func Node() NodeContinuation { return &entity[NodeContinuation]{ identifier: Identifiers.Node(), @@ -427,10 +460,11 @@ type NodeContinuation interface { type QueryBuilder interface { Where(constraints ...cypher.SyntaxNode) QueryBuilder - OrderBy(sortItems ...cypher.SyntaxNode) QueryBuilder + OrderBy(sortItems ...any) QueryBuilder Skip(offset int) QueryBuilder Limit(limit int) QueryBuilder Return(projections ...any) QueryBuilder + ReturnDistinct(projections ...any) QueryBuilder Update(updatingClauses ...any) QueryBuilder Create(creationClauses ...any) QueryBuilder Delete(expressions ...any) QueryBuilder @@ -442,8 +476,9 @@ type QueryBuilder interface { type builder struct { errors []error constraints []cypher.SyntaxNode - sortItems []cypher.SyntaxNode + sortItems []any projections []any + distinct bool creates []any setItems []*cypher.SetItem removeItems []*cypher.RemoveItem @@ -469,7 +504,7 @@ func (s *builder) WithAllShortestPaths() QueryBuilder { return s } -func (s *builder) OrderBy(sortItems ...cypher.SyntaxNode) QueryBuilder { +func (s *builder) OrderBy(sortItems ...any) QueryBuilder { s.sortItems = append(s.sortItems, sortItems...) return s } @@ -489,6 +524,12 @@ func (s *builder) Return(projections ...any) QueryBuilder { return s } +func (s *builder) ReturnDistinct(projections ...any) QueryBuilder { + s.distinct = true + s.projections = append(s.projections, projections...) + return s +} + func (s *builder) Create(creationClauses ...any) QueryBuilder { s.creates = append(s.creates, creationClauses...) return s @@ -523,8 +564,7 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { case QualifiedExpression: qualifier := typedNextUpdate.qualifier() - switch qualifier { - case Identifiers.node, Identifiers.start, Identifiers.end: + if isDetachDeleteQualifier(qualifier) { s.detachDelete = true } @@ -647,6 +687,13 @@ func (s *builder) buildProjectionOrder() (*cypher.Order, error) { case *cypher.SortItem: orderByNode.Items = append(orderByNode.Items, typedSortItem) + + default: + if sortItem, err := sortItemFromValue(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItem) + } } } } @@ -667,7 +714,7 @@ func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error return fmt.Errorf("query expected projected items") } - projection := singlePartQuery.NewProjection(false) + projection := singlePartQuery.NewProjection(s.distinct) for _, nextProjection := range s.projections { switch typedNextProjection := nextProjection.(type) { @@ -680,37 +727,12 @@ func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error } } - case QualifiedExpression: - projection.AddItem(cypher.NewProjectionItemWithExpr(typedNextProjection.qualifier())) - - case kindContinuation: - var kindExpr cypher.Expression - - switch typedNextProjection.identifier.Symbol { - case Identifiers.node, Identifiers.start, Identifiers.end: - kindExpr = cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedNextProjection.identifier) - - case Identifiers.relationship: - kindExpr = cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedNextProjection.identifier) - } - - projection.AddItem(cypher.NewProjectionItemWithExpr(kindExpr)) - - case kindsContinuation: - var kindExpr cypher.Expression - - switch typedNextProjection.identifier.Symbol { - case Identifiers.node, Identifiers.start, Identifiers.end: - kindExpr = cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedNextProjection.identifier) - - case Identifiers.relationship: - kindExpr = cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedNextProjection.identifier) - } - - projection.AddItem(cypher.NewProjectionItemWithExpr(kindExpr)) - default: - projection.AddItem(cypher.NewProjectionItemWithExpr(typedNextProjection)) + if projectionItem, err := projectionItemFromValue(typedNextProjection); err != nil { + return err + } else { + projection.AddItem(projectionItem) + } } } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 743be443..5b58e36d 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -40,7 +40,7 @@ func TestQuery(t *testing.T) { cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) require.NoError(t, err) - require.Equal(t, "match (s)-[r]->() where not r:test and not (r:A or r:B) and r.rel_prop <= $p0 and r.other_prop = $p1 and s:test set s.this_prop = $p2 remove e:A:B delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) + require.Equal(t, "match (s)-[r]->() where not r:test and not (r:A or r:B) and r.rel_prop <= $p0 and r.other_prop = $p1 and s:test set s.this_prop = $p2 remove e:A:B detach delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) require.Equal(t, map[string]any{ "p0": 1234, "p1": 5678, @@ -94,3 +94,20 @@ func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { "props": map[string]any{"name": "node"}, }, preparedQuery.Parameters) } + +func TestProjectionAndOrderHelpers(t *testing.T) { + preparedQuery, err := v2.New().ReturnDistinct( + v2.As(v2.Node().ID(), "node_id"), + ).OrderBy( + v2.Node().Property("name"), + v2.Desc(v2.Node().ID()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) as node_id order by n.name asc, id(n) desc", renderPrepared(t, preparedQuery)) +} + +func TestUnsupportedOrderByTypeReturnsError(t *testing.T) { + _, err := v2.New().Return(v2.Node()).OrderBy(123).Build() + require.ErrorContains(t, err, "unsupported expression type: int") +} diff --git a/query/v2/util.go b/query/v2/util.go index 75949e65..d9e8fe16 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -103,6 +103,136 @@ func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet) er return nil } +func isDetachDeleteQualifier(qualifier cypher.Expression) bool { + variable, typeOK := qualifier.(*cypher.Variable) + if !typeOK { + return false + } + + switch variable.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + return true + default: + return false + } +} + +func kindProjectionExpression(identifier *cypher.Variable) (cypher.Expression, error) { + switch identifier.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, identifier), nil + + case Identifiers.relationship: + return cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, identifier), nil + + default: + return nil, fmt.Errorf("invalid kind projection reference: %s", identifier.Symbol) + } +} + +func projectionExpression(value any) (cypher.Expression, error) { + switch typedValue := value.(type) { + case QualifiedExpression: + return typedValue.qualifier(), nil + + case kindContinuation: + return kindProjectionExpression(typedValue.identifier) + + case kindsContinuation: + return kindProjectionExpression(typedValue.identifier) + + case *cypher.ProjectionItem: + return typedValue.Expression, nil + + case *cypher.Parameter: + return typedValue, nil + + case *cypher.Literal: + return typedValue, nil + + case *cypher.Variable: + return typedValue, nil + + case *cypher.PropertyLookup: + return typedValue, nil + + case *cypher.FunctionInvocation: + return typedValue, nil + + case *cypher.Parenthetical: + return typedValue, nil + + case *cypher.Comparison: + return typedValue, nil + + case *cypher.Negation: + return typedValue, nil + + case *cypher.Conjunction: + return typedValue, nil + + case *cypher.Disjunction: + return typedValue, nil + + case *cypher.ExclusiveDisjunction: + return typedValue, nil + + case *cypher.KindMatcher: + return typedValue, nil + + case *cypher.ListLiteral: + return typedValue, nil + + case cypher.MapLiteral: + return typedValue, nil + + case *cypher.PatternPredicate: + return typedValue, nil + + case *cypher.ArithmeticExpression: + return typedValue, nil + + case *cypher.UnaryAddOrSubtractExpression: + return typedValue, nil + + case *cypher.FilterExpression: + return typedValue, nil + + case *cypher.IDInCollection: + return typedValue, nil + + default: + return nil, fmt.Errorf("unsupported expression type: %T", value) + } +} + +func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { + if projectionItem, typeOK := value.(*cypher.ProjectionItem); typeOK { + return projectionItem, nil + } + + if expression, err := projectionExpression(value); err != nil { + return nil, err + } else { + return cypher.NewProjectionItemWithExpr(expression), nil + } +} + +func sortItemFromValue(value any) (*cypher.SortItem, error) { + if sortItem, typeOK := value.(*cypher.SortItem); typeOK { + return sortItem, nil + } + + if expression, err := projectionExpression(value); err != nil { + return nil, err + } else { + return &cypher.SortItem{ + Ascending: true, + Expression: expression, + }, nil + } +} + type identifierSet struct { identifiers map[string]struct{} } From 39aa3fc624a14b35133e5b01a6c745e4ce9ccfe4 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 6 May 2026 23:37:30 -0700 Subject: [PATCH 05/55] feat(query/v2): add query helper parity --- query/v2/compat.go | 320 +++++++++++++++++++++++++++++++++++++++++ query/v2/query.go | 20 +++ query/v2/query_test.go | 37 +++++ 3 files changed, 377 insertions(+) create mode 100644 query/v2/compat.go diff --git a/query/v2/compat.go b/query/v2/compat.go new file mode 100644 index 00000000..bee8cf19 --- /dev/null +++ b/query/v2/compat.go @@ -0,0 +1,320 @@ +package v2 + +import ( + "strings" + "time" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" +) + +func Variable(name string) *cypher.Variable { + return cypher.NewVariableWithSymbol(name) +} + +func Identity(reference any) *cypher.FunctionInvocation { + expression, _ := projectionExpression(reference) + + return cypher.NewSimpleFunctionInvocation(cypher.IdentityFunction, expression) +} + +func NodeID() *cypher.FunctionInvocation { + return Identity(Identifiers.Node()) +} + +func RelationshipID() *cypher.FunctionInvocation { + return Identity(Identifiers.Relationship()) +} + +func StartID() *cypher.FunctionInvocation { + return Identity(Identifiers.Start()) +} + +func EndID() *cypher.FunctionInvocation { + return Identity(Identifiers.End()) +} + +func Count(reference any) *cypher.FunctionInvocation { + expression, _ := projectionExpression(reference) + return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, expression) +} + +func CountDistinct(reference any) *cypher.FunctionInvocation { + expression, _ := projectionExpression(reference) + + return &cypher.FunctionInvocation{ + Name: cypher.CountFunction, + Distinct: true, + Arguments: []cypher.Expression{expression}, + } +} + +func Size(expression any) *cypher.FunctionInvocation { + expr, _ := projectionExpression(expression) + return cypher.NewSimpleFunctionInvocation(cypher.ListSizeFunction, expr) +} + +func KindsOf(reference any) *cypher.FunctionInvocation { + expression, _ := projectionExpression(reference) + + switch typedExpression := expression.(type) { + case *cypher.Variable: + switch typedExpression.Symbol { + case Identifiers.node, Identifiers.start, Identifiers.end: + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, typedExpression) + + case Identifiers.relationship: + return cypher.NewSimpleFunctionInvocation(cypher.EdgeTypeFunction, typedExpression) + } + } + + return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, expression) +} + +func Kind(reference any, kinds ...graph.Kind) *cypher.KindMatcher { + expression, _ := projectionExpression(reference) + + return &cypher.KindMatcher{ + Reference: expression, + Kinds: kinds, + } +} + +func KindIn(reference any, kinds ...graph.Kind) *cypher.KindMatcher { + return Kind(reference, kinds...) +} + +func AddKind(reference any, kind graph.Kind) *cypher.SetItem { + return AddKinds(reference, graph.Kinds{kind}) +} + +func AddKinds(reference any, kinds graph.Kinds) *cypher.SetItem { + expression, _ := projectionExpression(reference) + return cypher.NewSetItem(expression, cypher.OperatorLabelAssignment, kinds) +} + +func DeleteKind(reference any, kind graph.Kind) *cypher.RemoveItem { + return DeleteKinds(reference, graph.Kinds{kind}) +} + +func DeleteKinds(reference any, kinds graph.Kinds) *cypher.RemoveItem { + expression, _ := projectionExpression(reference) + return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(expression, kinds, false)) +} + +func SetProperty(reference any, value any) *cypher.SetItem { + expression, _ := projectionExpression(reference) + return cypher.NewSetItem(expression, cypher.OperatorAssignment, valueExpression(value)) +} + +func SetProperties(reference any, properties map[string]any) *cypher.Set { + set := &cypher.Set{} + expression, _ := projectionExpression(reference) + variable, _ := expression.(*cypher.Variable) + + for key, value := range properties { + set.Items = append(set.Items, cypher.NewSetItem( + cypher.NewPropertyLookup(variable.Symbol, key), + cypher.OperatorAssignment, + valueExpression(value), + )) + } + + return set +} + +func DeleteProperty(reference any) *cypher.RemoveItem { + expression, _ := projectionExpression(reference) + return cypher.RemoveProperty(expression) +} + +func DeleteProperties(reference any, propertyNames ...string) *cypher.Remove { + remove := &cypher.Remove{} + expression, _ := projectionExpression(reference) + variable, _ := expression.(*cypher.Variable) + + for _, propertyName := range propertyNames { + remove.Items = append(remove.Items, cypher.RemoveProperty(cypher.NewPropertyLookup(variable.Symbol, propertyName))) + } + + return remove +} + +func NodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.Node(), + Kinds: kinds, + Properties: properties, + } +} + +func StartNodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.Start(), + Kinds: kinds, + Properties: properties, + } +} + +func EndNodePattern(kinds graph.Kinds, properties cypher.Expression) *cypher.NodePattern { + return &cypher.NodePattern{ + Variable: Identifiers.End(), + Kinds: kinds, + Properties: properties, + } +} + +func RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) *cypher.RelationshipPattern { + return &cypher.RelationshipPattern{ + Variable: Identifiers.Relationship(), + Kinds: graph.Kinds{kind}, + Direction: direction, + Properties: properties, + } +} + +func Equals(reference any, value any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorEquals, valueExpression(value)) +} + +func GreaterThan(reference any, value any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorGreaterThan, valueExpression(value)) +} + +func After(reference any, value any) cypher.Expression { + return GreaterThan(reference, value) +} + +func GreaterThanOrEqualTo(reference any, value any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorGreaterThanOrEqualTo, valueExpression(value)) +} + +func GreaterThanOrEquals(reference any, value any) cypher.Expression { + return GreaterThanOrEqualTo(reference, value) +} + +func LessThan(reference any, value any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorLessThan, valueExpression(value)) +} + +func LessThanGraphQuery(reference any, other any) cypher.Expression { + return LessThan(reference, other) +} + +func Before(reference any, value time.Time) cypher.Expression { + return LessThan(reference, value) +} + +func BeforeGraphQuery(reference any, other any) cypher.Expression { + return LessThan(reference, other) +} + +func LessThanOrEqualTo(reference any, value any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorLessThanOrEqualTo, valueExpression(value)) +} + +func LessThanOrEquals(reference any, value any) cypher.Expression { + return LessThanOrEqualTo(reference, value) +} + +func In(reference any, value any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorIn, valueExpression(value)) +} + +func InInverted(reference any, value any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(valueExpression(value), cypher.OperatorIn, expression) +} + +func InIDs(reference any, ids ...graph.ID) cypher.Expression { + expression, _ := projectionExpression(reference) + + if variable, typeOK := expression.(*cypher.Variable); typeOK { + expression = Identity(variable) + } + + return cypher.NewComparison(expression, cypher.OperatorIn, Parameter(ids)) +} + +func StringContains(reference any, value string) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorContains, Parameter(value)) +} + +func StringStartsWith(reference any, value string) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorStartsWith, Parameter(value)) +} + +func StringEndsWith(reference any, value string) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorEndsWith, Parameter(value)) +} + +func CaseInsensitiveStringContains(reference any, value string) cypher.Expression { + expression, _ := projectionExpression(reference) + + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expression), + cypher.OperatorContains, + Parameter(strings.ToLower(value)), + ) +} + +func CaseInsensitiveStringStartsWith(reference any, value string) cypher.Expression { + expression, _ := projectionExpression(reference) + + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expression), + cypher.OperatorStartsWith, + Parameter(strings.ToLower(value)), + ) +} + +func CaseInsensitiveStringEndsWith(reference any, value string) cypher.Expression { + expression, _ := projectionExpression(reference) + + return cypher.NewComparison( + cypher.NewSimpleFunctionInvocation("toLower", expression), + cypher.OperatorEndsWith, + Parameter(strings.ToLower(value)), + ) +} + +func Exists(reference any) cypher.Expression { + return IsNotNull(reference) +} + +func IsNull(reference any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorIs, Literal(nil)) +} + +func IsNotNull(reference any) cypher.Expression { + expression, _ := projectionExpression(reference) + return cypher.NewComparison(expression, cypher.OperatorIsNot, Literal(nil)) +} + +func HasRelationships(reference any) *cypher.PatternPredicate { + expression, _ := projectionExpression(reference) + variable, _ := expression.(*cypher.Variable) + + patternPredicate := cypher.NewPatternPredicate() + patternPredicate.AddElement(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(variable.Symbol), + }) + + patternPredicate.AddElement(&cypher.RelationshipPattern{ + Direction: graph.DirectionBoth, + }) + + patternPredicate.AddElement(&cypher.NodePattern{}) + + return patternPredicate +} diff --git a/query/v2/query.go b/query/v2/query.go index 8ad8fdc0..923f0737 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -231,11 +231,15 @@ type KindsContinuation interface { type Comparable interface { In(value any) cypher.Expression Contains(value any) cypher.Expression + StartsWith(value any) cypher.Expression + EndsWith(value any) cypher.Expression Equals(value any) cypher.Expression GreaterThan(value any) cypher.Expression GreaterThanOrEqualTo(value any) cypher.Expression LessThan(value any) cypher.Expression LessThanOrEqualTo(value any) cypher.Expression + IsNull() cypher.Expression + IsNotNull() cypher.Expression } type PropertyContinuation interface { @@ -275,6 +279,14 @@ func (s *comparisonContinuation) Contains(value any) cypher.Expression { return s.asComparison(cypher.OperatorContains, value) } +func (s *comparisonContinuation) StartsWith(value any) cypher.Expression { + return s.asComparison(cypher.OperatorStartsWith, value) +} + +func (s *comparisonContinuation) EndsWith(value any) cypher.Expression { + return s.asComparison(cypher.OperatorEndsWith, value) +} + func (s *comparisonContinuation) Equals(value any) cypher.Expression { return s.asComparison(cypher.OperatorEquals, value) } @@ -295,6 +307,14 @@ func (s *comparisonContinuation) LessThanOrEqualTo(value any) cypher.Expression return s.asComparison(cypher.OperatorLessThanOrEqualTo, value) } +func (s *comparisonContinuation) IsNull() cypher.Expression { + return cypher.NewComparison(s.qualifier(), cypher.OperatorIs, Literal(nil)) +} + +func (s *comparisonContinuation) IsNotNull() cypher.Expression { + return cypher.NewComparison(s.qualifier(), cypher.OperatorIsNot, Literal(nil)) +} + type propertyContinuation struct { comparisonContinuation } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 5b58e36d..f764f10b 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -111,3 +111,40 @@ func TestUnsupportedOrderByTypeReturnsError(t *testing.T) { _, err := v2.New().Return(v2.Node()).OrderBy(123).Build() require.ErrorContains(t, err, "unsupported expression type: int") } + +func TestCompatibilityHelpers(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.And( + v2.InIDs(v2.NodeID(), 1, 2), + v2.KindIn(v2.Node(), graph.StringKind("User")), + v2.CaseInsensitiveStringContains(v2.Node().Property("name"), "ADMIN"), + v2.IsNotNull(v2.Node().Property("enabled")), + ), + ).Return( + v2.CountDistinct(v2.Node()), + v2.KindsOf(v2.Node()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) in $p0 and n:User and toLower(n.name) contains $p1 and n.enabled is not null return count(distinct n), labels(n)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": []graph.ID{1, 2}, + "p1": "admin", + }, preparedQuery.Parameters) +} + +func TestUpdateCompatibilityHelpers(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.AddKind(v2.Node(), graph.StringKind("Enabled")), + v2.SetProperties(v2.Node(), map[string]any{"name": "updated"}), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) = $p0 set n:Enabled, n.name = $p1", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": "updated", + }, preparedQuery.Parameters) +} From f6488ebf559c539eb7a7efb1b32c29c4eb3854fc Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 6 May 2026 23:38:32 -0700 Subject: [PATCH 06/55] test(query/v2): cover backend preparation paths --- query/v2/backend_test.go | 130 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 130 insertions(+) create mode 100644 query/v2/backend_test.go diff --git a/query/v2/backend_test.go b/query/v2/backend_test.go new file mode 100644 index 00000000..48b52ea3 --- /dev/null +++ b/query/v2/backend_test.go @@ -0,0 +1,130 @@ +package v2_test + +import ( + "context" + "testing" + + "github.com/specterops/dawgs/cypher/models/pgsql/translate" + "github.com/specterops/dawgs/drivers/pg/pgutil" + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/query/neo4j" + v2 "github.com/specterops/dawgs/query/v2" + "github.com/stretchr/testify/require" +) + +func testKindMapper(kinds ...graph.Kind) *pgutil.InMemoryKindMapper { + mapper := pgutil.NewInMemoryKindMapper() + + for _, kind := range kinds { + mapper.Put(kind) + } + + return mapper +} + +func TestBackendParityNeo4jPrepare(t *testing.T) { + cases := map[string]v2.QueryBuilder{ + "node read": v2.New().Where( + v2.Node().Kinds().Has(graph.StringKind("User")), + v2.Node().Property("name").Contains("admin"), + ).Return( + v2.Node(), + ).OrderBy( + v2.Node().Property("name"), + ), + "relationship read": v2.New().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + "create node": v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})), + ).Return( + v2.Node().ID(), + ), + "update node": v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "updated"), + ), + "delete relationship": v2.New().Where( + v2.Relationship().ID().Equals(1), + ).Delete( + v2.Relationship(), + ), + } + + for name, builder := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := builder.Build() + require.NoError(t, err) + + queryBuilder := neo4j.NewQueryBuilder(preparedQuery.Query) + require.NoError(t, queryBuilder.Prepare()) + + rendered, err := queryBuilder.Render() + require.NoError(t, err) + require.NotEmpty(t, rendered) + require.NotEmpty(t, queryBuilder.Parameters) + }) + } +} + +func TestBackendParityPGTranslate(t *testing.T) { + userKind := graph.StringKind("User") + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(userKind, edgeKind) + + cases := map[string]v2.QueryBuilder{ + "node read": v2.New().Where( + v2.Node().Kinds().Has(userKind), + v2.Node().Property("name").Contains("admin"), + ).Return( + v2.Node().ID(), + v2.Node().Kinds(), + ), + "relationship read": v2.New().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + "create relationship": v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.RelationshipPattern(edgeKind, nil, graph.DirectionOutbound), + ).Return( + v2.Relationship().ID(), + ), + "update node": v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "updated"), + ), + "delete relationship": v2.New().Where( + v2.Relationship().ID().Equals(1), + ).Delete( + v2.Relationship(), + ), + } + + for name, builder := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := builder.Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + require.NotEmpty(t, sql) + }) + } +} From b014b7980d0591cfb5c0bb7e6c85eba8cf3673d1 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Wed, 6 May 2026 23:39:06 -0700 Subject: [PATCH 07/55] test(pgsql): exercise translator with v2 builder --- cypher/models/pgsql/test/query_test.go | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/cypher/models/pgsql/test/query_test.go b/cypher/models/pgsql/test/query_test.go index 38ec85ab..abb2d388 100644 --- a/cypher/models/pgsql/test/query_test.go +++ b/cypher/models/pgsql/test/query_test.go @@ -5,12 +5,11 @@ import ( "slices" "testing" - "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/pgsql" "github.com/specterops/dawgs/cypher/models/pgsql/translate" "github.com/specterops/dawgs/cypher/models/walk" "github.com/specterops/dawgs/graph" - "github.com/specterops/dawgs/query" + v2 "github.com/specterops/dawgs/query/v2" ) var ( @@ -24,19 +23,18 @@ var ( func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) { mapper := newKindMapper() - queries := []*cypher.Where{ - query.Where(query.KindIn(query.Node(), NodeKind1)), - query.Where(query.Kind(query.Node(), NodeKind2)), + queries := []v2.QueryBuilder{ + v2.New().Where(v2.KindIn(v2.Node(), NodeKind1)).Return(v2.Node()), + v2.New().Where(v2.Kind(v2.Node(), NodeKind2)).Return(v2.Node()), } - for _, nodeQuery := range queries { - builder := query.NewBuilderWithCriteria(nodeQuery) - builtQuery, err := builder.Build(false) + for _, queryBuilder := range queries { + builtQuery, err := queryBuilder.Build() if err != nil { t.Errorf("could not build query: %v", err) } - translatedQuery, err := translate.Translate(context.Background(), builtQuery, mapper, nil, translate.DefaultGraphID) + translatedQuery, err := translate.Translate(context.Background(), builtQuery.Query, mapper, builtQuery.Parameters, translate.DefaultGraphID) if err != nil { t.Errorf("could not translate query: %#v: %v", builtQuery, err) } @@ -47,7 +45,7 @@ func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) { switch leftTyped := typedNode.LOperand.(type) { case pgsql.CompoundIdentifier: if slices.Equal(leftTyped, pgsql.AsCompoundIdentifier("n0", "kind_ids")) && typedNode.Operator != pgsql.OperatorPGArrayOverlap { - t.Errorf("query did not generate an array overlap operator (&&): %#v", nodeQuery) + t.Errorf("query did not generate an array overlap operator (&&): %#v", builtQuery) } } } From b7cc7788eeca06cc63fb239604d205576fb25c05 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 00:44:37 -0700 Subject: [PATCH 08/55] fix(query/v2): surface helper validation errors --- query/v2/compat.go | 106 ++++++++++------------------ query/v2/query.go | 20 ++++-- query/v2/query_test.go | 35 ++++++++++ query/v2/util.go | 153 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 240 insertions(+), 74 deletions(-) diff --git a/query/v2/compat.go b/query/v2/compat.go index bee8cf19..521820ba 100644 --- a/query/v2/compat.go +++ b/query/v2/compat.go @@ -13,9 +13,7 @@ func Variable(name string) *cypher.Variable { } func Identity(reference any) *cypher.FunctionInvocation { - expression, _ := projectionExpression(reference) - - return cypher.NewSimpleFunctionInvocation(cypher.IdentityFunction, expression) + return cypher.NewSimpleFunctionInvocation(cypher.IdentityFunction, expressionOrError(reference)) } func NodeID() *cypher.FunctionInvocation { @@ -35,27 +33,23 @@ func EndID() *cypher.FunctionInvocation { } func Count(reference any) *cypher.FunctionInvocation { - expression, _ := projectionExpression(reference) - return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, expression) + return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, expressionOrError(reference)) } func CountDistinct(reference any) *cypher.FunctionInvocation { - expression, _ := projectionExpression(reference) - return &cypher.FunctionInvocation{ Name: cypher.CountFunction, Distinct: true, - Arguments: []cypher.Expression{expression}, + Arguments: []cypher.Expression{expressionOrError(reference)}, } } func Size(expression any) *cypher.FunctionInvocation { - expr, _ := projectionExpression(expression) - return cypher.NewSimpleFunctionInvocation(cypher.ListSizeFunction, expr) + return cypher.NewSimpleFunctionInvocation(cypher.ListSizeFunction, expressionOrError(expression)) } func KindsOf(reference any) *cypher.FunctionInvocation { - expression, _ := projectionExpression(reference) + expression := expressionOrError(reference) switch typedExpression := expression.(type) { case *cypher.Variable: @@ -72,10 +66,8 @@ func KindsOf(reference any) *cypher.FunctionInvocation { } func Kind(reference any, kinds ...graph.Kind) *cypher.KindMatcher { - expression, _ := projectionExpression(reference) - return &cypher.KindMatcher{ - Reference: expression, + Reference: expressionOrError(reference), Kinds: kinds, } } @@ -89,8 +81,7 @@ func AddKind(reference any, kind graph.Kind) *cypher.SetItem { } func AddKinds(reference any, kinds graph.Kinds) *cypher.SetItem { - expression, _ := projectionExpression(reference) - return cypher.NewSetItem(expression, cypher.OperatorLabelAssignment, kinds) + return cypher.NewSetItem(expressionOrError(reference), cypher.OperatorLabelAssignment, kinds) } func DeleteKind(reference any, kind graph.Kind) *cypher.RemoveItem { @@ -98,23 +89,19 @@ func DeleteKind(reference any, kind graph.Kind) *cypher.RemoveItem { } func DeleteKinds(reference any, kinds graph.Kinds) *cypher.RemoveItem { - expression, _ := projectionExpression(reference) - return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(expression, kinds, false)) + return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(expressionOrError(reference), kinds, false)) } func SetProperty(reference any, value any) *cypher.SetItem { - expression, _ := projectionExpression(reference) - return cypher.NewSetItem(expression, cypher.OperatorAssignment, valueExpression(value)) + return cypher.NewSetItem(expressionOrError(reference), cypher.OperatorAssignment, valueExpression(value)) } func SetProperties(reference any, properties map[string]any) *cypher.Set { set := &cypher.Set{} - expression, _ := projectionExpression(reference) - variable, _ := expression.(*cypher.Variable) for key, value := range properties { set.Items = append(set.Items, cypher.NewSetItem( - cypher.NewPropertyLookup(variable.Symbol, key), + propertyLookupOrError(reference, key), cypher.OperatorAssignment, valueExpression(value), )) @@ -124,17 +111,14 @@ func SetProperties(reference any, properties map[string]any) *cypher.Set { } func DeleteProperty(reference any) *cypher.RemoveItem { - expression, _ := projectionExpression(reference) - return cypher.RemoveProperty(expression) + return cypher.RemoveProperty(expressionOrError(reference)) } func DeleteProperties(reference any, propertyNames ...string) *cypher.Remove { remove := &cypher.Remove{} - expression, _ := projectionExpression(reference) - variable, _ := expression.(*cypher.Variable) for _, propertyName := range propertyNames { - remove.Items = append(remove.Items, cypher.RemoveProperty(cypher.NewPropertyLookup(variable.Symbol, propertyName))) + remove.Items = append(remove.Items, cypher.RemoveProperty(propertyLookupOrError(reference, propertyName))) } return remove @@ -174,13 +158,11 @@ func RelationshipPattern(kind graph.Kind, properties cypher.Expression, directio } func Equals(reference any, value any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorEquals, valueExpression(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorEquals, valueExpression(value)) } func GreaterThan(reference any, value any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorGreaterThan, valueExpression(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorGreaterThan, valueExpression(value)) } func After(reference any, value any) cypher.Expression { @@ -188,8 +170,7 @@ func After(reference any, value any) cypher.Expression { } func GreaterThanOrEqualTo(reference any, value any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorGreaterThanOrEqualTo, valueExpression(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorGreaterThanOrEqualTo, valueExpression(value)) } func GreaterThanOrEquals(reference any, value any) cypher.Expression { @@ -197,8 +178,7 @@ func GreaterThanOrEquals(reference any, value any) cypher.Expression { } func LessThan(reference any, value any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorLessThan, valueExpression(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorLessThan, valueExpression(value)) } func LessThanGraphQuery(reference any, other any) cypher.Expression { @@ -214,8 +194,7 @@ func BeforeGraphQuery(reference any, other any) cypher.Expression { } func LessThanOrEqualTo(reference any, value any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorLessThanOrEqualTo, valueExpression(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorLessThanOrEqualTo, valueExpression(value)) } func LessThanOrEquals(reference any, value any) cypher.Expression { @@ -223,17 +202,15 @@ func LessThanOrEquals(reference any, value any) cypher.Expression { } func In(reference any, value any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorIn, valueExpression(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIn, valueExpression(value)) } func InInverted(reference any, value any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(valueExpression(value), cypher.OperatorIn, expression) + return cypher.NewComparison(valueExpression(value), cypher.OperatorIn, expressionOrError(reference)) } func InIDs(reference any, ids ...graph.ID) cypher.Expression { - expression, _ := projectionExpression(reference) + expression := expressionOrError(reference) if variable, typeOK := expression.(*cypher.Variable); typeOK { expression = Identity(variable) @@ -243,45 +220,36 @@ func InIDs(reference any, ids ...graph.ID) cypher.Expression { } func StringContains(reference any, value string) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorContains, Parameter(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorContains, Parameter(value)) } func StringStartsWith(reference any, value string) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorStartsWith, Parameter(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorStartsWith, Parameter(value)) } func StringEndsWith(reference any, value string) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorEndsWith, Parameter(value)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorEndsWith, Parameter(value)) } func CaseInsensitiveStringContains(reference any, value string) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison( - cypher.NewSimpleFunctionInvocation("toLower", expression), + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), cypher.OperatorContains, Parameter(strings.ToLower(value)), ) } func CaseInsensitiveStringStartsWith(reference any, value string) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison( - cypher.NewSimpleFunctionInvocation("toLower", expression), + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), cypher.OperatorStartsWith, Parameter(strings.ToLower(value)), ) } func CaseInsensitiveStringEndsWith(reference any, value string) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison( - cypher.NewSimpleFunctionInvocation("toLower", expression), + cypher.NewSimpleFunctionInvocation("toLower", expressionOrError(reference)), cypher.OperatorEndsWith, Parameter(strings.ToLower(value)), ) @@ -292,23 +260,25 @@ func Exists(reference any) cypher.Expression { } func IsNull(reference any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorIs, Literal(nil)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIs, Literal(nil)) } func IsNotNull(reference any) cypher.Expression { - expression, _ := projectionExpression(reference) - return cypher.NewComparison(expression, cypher.OperatorIsNot, Literal(nil)) + return cypher.NewComparison(expressionOrError(reference), cypher.OperatorIsNot, Literal(nil)) } func HasRelationships(reference any) *cypher.PatternPredicate { - expression, _ := projectionExpression(reference) - variable, _ := expression.(*cypher.Variable) - patternPredicate := cypher.NewPatternPredicate() - patternPredicate.AddElement(&cypher.NodePattern{ - Variable: cypher.NewVariableWithSymbol(variable.Symbol), - }) + + if variable, err := variableReference(reference); err != nil { + patternPredicate.AddElement(&cypher.NodePattern{ + Properties: invalidExpression(err), + }) + } else { + patternPredicate.AddElement(&cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol(variable.Symbol), + }) + } patternPredicate.AddElement(&cypher.RelationshipPattern{ Direction: graph.DirectionBoth, diff --git a/query/v2/query.go b/query/v2/query.go index 923f0737..b51358fb 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -157,19 +157,15 @@ func Desc(expression any) *cypher.SortItem { } func Order(expression any, direction SortDirection) *cypher.SortItem { - sortExpression, _ := projectionExpression(expression) - return &cypher.SortItem{ Ascending: direction != SortDescending, - Expression: sortExpression, + Expression: expressionOrError(expression), } } func As(expression any, alias string) *cypher.ProjectionItem { - projectionExpression, _ := projectionExpression(expression) - return &cypher.ProjectionItem{ - Expression: projectionExpression, + Expression: expressionOrError(expression), Alias: cypher.NewVariableWithSymbol(alias), } } @@ -792,6 +788,10 @@ func (s *builder) Build() (*PreparedQuery, error) { return nil, fmt.Errorf("query has no action specified") } + if err := collectModelErrorsFromKnownValues(s.constraints, s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems); err != nil { + return nil, err + } + var ( regularQuery, singlePartQuery = cypher.NewRegularQueryWithSingleQuery() match = &cypher.Match{} @@ -819,6 +819,10 @@ func (s *builder) Build() (*PreparedQuery, error) { ) for _, nextConstraint := range s.constraints { + if err := collectModelErrorsFromKnownValues(nextConstraint); err != nil { + return nil, err + } + switch typedNextConstraint := nextConstraint.(type) { case *cypher.KindMatcher: if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK { @@ -881,6 +885,10 @@ func (s *builder) Build() (*PreparedQuery, error) { singlePartQuery.ReadingClauses = append(singlePartQuery.ReadingClauses, newReadingClause) } + if err := collectModelErrors(regularQuery); err != nil { + return nil, err + } + if parameters, err := materializeParameters(regularQuery); err != nil { return nil, err } else { diff --git a/query/v2/query_test.go b/query/v2/query_test.go index f764f10b..587dfd90 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -112,6 +112,41 @@ func TestUnsupportedOrderByTypeReturnsError(t *testing.T) { require.ErrorContains(t, err, "unsupported expression type: int") } +func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) { + cases := map[string]struct { + builder v2.QueryBuilder + err string + }{ + "aliased projection": { + builder: v2.New().Return(v2.As(123, "bad")), + err: "unsupported expression type: int", + }, + "sort item": { + builder: v2.New().Return(v2.Node()).OrderBy(v2.Desc(123)), + err: "unsupported expression type: int", + }, + "set properties": { + builder: v2.New().Update(v2.SetProperties(123, map[string]any{"name": "bad"})), + err: "unsupported expression type: int", + }, + "delete properties": { + builder: v2.New().Update(v2.DeleteProperties(123, "name")), + err: "unsupported expression type: int", + }, + "pattern predicate": { + builder: v2.New().Where(v2.HasRelationships(123)).Return(v2.Node()), + err: "unsupported expression type: int", + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + _, err := testCase.builder.Build() + require.ErrorContains(t, err, testCase.err) + }) + } +} + func TestCompatibilityHelpers(t *testing.T) { preparedQuery, err := v2.New().Where( v2.And( diff --git a/query/v2/util.go b/query/v2/util.go index d9e8fe16..72e7ed10 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -130,6 +130,39 @@ func kindProjectionExpression(identifier *cypher.Variable) (cypher.Expression, e } } +func invalidExpression(err error) *cypher.FunctionInvocation { + return cypher.WithErrors(cypher.NewSimpleFunctionInvocation("__invalid_expression__"), err) +} + +func expressionOrError(value any) cypher.Expression { + if expression, err := projectionExpression(value); err != nil { + return invalidExpression(err) + } else { + return expression + } +} + +func variableReference(value any) (*cypher.Variable, error) { + expression, err := projectionExpression(value) + if err != nil { + return nil, err + } + + if variable, typeOK := expression.(*cypher.Variable); !typeOK { + return nil, fmt.Errorf("expected variable reference, got %T", expression) + } else { + return variable, nil + } +} + +func propertyLookupOrError(reference any, propertyName string) cypher.Expression { + if variable, err := variableReference(reference); err != nil { + return invalidExpression(err) + } else { + return cypher.NewPropertyLookup(variable.Symbol, propertyName) + } +} + func projectionExpression(value any) (cypher.Expression, error) { switch typedValue := value.(type) { case QualifiedExpression: @@ -142,6 +175,10 @@ func projectionExpression(value any) (cypher.Expression, error) { return kindProjectionExpression(typedValue.identifier) case *cypher.ProjectionItem: + if typedValue.Expression == nil { + return nil, fmt.Errorf("projection item has nil expression") + } + return typedValue.Expression, nil case *cypher.Parameter: @@ -208,6 +245,14 @@ func projectionExpression(value any) (cypher.Expression, error) { func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { if projectionItem, typeOK := value.(*cypher.ProjectionItem); typeOK { + if projectionItem.Expression == nil { + return nil, fmt.Errorf("projection item has nil expression") + } + + if err := collectModelErrors(projectionItem); err != nil { + return nil, err + } + return projectionItem, nil } @@ -220,6 +265,14 @@ func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { func sortItemFromValue(value any) (*cypher.SortItem, error) { if sortItem, typeOK := value.(*cypher.SortItem); typeOK { + if sortItem.Expression == nil { + return nil, fmt.Errorf("sort item has nil expression") + } + + if err := collectModelErrors(sortItem); err != nil { + return nil, err + } + return sortItem, nil } @@ -474,6 +527,106 @@ func extractCypherIdentifiers(expression cypher.Expression) (*identifierSet, err return identifierExtractorVisitor.seen, err } +func collectModelErrors(node cypher.SyntaxNode) error { + var modelErrors []error + + if err := walk.Cypher(node, walk.NewSimpleVisitor[cypher.SyntaxNode](func(node cypher.SyntaxNode, _ walk.VisitorHandler) { + if errorNode, typeOK := node.(cypher.Fallible); typeOK { + modelErrors = append(modelErrors, errorNode.Errors()...) + } + })); err != nil { + modelErrors = append(modelErrors, err) + } + + return errors.Join(modelErrors...) +} + +func collectModelErrorsFromKnownValues(values ...any) error { + var modelErrors []error + + for _, value := range values { + switch typedValue := value.(type) { + case nil: + continue + + case []cypher.SyntaxNode: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []cypher.Expression: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.SetItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.RemoveItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.ProjectionItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case []*cypher.SortItem: + if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.ArithmeticExpression, + *cypher.Comparison, + *cypher.Conjunction, + *cypher.Create, + *cypher.Delete, + *cypher.Disjunction, + *cypher.ExclusiveDisjunction, + *cypher.FilterExpression, + *cypher.FunctionInvocation, + *cypher.IDInCollection, + *cypher.KindMatcher, + *cypher.ListLiteral, + *cypher.Negation, + *cypher.NodePattern, + *cypher.Order, + *cypher.Parenthetical, + *cypher.PatternPredicate, + *cypher.ProjectionItem, + *cypher.PropertyLookup, + *cypher.RelationshipPattern, + *cypher.Remove, + *cypher.RemoveItem, + *cypher.Return, + *cypher.Set, + *cypher.SetItem, + *cypher.SortItem, + *cypher.UnaryAddOrSubtractExpression, + *cypher.UpdatingClause, + *cypher.Variable: + if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + } + } + + return errors.Join(modelErrors...) +} + +func anySlice[T any](values []T) []any { + items := make([]any, len(values)) + + for idx, value := range values { + items[idx] = value + } + + return items +} + type parameterMaterializer struct { walk.Visitor[cypher.SyntaxNode] From e61f4a826efeb129a0921c36df7ac18187d34c1d Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 00:45:41 -0700 Subject: [PATCH 09/55] fix(query/v2): preserve updating clause order --- query/v2/query.go | 141 ++++++++++++++++++++++++++++++++++------- query/v2/query_test.go | 33 ++++++++++ 2 files changed, 151 insertions(+), 23 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index b51358fb..23e8a1d4 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -489,12 +489,31 @@ type QueryBuilder interface { Build() (*PreparedQuery, error) } +type updatingClauseKind int + +const ( + updatingClauseSet updatingClauseKind = iota + updatingClauseRemove + updatingClauseDelete + updatingClauseCreate +) + +type pendingUpdatingClause struct { + kind updatingClauseKind + creates []any + setItems []*cypher.SetItem + removeItems []*cypher.RemoveItem + deleteItems []cypher.Expression + detach bool +} + type builder struct { errors []error constraints []cypher.SyntaxNode sortItems []any projections []any distinct bool + updatingClauses []pendingUpdatingClause creates []any setItems []*cypher.SetItem removeItems []*cypher.RemoveItem @@ -546,8 +565,66 @@ func (s *builder) ReturnDistinct(projections ...any) QueryBuilder { return s } +func (s *builder) appendSetItems(items ...*cypher.SetItem) { + if len(items) == 0 { + return + } + + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseSet { + s.updatingClauses[lastClauseIdx].setItems = append(s.updatingClauses[lastClauseIdx].setItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseSet, + setItems: items, + }) + } +} + +func (s *builder) appendRemoveItems(items ...*cypher.RemoveItem) { + if len(items) == 0 { + return + } + + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseRemove { + s.updatingClauses[lastClauseIdx].removeItems = append(s.updatingClauses[lastClauseIdx].removeItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseRemove, + removeItems: items, + }) + } +} + +func (s *builder) appendDeleteItems(detach bool, items ...cypher.Expression) { + if len(items) == 0 { + return + } + + lastClauseIdx := len(s.updatingClauses) - 1 + if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseDelete { + s.updatingClauses[lastClauseIdx].detach = s.updatingClauses[lastClauseIdx].detach || detach + s.updatingClauses[lastClauseIdx].deleteItems = append(s.updatingClauses[lastClauseIdx].deleteItems, items...) + } else { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseDelete, + deleteItems: items, + detach: detach, + }) + } +} + func (s *builder) Create(creationClauses ...any) QueryBuilder { s.creates = append(s.creates, creationClauses...) + + if len(creationClauses) > 0 { + s.updatingClauses = append(s.updatingClauses, pendingUpdatingClause{ + kind: updatingClauseCreate, + creates: creationClauses, + }) + } + return s } @@ -556,15 +633,19 @@ func (s *builder) Update(updates ...any) QueryBuilder { switch typedNextUpdate := nextUpdate.(type) { case *cypher.Set: s.setItems = append(s.setItems, typedNextUpdate.Items...) + s.appendSetItems(typedNextUpdate.Items...) case *cypher.SetItem: s.setItems = append(s.setItems, typedNextUpdate) + s.appendSetItems(typedNextUpdate) case *cypher.Remove: s.removeItems = append(s.removeItems, typedNextUpdate.Items...) + s.appendRemoveItems(typedNextUpdate.Items...) case *cypher.RemoveItem: s.removeItems = append(s.removeItems, typedNextUpdate) + s.appendRemoveItems(typedNextUpdate) default: s.trackError(fmt.Errorf("unknown update type: %T", nextUpdate)) @@ -575,6 +656,9 @@ func (s *builder) Update(updates ...any) QueryBuilder { } func (s *builder) Delete(deleteItems ...any) QueryBuilder { + var pendingDeleteItems []cypher.Expression + pendingDetachDelete := false + for _, nextDelete := range deleteItems { switch typedNextUpdate := nextDelete.(type) { case QualifiedExpression: @@ -582,23 +666,28 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { if isDetachDeleteQualifier(qualifier) { s.detachDelete = true + pendingDetachDelete = true } s.deleteItems = append(s.deleteItems, qualifier) + pendingDeleteItems = append(pendingDeleteItems, qualifier) case *cypher.Variable: switch typedNextUpdate.Symbol { case Identifiers.node, Identifiers.start, Identifiers.end: s.detachDelete = true + pendingDetachDelete = true } s.deleteItems = append(s.deleteItems, typedNextUpdate) + pendingDeleteItems = append(pendingDeleteItems, typedNextUpdate) default: s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete)) } } + s.appendDeleteItems(pendingDetachDelete, pendingDeleteItems...) return s } @@ -611,8 +700,8 @@ func (s *builder) Where(constraints ...cypher.SyntaxNode) QueryBuilder { return s } -func (s *builder) buildCreates(singlePartQuery *cypher.SinglePartQuery) error { - if len(s.creates) == 0 { +func buildCreates(singlePartQuery *cypher.SinglePartQuery, creates []any) error { + if len(creates) == 0 { return nil } @@ -624,7 +713,7 @@ func (s *builder) buildCreates(singlePartQuery *cypher.SinglePartQuery) error { } ) - for _, nextCreate := range s.creates { + for _, nextCreate := range creates { switch typedNextCreate := nextCreate.(type) { case QualifiedExpression: switch typedExpression := typedNextCreate.qualifier().(type) { @@ -664,28 +753,34 @@ func (s *builder) buildCreates(singlePartQuery *cypher.SinglePartQuery) error { } func (s *builder) buildUpdatingClauses(singlePartQuery *cypher.SinglePartQuery) error { - if len(s.setItems) > 0 { - singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( - cypher.NewSet(s.setItems), - )) - } - - if len(s.removeItems) > 0 { - singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( - cypher.NewRemove(s.removeItems), - )) - } - - if len(s.deleteItems) > 0 { - singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( - cypher.NewDelete( - s.detachDelete, - s.deleteItems, - ), - )) + for _, updatingClause := range s.updatingClauses { + switch updatingClause.kind { + case updatingClauseSet: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewSet(updatingClause.setItems), + )) + + case updatingClauseRemove: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewRemove(updatingClause.removeItems), + )) + + case updatingClauseDelete: + singlePartQuery.UpdatingClauses = append(singlePartQuery.UpdatingClauses, cypher.NewUpdatingClause( + cypher.NewDelete( + updatingClause.detach, + updatingClause.deleteItems, + ), + )) + + case updatingClauseCreate: + if err := buildCreates(singlePartQuery, updatingClause.creates); err != nil { + return err + } + } } - return s.buildCreates(singlePartQuery) + return nil } func (s *builder) buildProjectionOrder() (*cypher.Order, error) { diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 587dfd90..3745ec90 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -95,6 +95,39 @@ func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { }, preparedQuery.Parameters) } +func TestUpdatingClausesPreserveFluentOrder(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("User")}, nil), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "created"), + ).Return( + v2.Node().Property("name"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (n:User) set n.name = $p0 return n.name", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": "created", + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.DeleteProperties(v2.Node(), "old"), + ).Update( + v2.SetProperty(v2.Node().Property("new"), "value"), + ).Return( + v2.Node(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where id(n) = $p0 remove n.old set n.new = $p1 return n", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": "value", + }, preparedQuery.Parameters) +} + func TestProjectionAndOrderHelpers(t *testing.T) { preparedQuery, err := v2.New().ReturnDistinct( v2.As(v2.Node().ID(), "node_id"), From a73d49a1524333a8e664012471c3b36cd11e84a3 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 00:48:15 -0700 Subject: [PATCH 10/55] feat(query/v2): support scoped pattern aliases --- query/v2/query.go | 172 ++++++++++++++++++++++++++++++----------- query/v2/query_test.go | 23 ++++++ query/v2/util.go | 64 +++++++-------- 3 files changed, 180 insertions(+), 79 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 23e8a1d4..9046fa47 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -44,6 +44,67 @@ var Identifiers = runtimeIdentifiers{ end: "e", } +type Scope struct { + identifiers runtimeIdentifiers +} + +func DefaultScope() Scope { + return Scope{ + identifiers: Identifiers, + } +} + +func NewScope(path, node, start, relationship, end string) Scope { + return Scope{ + identifiers: runtimeIdentifiers{ + path: path, + node: node, + start: start, + relationship: relationship, + end: end, + }, + } +} + +func (s Scope) New() QueryBuilder { + return newBuilder(s.identifiers) +} + +func (s Scope) Node() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.Node(), + role: Identifiers.node, + } +} + +func (s Scope) Path() PathContinuation { + return &entity[PathContinuation]{ + identifier: s.identifiers.Path(), + role: Identifiers.path, + } +} + +func (s Scope) Start() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.Start(), + role: Identifiers.start, + } +} + +func (s Scope) Relationship() RelationshipContinuation { + return &entity[RelationshipContinuation]{ + identifier: s.identifiers.Relationship(), + role: Identifiers.relationship, + } +} + +func (s Scope) End() NodeContinuation { + return &entity[NodeContinuation]{ + identifier: s.identifiers.End(), + role: Identifiers.end, + } +} + func Literal(value any) *cypher.Literal { if value == nil { return cypher.NewLiteral(nil, true) @@ -171,33 +232,23 @@ func As(expression any, alias string) *cypher.ProjectionItem { } func Node() NodeContinuation { - return &entity[NodeContinuation]{ - identifier: Identifiers.Node(), - } + return DefaultScope().Node() } func Path() PathContinuation { - return &entity[PathContinuation]{ - identifier: Identifiers.Path(), - } + return DefaultScope().Path() } func Start() NodeContinuation { - return &entity[NodeContinuation]{ - identifier: Identifiers.Start(), - } + return DefaultScope().Start() } func Relationship() RelationshipContinuation { - return &entity[RelationshipContinuation]{ - identifier: Identifiers.Relationship(), - } + return DefaultScope().Relationship() } func End() NodeContinuation { - return &entity[NodeContinuation]{ - identifier: Identifiers.End(), - } + return DefaultScope().End() } type QualifiedExpression interface { @@ -329,17 +380,20 @@ func (s *propertyContinuation) Remove() *cypher.RemoveItem { type entity[T any] struct { identifier *cypher.Variable + role string } func (s *entity[T]) Kind() KindContinuation { return kindContinuation{ identifier: s.identifier, + role: s.role, } } func (s *entity[T]) Kinds() KindsContinuation { return kindsContinuation{ identifier: s.identifier, + role: s.role, } } @@ -408,6 +462,7 @@ func (s *entity[T]) Property(propertyName string) PropertyContinuation { type kindContinuation struct { identifier *cypher.Variable + role string } func (s kindContinuation) Is(kind graph.Kind) cypher.Expression { @@ -423,6 +478,7 @@ func (s kindContinuation) IsOneOf(kinds graph.Kinds) cypher.Expression { type kindsContinuation struct { identifier *cypher.Variable + role string } func (s kindsContinuation) Has(kind graph.Kind) cypher.Expression { @@ -486,6 +542,7 @@ type QueryBuilder interface { Delete(expressions ...any) QueryBuilder WithShortestPaths() QueryBuilder WithAllShortestPaths() QueryBuilder + WithRelationshipDirection(direction graph.Direction) QueryBuilder Build() (*PreparedQuery, error) } @@ -508,25 +565,34 @@ type pendingUpdatingClause struct { } type builder struct { - errors []error - constraints []cypher.SyntaxNode - sortItems []any - projections []any - distinct bool - updatingClauses []pendingUpdatingClause - creates []any - setItems []*cypher.SetItem - removeItems []*cypher.RemoveItem - deleteItems []cypher.Expression - detachDelete bool - shortestPathQuery bool - allShorestPathsQuery bool - skip *int - limit *int + errors []error + constraints []cypher.SyntaxNode + sortItems []any + projections []any + distinct bool + identifiers runtimeIdentifiers + updatingClauses []pendingUpdatingClause + creates []any + setItems []*cypher.SetItem + removeItems []*cypher.RemoveItem + deleteItems []cypher.Expression + detachDelete bool + relationshipDirection graph.Direction + shortestPathQuery bool + allShorestPathsQuery bool + skip *int + limit *int } func New() QueryBuilder { - return &builder{} + return DefaultScope().New() +} + +func newBuilder(identifiers runtimeIdentifiers) QueryBuilder { + return &builder{ + identifiers: identifiers, + relationshipDirection: graph.DirectionOutbound, + } } func (s *builder) WithShortestPaths() QueryBuilder { @@ -539,6 +605,18 @@ func (s *builder) WithAllShortestPaths() QueryBuilder { return s } +func (s *builder) WithRelationshipDirection(direction graph.Direction) QueryBuilder { + switch direction { + case graph.DirectionInbound, graph.DirectionOutbound, graph.DirectionBoth: + s.relationshipDirection = direction + + default: + s.trackError(fmt.Errorf("invalid relationship direction: %s", direction)) + } + + return s +} + func (s *builder) OrderBy(sortItems ...any) QueryBuilder { s.sortItems = append(s.sortItems, sortItems...) return s @@ -664,7 +742,7 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { case QualifiedExpression: qualifier := typedNextUpdate.qualifier() - if isDetachDeleteQualifier(qualifier) { + if isDetachDeleteQualifier(qualifier, s.identifiers) { s.detachDelete = true pendingDetachDelete = true } @@ -674,7 +752,7 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { case *cypher.Variable: switch typedNextUpdate.Symbol { - case Identifiers.node, Identifiers.start, Identifiers.end: + case s.identifiers.node, s.identifiers.start, s.identifiers.end: s.detachDelete = true pendingDetachDelete = true } @@ -700,7 +778,7 @@ func (s *builder) Where(constraints ...cypher.SyntaxNode) QueryBuilder { return s } -func buildCreates(singlePartQuery *cypher.SinglePartQuery, creates []any) error { +func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeIdentifiers, creates []any) error { if len(creates) == 0 { return nil } @@ -719,7 +797,7 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, creates []any) error switch typedExpression := typedNextCreate.qualifier().(type) { case *cypher.Variable: switch typedExpression.Symbol { - case Identifiers.node, Identifiers.start, Identifiers.end: + case identifiers.node, identifiers.start, identifiers.end: pattern.AddPatternElements(&cypher.NodePattern{ Variable: cypher.NewVariableWithSymbol(typedExpression.Symbol), }) @@ -734,13 +812,13 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, creates []any) error case *cypher.RelationshipPattern: pattern.AddPatternElements(&cypher.NodePattern{ - Variable: cypher.NewVariableWithSymbol(Identifiers.start), + Variable: identifiers.Start(), }) pattern.AddPatternElements(typedNextCreate) pattern.AddPatternElements(&cypher.NodePattern{ - Variable: cypher.NewVariableWithSymbol(Identifiers.end), + Variable: identifiers.End(), }) default: @@ -774,7 +852,7 @@ func (s *builder) buildUpdatingClauses(singlePartQuery *cypher.SinglePartQuery) )) case updatingClauseCreate: - if err := buildCreates(singlePartQuery, updatingClause.creates); err != nil { + if err := buildCreates(singlePartQuery, s.identifiers, updatingClause.creates); err != nil { return err } } @@ -894,7 +972,7 @@ func (s *builder) Build() (*PreparedQuery, error) { relationshipKinds graph.Kinds ) - createScope, err := collectCreateScope(s.creates...) + createScope, err := collectCreateScope(s.identifiers, s.creates...) if err != nil { return nil, err } @@ -922,9 +1000,9 @@ func (s *builder) Build() (*PreparedQuery, error) { case *cypher.KindMatcher: if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK { return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint) - } else if identifier.Symbol == Identifiers.relationship { + } else if identifier.Symbol == s.identifiers.relationship { relationshipKinds = relationshipKinds.Add(typedNextConstraint.Kinds...) - readIdentifiers.Add(Identifiers.relationship) + readIdentifiers.Add(s.identifiers.relationship) continue } } @@ -956,16 +1034,16 @@ func (s *builder) Build() (*PreparedQuery, error) { matchIdentifiers.Or(actionIdentifiers) if len(s.constraints) > 0 || len(s.creates) == 0 || matchIdentifiers.Len() > 0 { - if isNodePattern(matchIdentifiers) { - if err := prepareNodePattern(match, matchIdentifiers); err != nil { + if isNodePattern(matchIdentifiers, s.identifiers) { + if err := prepareNodePattern(match, matchIdentifiers, s.identifiers); err != nil { return nil, err } - } else if createScope.createsRelationship && !matchIdentifiers.Contains(Identifiers.relationship) { - if err := prepareCreateRelationshipMatch(match, matchIdentifiers); err != nil { + } else if createScope.createsRelationship && !matchIdentifiers.Contains(s.identifiers.relationship) { + if err := prepareCreateRelationshipMatch(match, matchIdentifiers, s.identifiers); err != nil { return nil, err } - } else if isRelationshipPattern(matchIdentifiers) { - if err := prepareRelationshipPattern(match, matchIdentifiers, relationshipKinds, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { + } else if isRelationshipPattern(matchIdentifiers, s.identifiers) { + if err := prepareRelationshipPattern(match, matchIdentifiers, s.identifiers, relationshipKinds, s.relationshipDirection, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { return nil, err } } else { diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 3745ec90..5cfc05b8 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -128,6 +128,29 @@ func TestUpdatingClausesPreserveFluentOrder(t *testing.T) { }, preparedQuery.Parameters) } +func TestScopedRelationshipPatternControls(t *testing.T) { + scope := v2.NewScope("path", "person", "source", "edge", "target") + + preparedQuery, err := scope.New().WithRelationshipDirection(graph.DirectionInbound).Where( + scope.Relationship().Kind().Is(graph.StringKind("MemberOf")), + scope.Start().ID().Equals(1), + ).Return( + scope.Relationship().Kind(), + scope.End().Kinds(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (source)<-[edge:MemberOf]-(target) where id(source) = $p0 return type(edge), labels(target)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + }, preparedQuery.Parameters) +} + +func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { + _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build() + require.ErrorContains(t, err, "invalid relationship direction: invalid") +} + func TestProjectionAndOrderHelpers(t *testing.T) { preparedQuery, err := v2.New().ReturnDistinct( v2.As(v2.Node().ID(), "node_id"), diff --git a/query/v2/util.go b/query/v2/util.go index 72e7ed10..016b2f01 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -11,42 +11,42 @@ import ( "github.com/specterops/dawgs/graph" ) -func isNodePattern(seen *identifierSet) bool { - return seen.Contains(Identifiers.node) +func isNodePattern(seen *identifierSet, identifiers runtimeIdentifiers) bool { + return seen.Contains(identifiers.node) } -func isRelationshipPattern(seen *identifierSet) bool { +func isRelationshipPattern(seen *identifierSet, identifiers runtimeIdentifiers) bool { var ( - hasStart = seen.Contains(Identifiers.start) - hasRelationship = seen.Contains(Identifiers.relationship) - hasEnd = seen.Contains(Identifiers.end) + hasStart = seen.Contains(identifiers.start) + hasRelationship = seen.Contains(identifiers.relationship) + hasEnd = seen.Contains(identifiers.end) ) return hasStart || hasRelationship || hasEnd } -func prepareNodePattern(match *cypher.Match, seen *identifierSet) error { - if isRelationshipPattern(seen) { +func prepareNodePattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error { + if isRelationshipPattern(seen, identifiers) { return fmt.Errorf("query mixes node and relationship query identifiers") } match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ - Variable: Identifiers.Node(), + Variable: identifiers.Node(), }) return nil } -func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, relationshipKinds graph.Kinds, shortestPaths, allShortestPaths bool) error { +func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers, relationshipKinds graph.Kinds, direction graph.Direction, shortestPaths, allShortestPaths bool) error { if shortestPaths && allShortestPaths { return errors.New("query is requesting both all shortest paths and shortest paths") } var ( newPatternPart = match.NewPatternPart() - startNodeSeen = seen.Contains(Identifiers.start) - relationshipSeen = seen.Contains(Identifiers.relationship) - endNodeSeen = seen.Contains(Identifiers.end) + startNodeSeen = seen.Contains(identifiers.start) + relationshipSeen = seen.Contains(identifiers.relationship) + endNodeSeen = seen.Contains(identifiers.end) ) newPatternPart.ShortestPathPattern = shortestPaths @@ -54,7 +54,7 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, relati if startNodeSeen { newPatternPart.AddPatternElements(&cypher.NodePattern{ - Variable: Identifiers.Start(), + Variable: identifiers.Start(), }) } else { newPatternPart.AddPatternElements(&cypher.NodePattern{}) @@ -62,15 +62,15 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, relati relationshipPattern := &cypher.RelationshipPattern{ Kinds: relationshipKinds, - Direction: graph.DirectionOutbound, + Direction: direction, } if relationshipSeen { - relationshipPattern.Variable = Identifiers.Relationship() + relationshipPattern.Variable = identifiers.Relationship() } if shortestPaths || allShortestPaths { - newPatternPart.Variable = Identifiers.Path() + newPatternPart.Variable = identifiers.Path() relationshipPattern.Range = &cypher.PatternRange{} } @@ -78,7 +78,7 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, relati if endNodeSeen { newPatternPart.AddPatternElements(&cypher.NodePattern{ - Variable: Identifiers.End(), + Variable: identifiers.End(), }) } else { newPatternPart.AddPatternElements(&cypher.NodePattern{}) @@ -87,38 +87,38 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, relati return nil } -func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet) error { - if seen.Contains(Identifiers.start) { +func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error { + if seen.Contains(identifiers.start) { match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ - Variable: Identifiers.Start(), + Variable: identifiers.Start(), }) } - if seen.Contains(Identifiers.end) { + if seen.Contains(identifiers.end) { match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ - Variable: Identifiers.End(), + Variable: identifiers.End(), }) } return nil } -func isDetachDeleteQualifier(qualifier cypher.Expression) bool { +func isDetachDeleteQualifier(qualifier cypher.Expression, identifiers runtimeIdentifiers) bool { variable, typeOK := qualifier.(*cypher.Variable) if !typeOK { return false } switch variable.Symbol { - case Identifiers.node, Identifiers.start, Identifiers.end: + case identifiers.node, identifiers.start, identifiers.end: return true default: return false } } -func kindProjectionExpression(identifier *cypher.Variable) (cypher.Expression, error) { - switch identifier.Symbol { +func kindProjectionExpression(role string, identifier *cypher.Variable) (cypher.Expression, error) { + switch role { case Identifiers.node, Identifiers.start, Identifiers.end: return cypher.NewSimpleFunctionInvocation(cypher.NodeLabelsFunction, identifier), nil @@ -169,10 +169,10 @@ func projectionExpression(value any) (cypher.Expression, error) { return typedValue.qualifier(), nil case kindContinuation: - return kindProjectionExpression(typedValue.identifier) + return kindProjectionExpression(typedValue.role, typedValue.identifier) case kindsContinuation: - return kindProjectionExpression(typedValue.identifier) + return kindProjectionExpression(typedValue.role, typedValue.identifier) case *cypher.ProjectionItem: if typedValue.Expression == nil { @@ -447,7 +447,7 @@ type createScope struct { createsRelationship bool } -func collectCreateScope(values ...any) (*createScope, error) { +func collectCreateScope(identifiers runtimeIdentifiers, values ...any) (*createScope, error) { scope := &createScope{ identifiers: newIdentifierSet(), } @@ -456,8 +456,8 @@ func collectCreateScope(values ...any) (*createScope, error) { switch typedValue := value.(type) { case *cypher.RelationshipPattern: scope.createsRelationship = true - scope.identifiers.Add(Identifiers.start) - scope.identifiers.Add(Identifiers.end) + scope.identifiers.Add(identifiers.start) + scope.identifiers.Add(identifiers.end) if typedValue.Variable != nil { scope.identifiers.Add(typedValue.Variable.Symbol) From ec02419892c65f5d42c910810eafb5d18fce4c24 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 00:49:52 -0700 Subject: [PATCH 11/55] test(query/v2): assert backend render output --- query/v2/backend_test.go | 195 ++++++++++++++++++++++++--------------- 1 file changed, 122 insertions(+), 73 deletions(-) diff --git a/query/v2/backend_test.go b/query/v2/backend_test.go index 48b52ea3..fd9d9707 100644 --- a/query/v2/backend_test.go +++ b/query/v2/backend_test.go @@ -23,43 +23,67 @@ func testKindMapper(kinds ...graph.Kind) *pgutil.InMemoryKindMapper { } func TestBackendParityNeo4jPrepare(t *testing.T) { - cases := map[string]v2.QueryBuilder{ - "node read": v2.New().Where( - v2.Node().Kinds().Has(graph.StringKind("User")), - v2.Node().Property("name").Contains("admin"), - ).Return( - v2.Node(), - ).OrderBy( - v2.Node().Property("name"), - ), - "relationship read": v2.New().Where( - v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), - v2.Start().ID().Equals(1), - ).Return( - v2.Start().ID(), - v2.Relationship().ID(), - v2.End().ID(), - ), - "create node": v2.New().Create( - v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})), - ).Return( - v2.Node().ID(), - ), - "update node": v2.New().Where( - v2.Node().ID().Equals(1), - ).Update( - v2.SetProperty(v2.Node().Property("name"), "updated"), - ), - "delete relationship": v2.New().Where( - v2.Relationship().ID().Equals(1), - ).Delete( - v2.Relationship(), - ), + cases := map[string]struct { + builder v2.QueryBuilder + expectedCypher string + expectedParams map[string]any + }{ + "node read": { + builder: v2.New().Where( + v2.Node().Kinds().Has(graph.StringKind("User")), + v2.Node().Property("name").Contains("admin"), + ).Return( + v2.Node(), + ).OrderBy( + v2.Node().Property("name"), + ), + expectedCypher: "match (n) where n:User and n.name contains $p0 return n order by n.name asc", + expectedParams: map[string]any{"p0": "admin"}, + }, + "relationship read": { + builder: v2.New().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + expectedCypher: "match (s)-[r]->(e) where id(s) = $p0 return id(s), id(r), id(e)", + expectedParams: map[string]any{"p0": 1}, + }, + "create node": { + builder: v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})), + ).Return( + v2.Node().ID(), + ), + expectedCypher: "create (n:User $p0) return id(n)", + expectedParams: map[string]any{"p0": map[string]any{"name": "u"}}, + }, + "update node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "updated"), + ), + expectedCypher: "match (n) where id(n) = $p0 set n.name = $p1", + expectedParams: map[string]any{"p0": 1, "p1": "updated"}, + }, + "delete relationship": { + builder: v2.New().Where( + v2.Relationship().ID().Equals(1), + ).Delete( + v2.Relationship(), + ), + expectedCypher: "match ()-[r]->() where id(r) = $p0 delete r", + expectedParams: map[string]any{"p0": 1}, + }, } - for name, builder := range cases { + for name, testCase := range cases { t.Run(name, func(t *testing.T) { - preparedQuery, err := builder.Build() + preparedQuery, err := testCase.builder.Build() require.NoError(t, err) queryBuilder := neo4j.NewQueryBuilder(preparedQuery.Query) @@ -67,8 +91,8 @@ func TestBackendParityNeo4jPrepare(t *testing.T) { rendered, err := queryBuilder.Render() require.NoError(t, err) - require.NotEmpty(t, rendered) - require.NotEmpty(t, queryBuilder.Parameters) + require.Equal(t, testCase.expectedCypher, rendered) + require.Equal(t, testCase.expectedParams, queryBuilder.Parameters) }) } } @@ -78,45 +102,69 @@ func TestBackendParityPGTranslate(t *testing.T) { edgeKind := graph.StringKind("MemberOf") mapper := testKindMapper(userKind, edgeKind) - cases := map[string]v2.QueryBuilder{ - "node read": v2.New().Where( - v2.Node().Kinds().Has(userKind), - v2.Node().Property("name").Contains("admin"), - ).Return( - v2.Node().ID(), - v2.Node().Kinds(), - ), - "relationship read": v2.New().Where( - v2.Relationship().Kind().Is(edgeKind), - v2.Start().ID().Equals(1), - ).Return( - v2.Start().ID(), - v2.Relationship().ID(), - v2.End().ID(), - ), - "create relationship": v2.New().Where( - v2.Start().ID().Equals(1), - v2.End().ID().Equals(2), - ).Create( - v2.RelationshipPattern(edgeKind, nil, graph.DirectionOutbound), - ).Return( - v2.Relationship().ID(), - ), - "update node": v2.New().Where( - v2.Node().ID().Equals(1), - ).Update( - v2.SetProperty(v2.Node().Property("name"), "updated"), - ), - "delete relationship": v2.New().Where( - v2.Relationship().ID().Equals(1), - ).Delete( - v2.Relationship(), - ), + cases := map[string]struct { + builder v2.QueryBuilder + expectedSQL string + expectedParams map[string]any + }{ + "node read": { + builder: v2.New().Where( + v2.Node().Kinds().Has(userKind), + v2.Node().Property("name").Contains("admin"), + ).Return( + v2.Node().ID(), + v2.Node().Kinds(), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and (n0.properties ->> 'name') like '%' || @pi0::text || '%')) select (s0.n0).id, (s0.n0).kind_ids from s0;", + expectedParams: map[string]any{"p0": "admin", "pi0": "admin"}, + }, + "relationship read": { + builder: v2.New().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on (n0.id = @pi0::int8) and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [2]::int2[])) select (s0.n0).id, (s0.e0).id, (s0.n1).id from s0;", + expectedParams: map[string]any{"p0": 1, "pi0": 1}, + }, + "create relationship": { + builder: v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.RelationshipPattern(edgeKind, nil, graph.DirectionOutbound), + ).Return( + v2.Relationship().ID(), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, node n1 where (n1.id = @pi1::int8)) select e0.id from s1 where e0.kind_id = any (array [2]::int2[]);", + expectedParams: map[string]any{"p0": 1, "p1": 2, "pi0": 1, "pi1": 2}, + }, + "update node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Update( + v2.SetProperty(v2.Node().Property("name"), "updated"), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (update node n1 set properties = n1.properties || jsonb_build_object('name', @pi1::text)::jsonb from s0 where (s0.n0).id = n1.id returning (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n0) select 1;", + expectedParams: map[string]any{"p0": 1, "p1": "updated", "pi0": 1, "pi1": "updated"}, + }, + "delete relationship": { + builder: v2.New().Where( + v2.Relationship().ID().Equals(1), + ).Delete( + v2.Relationship(), + ), + expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.id = @pi0::int8)), s1 as (delete from edge e1 using s0 where (s0.e0).id = e1.id) select 1;", + expectedParams: map[string]any{"p0": 1, "pi0": 1}, + }, } - for name, builder := range cases { + for name, testCase := range cases { t.Run(name, func(t *testing.T) { - preparedQuery, err := builder.Build() + preparedQuery, err := testCase.builder.Build() require.NoError(t, err) translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters) @@ -124,7 +172,8 @@ func TestBackendParityPGTranslate(t *testing.T) { sql, err := translate.Translated(translation) require.NoError(t, err) - require.NotEmpty(t, sql) + require.Equal(t, testCase.expectedSQL, sql) + require.Equal(t, testCase.expectedParams, translation.Parameters) }) } } From 318555dcaf4f8962d70e9d159663c6c705dcb6b1 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:01:11 -0700 Subject: [PATCH 12/55] fix(pgsql): reject unsupported create translation --- cypher/models/pgsql/translate/translator.go | 3 +++ query/v2/backend_test.go | 30 ++++++++++++--------- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/cypher/models/pgsql/translate/translator.go b/cypher/models/pgsql/translate/translator.go index 993248b1..03d18b48 100644 --- a/cypher/models/pgsql/translate/translator.go +++ b/cypher/models/pgsql/translate/translator.go @@ -284,6 +284,9 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.SetError(err) } + case *cypher.Create: + s.SetErrorf("pgsql translator does not support create clauses") + case *cypher.Delete: if err := s.translateDelete(s.scope, typedExpression); err != nil { s.SetError(err) diff --git a/query/v2/backend_test.go b/query/v2/backend_test.go index fd9d9707..de504a7a 100644 --- a/query/v2/backend_test.go +++ b/query/v2/backend_test.go @@ -130,18 +130,6 @@ func TestBackendParityPGTranslate(t *testing.T) { expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on (n0.id = @pi0::int8) and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [2]::int2[])) select (s0.n0).id, (s0.e0).id, (s0.n1).id from s0;", expectedParams: map[string]any{"p0": 1, "pi0": 1}, }, - "create relationship": { - builder: v2.New().Where( - v2.Start().ID().Equals(1), - v2.End().ID().Equals(2), - ).Create( - v2.RelationshipPattern(edgeKind, nil, graph.DirectionOutbound), - ).Return( - v2.Relationship().ID(), - ), - expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (select s0.n0 as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from s0, node n1 where (n1.id = @pi1::int8)) select e0.id from s1 where e0.kind_id = any (array [2]::int2[]);", - expectedParams: map[string]any{"p0": 1, "p1": 2, "pi0": 1, "pi1": 2}, - }, "update node": { builder: v2.New().Where( v2.Node().ID().Equals(1), @@ -177,3 +165,21 @@ func TestBackendParityPGTranslate(t *testing.T) { }) } } + +func TestBackendParityPGCreateUnsupported(t *testing.T) { + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(edgeKind) + + preparedQuery, err := v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.RelationshipPattern(edgeKind, nil, graph.DirectionOutbound), + ).Return( + v2.Relationship().ID(), + ).Build() + require.NoError(t, err) + + _, err = translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters) + require.ErrorContains(t, err, "pgsql translator does not support create clauses") +} From 0f92ca8b385f691631244eaa4f92a7d6e123710b Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:01:36 -0700 Subject: [PATCH 13/55] fix(query/v2): reject unsupported relationship directions --- query/v2/query.go | 4 ++-- query/v2/query_test.go | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 9046fa47..b4b86071 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -607,11 +607,11 @@ func (s *builder) WithAllShortestPaths() QueryBuilder { func (s *builder) WithRelationshipDirection(direction graph.Direction) QueryBuilder { switch direction { - case graph.DirectionInbound, graph.DirectionOutbound, graph.DirectionBoth: + case graph.DirectionInbound, graph.DirectionOutbound: s.relationshipDirection = direction default: - s.trackError(fmt.Errorf("invalid relationship direction: %s", direction)) + s.trackError(fmt.Errorf("unsupported relationship direction: %s", direction)) } return s diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 5cfc05b8..2fd644dc 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -148,7 +148,10 @@ func TestScopedRelationshipPatternControls(t *testing.T) { func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build() - require.ErrorContains(t, err, "invalid relationship direction: invalid") + require.ErrorContains(t, err, "unsupported relationship direction: invalid") + + _, err = v2.New().WithRelationshipDirection(graph.DirectionBoth).Return(v2.Relationship()).Build() + require.ErrorContains(t, err, "unsupported relationship direction: both") } func TestProjectionAndOrderHelpers(t *testing.T) { From 48b4aac6d8fceb22fd2332bbe9a6cd362e0178c2 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:01:59 -0700 Subject: [PATCH 14/55] fix(query/v2): validate create qualified expressions --- query/v2/query.go | 3 +++ query/v2/query_test.go | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/query/v2/query.go b/query/v2/query.go index b4b86071..ddca3da4 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -805,6 +805,9 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId default: return fmt.Errorf("invalid variable reference for create: %s", typedExpression.Symbol) } + + default: + return fmt.Errorf("invalid qualified expression for create: %T", typedExpression) } case *cypher.NodePattern: diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 2fd644dc..b1d40494 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -95,6 +95,11 @@ func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { }, preparedQuery.Parameters) } +func TestInvalidCreateQualifiedExpressionReturnsError(t *testing.T) { + _, err := v2.New().Create(v2.Node().Property("name")).Build() + require.ErrorContains(t, err, "invalid qualified expression for create: *cypher.PropertyLookup") +} + func TestUpdatingClausesPreserveFluentOrder(t *testing.T) { preparedQuery, err := v2.New().Create( v2.NodePattern(graph.Kinds{graph.StringKind("User")}, nil), From 3ad2e56f4d84ffec5f9d70567919549c48c22751 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:02:55 -0700 Subject: [PATCH 15/55] fix(query/v2): make kind projections scope aware --- query/v2/compat.go | 13 +++++++++++++ query/v2/query.go | 26 ++++++++++++++++++++++++++ query/v2/query_test.go | 12 ++++++++++++ query/v2/util.go | 6 +++--- 4 files changed, 54 insertions(+), 3 deletions(-) diff --git a/query/v2/compat.go b/query/v2/compat.go index 521820ba..1e941816 100644 --- a/query/v2/compat.go +++ b/query/v2/compat.go @@ -1,6 +1,7 @@ package v2 import ( + "fmt" "strings" "time" @@ -49,6 +50,18 @@ func Size(expression any) *cypher.FunctionInvocation { } func KindsOf(reference any) *cypher.FunctionInvocation { + if scopedReference, typeOK := reference.(scopedExpression); typeOK { + if variable, typeOK := scopedReference.qualifier().(*cypher.Variable); !typeOK { + return invalidExpression(fmt.Errorf("expected variable reference, got %T", scopedReference.qualifier())) + } else if expression, err := kindProjectionExpression(scopedReference.roleName(), variable); err != nil { + return invalidExpression(err) + } else if invocation, typeOK := expression.(*cypher.FunctionInvocation); !typeOK { + return invalidExpression(fmt.Errorf("expected kind projection function, got %T", expression)) + } else { + return invocation + } + } + expression := expressionOrError(reference) switch typedExpression := expression.(type) { diff --git a/query/v2/query.go b/query/v2/query.go index ddca3da4..2e94fe2b 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -255,6 +255,12 @@ type QualifiedExpression interface { qualifier() cypher.Expression } +type scopedExpression interface { + QualifiedExpression + + roleName() string +} + type EntityContinuation interface { QualifiedExpression @@ -442,6 +448,10 @@ func (s *entity[T]) qualifier() cypher.Expression { return s.identifier } +func (s *entity[T]) roleName() string { + return s.role +} + func (s *entity[T]) ID() IdentityContinuation { return &comparisonContinuation{ qualifierExpression: &cypher.FunctionInvocation{ @@ -465,6 +475,14 @@ type kindContinuation struct { role string } +func (s kindContinuation) qualifier() cypher.Expression { + return s.identifier +} + +func (s kindContinuation) roleName() string { + return s.role +} + func (s kindContinuation) Is(kind graph.Kind) cypher.Expression { return s.IsOneOf(graph.Kinds{kind}) } @@ -481,6 +499,14 @@ type kindsContinuation struct { role string } +func (s kindsContinuation) qualifier() cypher.Expression { + return s.identifier +} + +func (s kindsContinuation) roleName() string { + return s.role +} + func (s kindsContinuation) Has(kind graph.Kind) cypher.Expression { return s.HasOneOf(graph.Kinds{kind}) } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index b1d40494..9b04b927 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -151,6 +151,18 @@ func TestScopedRelationshipPatternControls(t *testing.T) { }, preparedQuery.Parameters) } +func TestScopedKindsOfCompatibilityHelper(t *testing.T) { + scope := v2.NewScope("path", "person", "source", "edge", "target") + + preparedQuery, err := scope.New().Return( + v2.KindsOf(scope.Relationship()), + v2.KindsOf(scope.End()), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match ()-[edge]->(target) return type(edge), labels(target)", renderPrepared(t, preparedQuery)) +} + func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build() require.ErrorContains(t, err, "unsupported relationship direction: invalid") diff --git a/query/v2/util.go b/query/v2/util.go index 016b2f01..c5ea4857 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -165,15 +165,15 @@ func propertyLookupOrError(reference any, propertyName string) cypher.Expression func projectionExpression(value any) (cypher.Expression, error) { switch typedValue := value.(type) { - case QualifiedExpression: - return typedValue.qualifier(), nil - case kindContinuation: return kindProjectionExpression(typedValue.role, typedValue.identifier) case kindsContinuation: return kindProjectionExpression(typedValue.role, typedValue.identifier) + case QualifiedExpression: + return typedValue.qualifier(), nil + case *cypher.ProjectionItem: if typedValue.Expression == nil { return nil, fmt.Errorf("projection item has nil expression") From 2894033db786eb38a712972cb7bc7c59fb22c128 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:04:07 -0700 Subject: [PATCH 16/55] fix(query/v2): ignore projection aliases for match inference --- cypher/models/cypher/format/format.go | 8 +++++++- query/v2/query.go | 2 +- query/v2/query_test.go | 10 ++++++++++ query/v2/util.go | 5 ----- 4 files changed, 18 insertions(+), 7 deletions(-) diff --git a/cypher/models/cypher/format/format.go b/cypher/models/cypher/format/format.go index 495cf806..f2f62c15 100644 --- a/cypher/models/cypher/format/format.go +++ b/cypher/models/cypher/format/format.go @@ -296,7 +296,7 @@ func (s Emitter) formatProjection(output io.Writer, projection *cypher.Projectio } func (s Emitter) formatReturn(output io.Writer, returnClause *cypher.Return) error { - if _, err := io.WriteString(output, " return "); err != nil { + if _, err := io.WriteString(output, "return "); err != nil { return err } @@ -1095,6 +1095,12 @@ func (s Emitter) formatSinglePartQuery(writer io.Writer, singlePartQuery *cypher } if singlePartQuery.Return != nil { + if len(singlePartQuery.ReadingClauses) > 0 || len(singlePartQuery.UpdatingClauses) > 0 { + if _, err := io.WriteString(writer, " "); err != nil { + return err + } + } + return s.formatReturn(writer, singlePartQuery.Return) } diff --git a/query/v2/query.go b/query/v2/query.go index 2e94fe2b..9e833f00 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -1062,7 +1062,7 @@ func (s *builder) Build() (*PreparedQuery, error) { matchIdentifiers := readIdentifiers.Clone() matchIdentifiers.Or(actionIdentifiers) - if len(s.constraints) > 0 || len(s.creates) == 0 || matchIdentifiers.Len() > 0 { + if len(s.constraints) > 0 || matchIdentifiers.Len() > 0 { if isNodePattern(matchIdentifiers, s.identifiers) { if err := prepareNodePattern(match, matchIdentifiers, s.identifiers); err != nil { return nil, err diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 9b04b927..35a3d35d 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -183,6 +183,16 @@ func TestProjectionAndOrderHelpers(t *testing.T) { require.Equal(t, "match (n) return distinct id(n) as node_id order by n.name asc, id(n) desc", renderPrepared(t, preparedQuery)) } +func TestProjectionAliasDoesNotCreateMatchInference(t *testing.T) { + preparedQuery, err := v2.New().Return( + v2.As(v2.Literal(1), "one"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "return 1 as one", renderPrepared(t, preparedQuery)) + require.Empty(t, preparedQuery.Parameters) +} + func TestUnsupportedOrderByTypeReturnsError(t *testing.T) { _, err := v2.New().Return(v2.Node()).OrderBy(123).Build() require.ErrorContains(t, err, "unsupported expression type: int") diff --git a/query/v2/util.go b/query/v2/util.go index c5ea4857..d3ccdd99 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -510,11 +510,6 @@ func (s *identifierExtractor) Enter(node cypher.SyntaxNode) { if typedNode.Variable != nil { s.seen.Add(typedNode.Variable.Symbol) } - - case *cypher.ProjectionItem: - if typedNode.Alias != nil { - s.seen.Add(typedNode.Alias.Symbol) - } } } From 961e60125749a9ab74dcbe910dddd5b980ec000f Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:14:38 -0700 Subject: [PATCH 17/55] fix(query/v2): validate explicit relationship directions --- query/v2/query.go | 12 +++++++----- query/v2/query_test.go | 12 ++++++++++++ query/v2/util.go | 13 +++++++++++++ 3 files changed, 32 insertions(+), 5 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 9e833f00..201e0e66 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -632,12 +632,10 @@ func (s *builder) WithAllShortestPaths() QueryBuilder { } func (s *builder) WithRelationshipDirection(direction graph.Direction) QueryBuilder { - switch direction { - case graph.DirectionInbound, graph.DirectionOutbound: + if err := validateRelationshipDirection(direction); err != nil { + s.trackError(err) + } else { s.relationshipDirection = direction - - default: - s.trackError(fmt.Errorf("unsupported relationship direction: %s", direction)) } return s @@ -840,6 +838,10 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId pattern.AddPatternElements(typedNextCreate) case *cypher.RelationshipPattern: + if err := validateRelationshipDirection(typedNextCreate.Direction); err != nil { + return err + } + pattern.AddPatternElements(&cypher.NodePattern{ Variable: identifiers.Start(), }) diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 35a3d35d..8a324b90 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -171,6 +171,18 @@ func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { require.ErrorContains(t, err, "unsupported relationship direction: both") } +func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) { + _, err := v2.New().Create( + v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth), + ).Build() + require.ErrorContains(t, err, "unsupported relationship direction: both") + + _, err = v2.New().Create( + v2.Relationship().RelationshipPattern(graph.StringKind("Edge"), nil, graph.Direction(99)), + ).Build() + require.ErrorContains(t, err, "unsupported relationship direction: invalid") +} + func TestProjectionAndOrderHelpers(t *testing.T) { preparedQuery, err := v2.New().ReturnDistinct( v2.As(v2.Node().ID(), "node_id"), diff --git a/query/v2/util.go b/query/v2/util.go index d3ccdd99..d68afb98 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -37,11 +37,24 @@ func prepareNodePattern(match *cypher.Match, seen *identifierSet, identifiers ru return nil } +func validateRelationshipDirection(direction graph.Direction) error { + switch direction { + case graph.DirectionInbound, graph.DirectionOutbound: + return nil + default: + return fmt.Errorf("unsupported relationship direction: %s", direction) + } +} + func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers, relationshipKinds graph.Kinds, direction graph.Direction, shortestPaths, allShortestPaths bool) error { if shortestPaths && allShortestPaths { return errors.New("query is requesting both all shortest paths and shortest paths") } + if err := validateRelationshipDirection(direction); err != nil { + return err + } + var ( newPatternPart = match.NewPatternPart() startNodeSeen = seen.Contains(identifiers.start) From 0eb694507264c8e559b8124a7d60b79a052ab32d Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:15:26 -0700 Subject: [PATCH 18/55] fix(query/v2): validate scope aliases --- query/v2/query.go | 57 +++++++++++++++++++++++++++++++++++------- query/v2/query_test.go | 10 ++++++++ 2 files changed, 58 insertions(+), 9 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 201e0e66..968d11ec 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -3,6 +3,7 @@ package v2 import ( "errors" "fmt" + "strings" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/graph" @@ -46,6 +47,7 @@ var Identifiers = runtimeIdentifiers{ type Scope struct { identifiers runtimeIdentifiers + errors []error } func DefaultScope() Scope { @@ -55,19 +57,55 @@ func DefaultScope() Scope { } func NewScope(path, node, start, relationship, end string) Scope { + identifiers := runtimeIdentifiers{ + path: path, + node: node, + start: start, + relationship: relationship, + end: end, + } + return Scope{ - identifiers: runtimeIdentifiers{ - path: path, - node: node, - start: start, - relationship: relationship, - end: end, - }, + identifiers: identifiers, + errors: validateRuntimeIdentifiers(identifiers), + } +} + +func validateRuntimeIdentifiers(identifiers runtimeIdentifiers) []error { + aliases := []struct { + role string + value string + }{ + {role: "path", value: identifiers.path}, + {role: "node", value: identifiers.node}, + {role: "start", value: identifiers.start}, + {role: "relationship", value: identifiers.relationship}, + {role: "end", value: identifiers.end}, + } + + var ( + errs []error + seen = map[string]string{} + ) + + for _, alias := range aliases { + if strings.TrimSpace(alias.value) == "" { + errs = append(errs, fmt.Errorf("scope alias %s is empty", alias.role)) + continue + } + + if existingRole, exists := seen[alias.value]; exists { + errs = append(errs, fmt.Errorf("scope aliases %s and %s both use %q", existingRole, alias.role, alias.value)) + } else { + seen[alias.value] = alias.role + } } + + return errs } func (s Scope) New() QueryBuilder { - return newBuilder(s.identifiers) + return newBuilder(s.identifiers, s.errors...) } func (s Scope) Node() NodeContinuation { @@ -614,9 +652,10 @@ func New() QueryBuilder { return DefaultScope().New() } -func newBuilder(identifiers runtimeIdentifiers) QueryBuilder { +func newBuilder(identifiers runtimeIdentifiers, errs ...error) QueryBuilder { return &builder{ identifiers: identifiers, + errors: append([]error(nil), errs...), relationshipDirection: graph.DirectionOutbound, } } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 8a324b90..d5a53990 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -163,6 +163,16 @@ func TestScopedKindsOfCompatibilityHelper(t *testing.T) { require.Equal(t, "match ()-[edge]->(target) return type(edge), labels(target)", renderPrepared(t, preparedQuery)) } +func TestInvalidScopeAliasesReturnBuildErrors(t *testing.T) { + emptyAliasScope := v2.NewScope("", "node", "start", "relationship", "end") + _, err := emptyAliasScope.New().Return(emptyAliasScope.Node()).Build() + require.ErrorContains(t, err, "scope alias path is empty") + + duplicateAliasScope := v2.NewScope("path", "node", "node", "relationship", "end") + _, err = duplicateAliasScope.New().Return(duplicateAliasScope.Start()).Build() + require.ErrorContains(t, err, `scope aliases node and start both use "node"`) +} + func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build() require.ErrorContains(t, err, "unsupported relationship direction: invalid") From ddf68537aaa19ecc30644b4b5228179afe90b52b Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:18:16 -0700 Subject: [PATCH 19/55] fix(query/v2): validate raw projection inputs --- query/v2/query.go | 22 ++++--- query/v2/query_test.go | 32 ++++++++++ query/v2/util.go | 138 +++++++++++++++++++++++++++++++++++++---- 3 files changed, 173 insertions(+), 19 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 968d11ec..b96f072c 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -940,12 +940,18 @@ func (s *builder) buildProjectionOrder() (*cypher.Order, error) { for _, untypedSortItem := range s.sortItems { switch typedSortItem := untypedSortItem.(type) { case *cypher.Order: - for _, sortItem := range typedSortItem.Items { - orderByNode.Items = append(orderByNode.Items, sortItem) + if sortItems, err := sortItemsFromOrder(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItems...) } case *cypher.SortItem: - orderByNode.Items = append(orderByNode.Items, typedSortItem) + if sortItem, err := sortItemFromValue(typedSortItem); err != nil { + return nil, err + } else { + orderByNode.Items = append(orderByNode.Items, sortItem) + } default: if sortItem, err := sortItemFromValue(typedSortItem); err != nil { @@ -978,11 +984,11 @@ func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error for _, nextProjection := range s.projections { switch typedNextProjection := nextProjection.(type) { case *cypher.Return: - for _, returnItem := range typedNextProjection.Projection.Items { - if typedReturnItem, typeOK := returnItem.(*cypher.ProjectionItem); !typeOK { - return fmt.Errorf("invalid type for return: %T", returnItem) - } else { - projection.AddItem(typedReturnItem) + if projectionItems, err := projectionItemsFromReturn(typedNextProjection); err != nil { + return err + } else { + for _, projectionItem := range projectionItems { + projection.AddItem(projectionItem) } } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index d5a53990..dcabf76e 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -220,6 +220,38 @@ func TestUnsupportedOrderByTypeReturnsError(t *testing.T) { require.ErrorContains(t, err, "unsupported expression type: int") } +func TestRawProjectionAndOrderInputsAreValidated(t *testing.T) { + _, err := v2.New().Return(&cypher.Return{}).Build() + require.ErrorContains(t, err, "return clause has nil projection") + + returnClause := cypher.NewReturn() + returnClause.NewProjection(false).Items = append(returnClause.Projection.Items, &cypher.ProjectionItem{}) + _, err = v2.New().Return(returnClause).Build() + require.ErrorContains(t, err, "projection item has nil expression") + + _, err = v2.New().Return(v2.Node()).OrderBy(&cypher.SortItem{}).Build() + require.ErrorContains(t, err, "sort item has nil expression") + + _, err = v2.New().Return(v2.Node()).OrderBy(&cypher.Order{ + Items: []*cypher.SortItem{{}}, + }).Build() + require.ErrorContains(t, err, "sort item has nil expression") +} + +func TestRawProjectionAndOrderInputsAreNormalized(t *testing.T) { + returnClause := cypher.NewReturn() + returnClause.NewProjection(false).Items = append(returnClause.Projection.Items, v2.Node().ID()) + + preparedQuery, err := v2.New().Return(returnClause).OrderBy(&cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + }).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return id(n) order by n.name desc", renderPrepared(t, preparedQuery)) +} + func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) { cases := map[string]struct { builder v2.QueryBuilder diff --git a/query/v2/util.go b/query/v2/util.go index d68afb98..06cc0ffc 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -147,6 +147,15 @@ func invalidExpression(err error) *cypher.FunctionInvocation { return cypher.WithErrors(cypher.NewSimpleFunctionInvocation("__invalid_expression__"), err) } +func isNilPointer(value any) bool { + if value == nil { + return true + } + + reflectValue := reflect.ValueOf(value) + return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() +} + func expressionOrError(value any) cypher.Expression { if expression, err := projectionExpression(value); err != nil { return invalidExpression(err) @@ -177,6 +186,10 @@ func propertyLookupOrError(reference any, propertyName string) cypher.Expression } func projectionExpression(value any) (cypher.Expression, error) { + if isNilPointer(value) { + return nil, fmt.Errorf("expression is nil: %T", value) + } + switch typedValue := value.(type) { case kindContinuation: return kindProjectionExpression(typedValue.role, typedValue.identifier) @@ -256,10 +269,22 @@ func projectionExpression(value any) (cypher.Expression, error) { } } +func validateExpressionValue(expression cypher.Expression, context string) error { + if isNilPointer(expression) { + return fmt.Errorf("%s has nil expression", context) + } + + return collectModelErrors(expression) +} + func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { if projectionItem, typeOK := value.(*cypher.ProjectionItem); typeOK { - if projectionItem.Expression == nil { - return nil, fmt.Errorf("projection item has nil expression") + if projectionItem == nil { + return nil, fmt.Errorf("projection item is nil") + } + + if err := validateExpressionValue(projectionItem.Expression, "projection item"); err != nil { + return nil, err } if err := collectModelErrors(projectionItem); err != nil { @@ -278,8 +303,12 @@ func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { func sortItemFromValue(value any) (*cypher.SortItem, error) { if sortItem, typeOK := value.(*cypher.SortItem); typeOK { - if sortItem.Expression == nil { - return nil, fmt.Errorf("sort item has nil expression") + if sortItem == nil { + return nil, fmt.Errorf("sort item is nil") + } + + if err := validateExpressionValue(sortItem.Expression, "sort item"); err != nil { + return nil, err } if err := collectModelErrors(sortItem); err != nil { @@ -299,6 +328,46 @@ func sortItemFromValue(value any) (*cypher.SortItem, error) { } } +func projectionItemsFromReturn(returnClause *cypher.Return) ([]*cypher.ProjectionItem, error) { + if returnClause == nil { + return nil, fmt.Errorf("return clause is nil") + } + + if returnClause.Projection == nil { + return nil, fmt.Errorf("return clause has nil projection") + } + + projectionItems := make([]*cypher.ProjectionItem, 0, len(returnClause.Projection.Items)) + + for _, returnItem := range returnClause.Projection.Items { + if projectionItem, err := projectionItemFromValue(returnItem); err != nil { + return nil, err + } else { + projectionItems = append(projectionItems, projectionItem) + } + } + + return projectionItems, nil +} + +func sortItemsFromOrder(order *cypher.Order) ([]*cypher.SortItem, error) { + if order == nil { + return nil, fmt.Errorf("order is nil") + } + + sortItems := make([]*cypher.SortItem, 0, len(order.Items)) + + for _, sortItem := range order.Items { + if normalizedSortItem, err := sortItemFromValue(sortItem); err != nil { + return nil, err + } else { + sortItems = append(sortItems, normalizedSortItem) + } + } + + return sortItems, nil +} + type identifierSet struct { identifiers map[string]struct{} } @@ -366,13 +435,33 @@ func (s *identifierSet) CollectFromValue(value any) error { return nil case *cypher.Return: - return s.CollectFromExpression(typedValue) + if projectionItems, err := projectionItemsFromReturn(typedValue); err != nil { + return err + } else { + for _, projectionItem := range projectionItems { + if err := s.CollectFromExpression(projectionItem); err != nil { + return err + } + } + } case *cypher.Order: - return s.CollectFromExpression(typedValue) + if sortItems, err := sortItemsFromOrder(typedValue); err != nil { + return err + } else { + for _, sortItem := range sortItems { + if err := s.CollectFromExpression(sortItem); err != nil { + return err + } + } + } case *cypher.SortItem: - return s.CollectFromExpression(typedValue) + if sortItem, err := sortItemFromValue(typedValue); err != nil { + return err + } else { + return s.CollectFromExpression(sortItem) + } case *cypher.Set: return s.CollectFromExpression(typedValue) @@ -557,6 +646,11 @@ func collectModelErrorsFromKnownValues(values ...any) error { case nil: continue + case []any: + if err := collectModelErrorsFromKnownValues(typedValue...); err != nil { + modelErrors = append(modelErrors, err) + } + case []cypher.SyntaxNode: if err := collectModelErrorsFromKnownValues(anySlice(typedValue)...); err != nil { modelErrors = append(modelErrors, err) @@ -587,6 +681,32 @@ func collectModelErrorsFromKnownValues(values ...any) error { modelErrors = append(modelErrors, err) } + case *cypher.Order: + if _, err := sortItemsFromOrder(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.ProjectionItem: + if _, err := projectionItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Return: + if _, err := projectionItemsFromReturn(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.SortItem: + if _, err := sortItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + case *cypher.ArithmeticExpression, *cypher.Comparison, *cypher.Conjunction, @@ -601,18 +721,14 @@ func collectModelErrorsFromKnownValues(values ...any) error { *cypher.ListLiteral, *cypher.Negation, *cypher.NodePattern, - *cypher.Order, *cypher.Parenthetical, *cypher.PatternPredicate, - *cypher.ProjectionItem, *cypher.PropertyLookup, *cypher.RelationshipPattern, *cypher.Remove, *cypher.RemoveItem, - *cypher.Return, *cypher.Set, *cypher.SetItem, - *cypher.SortItem, *cypher.UnaryAddOrSubtractExpression, *cypher.UpdatingClause, *cypher.Variable: From e333c6407867669dec884ae150280035e1359764 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 01:18:36 -0700 Subject: [PATCH 20/55] chore(query/v2): remove unused extractor state --- query/v2/util.go | 5 ----- 1 file changed, 5 deletions(-) diff --git a/query/v2/util.go b/query/v2/util.go index 06cc0ffc..9d5572ec 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -579,11 +579,6 @@ type identifierExtractor struct { walk.Visitor[cypher.SyntaxNode] seen *identifierSet - - inDelete bool - inUpdate bool - inCreate bool - inWhere bool } func newIdentifierExtractor() *identifierExtractor { From 0eb36b3d5b9709db6bfac2dd314be15e8d052119 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 08:29:21 -0700 Subject: [PATCH 21/55] fix(query/v2): sort property update keys --- query/v2/compat.go | 4 ++-- query/v2/query.go | 4 ++-- query/v2/query_test.go | 30 ++++++++++++++++++++++++++++++ query/v2/util.go | 12 ++++++++++++ 4 files changed, 46 insertions(+), 4 deletions(-) diff --git a/query/v2/compat.go b/query/v2/compat.go index 1e941816..fb6be7b3 100644 --- a/query/v2/compat.go +++ b/query/v2/compat.go @@ -112,11 +112,11 @@ func SetProperty(reference any, value any) *cypher.SetItem { func SetProperties(reference any, properties map[string]any) *cypher.Set { set := &cypher.Set{} - for key, value := range properties { + for _, key := range sortedPropertyKeys(properties) { set.Items = append(set.Items, cypher.NewSetItem( propertyLookupOrError(reference, key), cypher.OperatorAssignment, - valueExpression(value), + valueExpression(properties[key]), )) } diff --git a/query/v2/query.go b/query/v2/query.go index b96f072c..a833e658 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -448,8 +448,8 @@ func (s *entity[T]) Count() cypher.Expression { func (s *entity[T]) SetProperties(properties map[string]any) cypher.Expression { set := &cypher.Set{} - for key, value := range properties { - set.Items = append(set.Items, s.Property(key).Set(value)) + for _, key := range sortedPropertyKeys(properties) { + set.Items = append(set.Items, s.Property(key).Set(properties[key])) } return set diff --git a/query/v2/query_test.go b/query/v2/query_test.go index dcabf76e..78f4746c 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -323,3 +323,33 @@ func TestUpdateCompatibilityHelpers(t *testing.T) { "p1": "updated", }, preparedQuery.Parameters) } + +func TestSetPropertiesSortsKeys(t *testing.T) { + properties := map[string]any{ + "zeta": 3, + "alpha": 1, + "mid": 2, + } + + preparedQuery, err := v2.New().Update( + v2.SetProperties(v2.Node(), properties), + ).Build() + require.NoError(t, err) + require.Equal(t, "match (n) set n.alpha = $p0, n.mid = $p1, n.zeta = $p2", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "p2": 3, + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().Update( + v2.Node().SetProperties(properties), + ).Build() + require.NoError(t, err) + require.Equal(t, "match (n) set n.alpha = $p0, n.mid = $p1, n.zeta = $p2", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "p2": 3, + }, preparedQuery.Parameters) +} diff --git a/query/v2/util.go b/query/v2/util.go index 9d5572ec..d7b045eb 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "reflect" + "sort" "strconv" "github.com/specterops/dawgs/cypher/models/cypher" @@ -185,6 +186,17 @@ func propertyLookupOrError(reference any, propertyName string) cypher.Expression } } +func sortedPropertyKeys(properties map[string]any) []string { + keys := make([]string, 0, len(properties)) + + for key := range properties { + keys = append(keys, key) + } + + sort.Strings(keys) + return keys +} + func projectionExpression(value any) (cypher.Expression, error) { if isNilPointer(value) { return nil, fmt.Errorf("expression is nil: %T", value) From 8676196d5df980ad9214ad1202943abab9739bd0 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 08:32:06 -0700 Subject: [PATCH 22/55] fix(query/v2): validate raw mutation inputs --- query/v2/query.go | 51 ++++++++-- query/v2/query_test.go | 36 ++++++- query/v2/util.go | 214 ++++++++++++++++++++++++++++++++++++++--- 3 files changed, 280 insertions(+), 21 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index a833e658..af7e5822 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -773,20 +773,36 @@ func (s *builder) Update(updates ...any) QueryBuilder { for _, nextUpdate := range updates { switch typedNextUpdate := nextUpdate.(type) { case *cypher.Set: - s.setItems = append(s.setItems, typedNextUpdate.Items...) - s.appendSetItems(typedNextUpdate.Items...) + if setItems, err := setItemsFromSet(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.setItems = append(s.setItems, setItems...) + s.appendSetItems(setItems...) + } case *cypher.SetItem: - s.setItems = append(s.setItems, typedNextUpdate) - s.appendSetItems(typedNextUpdate) + if setItem, err := setItemFromValue(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.setItems = append(s.setItems, setItem) + s.appendSetItems(setItem) + } case *cypher.Remove: - s.removeItems = append(s.removeItems, typedNextUpdate.Items...) - s.appendRemoveItems(typedNextUpdate.Items...) + if removeItems, err := removeItemsFromRemove(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.removeItems = append(s.removeItems, removeItems...) + s.appendRemoveItems(removeItems...) + } case *cypher.RemoveItem: - s.removeItems = append(s.removeItems, typedNextUpdate) - s.appendRemoveItems(typedNextUpdate) + if removeItem, err := removeItemFromValue(typedNextUpdate); err != nil { + s.trackError(err) + } else { + s.removeItems = append(s.removeItems, removeItem) + s.appendRemoveItems(removeItem) + } default: s.trackError(fmt.Errorf("unknown update type: %T", nextUpdate)) @@ -804,6 +820,10 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { switch typedNextUpdate := nextDelete.(type) { case QualifiedExpression: qualifier := typedNextUpdate.qualifier() + if err := validateExpressionValue(qualifier, "delete expression"); err != nil { + s.trackError(err) + continue + } if isDetachDeleteQualifier(qualifier, s.identifiers) { s.detachDelete = true @@ -814,6 +834,11 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { pendingDeleteItems = append(pendingDeleteItems, qualifier) case *cypher.Variable: + if err := validateExpressionValue(typedNextUpdate, "delete expression"); err != nil { + s.trackError(err) + continue + } + switch typedNextUpdate.Symbol { case s.identifiers.node, s.identifiers.start, s.identifiers.end: s.detachDelete = true @@ -859,6 +884,10 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId case QualifiedExpression: switch typedExpression := typedNextCreate.qualifier().(type) { case *cypher.Variable: + if typedExpression == nil { + return fmt.Errorf("invalid variable reference for create: ") + } + switch typedExpression.Symbol { case identifiers.node, identifiers.start, identifiers.end: pattern.AddPatternElements(&cypher.NodePattern{ @@ -874,10 +903,14 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId } case *cypher.NodePattern: + if err := validateNodePattern(typedNextCreate); err != nil { + return err + } + pattern.AddPatternElements(typedNextCreate) case *cypher.RelationshipPattern: - if err := validateRelationshipDirection(typedNextCreate.Direction); err != nil { + if err := validateRelationshipPattern(typedNextCreate); err != nil { return err } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 78f4746c..3a765f3a 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -40,7 +40,7 @@ func TestQuery(t *testing.T) { cypherQueryStr, err := format.RegularQuery(preparedQuery.Query, false) require.NoError(t, err) - require.Equal(t, "match (s)-[r]->() where not r:test and not (r:A or r:B) and r.rel_prop <= $p0 and r.other_prop = $p1 and s:test set s.this_prop = $p2 remove e:A:B detach delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) + require.Equal(t, "match (s)-[r]->(e) where not r:test and not (r:A or r:B) and r.rel_prop <= $p0 and r.other_prop = $p1 and s:test set s.this_prop = $p2 remove e:A:B detach delete s return r, s.node_prop skip 10 limit 10", cypherQueryStr) require.Equal(t, map[string]any{ "p0": 1234, "p1": 5678, @@ -252,6 +252,40 @@ func TestRawProjectionAndOrderInputsAreNormalized(t *testing.T) { require.Equal(t, "match (n) return id(n) order by n.name desc", renderPrepared(t, preparedQuery)) } +func TestRawUpdatingInputsAreValidated(t *testing.T) { + var setClause *cypher.Set + _, err := v2.New().Update(setClause).Build() + require.ErrorContains(t, err, "set clause is nil") + + _, err = v2.New().Update(&cypher.Set{Items: []*cypher.SetItem{nil}}).Build() + require.ErrorContains(t, err, "set item is nil") + + _, err = v2.New().Update(&cypher.SetItem{}).Build() + require.ErrorContains(t, err, "set item left has nil expression") + + var removeClause *cypher.Remove + _, err = v2.New().Update(removeClause).Build() + require.ErrorContains(t, err, "remove clause is nil") + + _, err = v2.New().Update(&cypher.Remove{Items: []*cypher.RemoveItem{nil}}).Build() + require.ErrorContains(t, err, "remove item is nil") + + _, err = v2.New().Update(&cypher.RemoveItem{}).Build() + require.ErrorContains(t, err, "remove item has no target") + + var deleteVariable *cypher.Variable + _, err = v2.New().Delete(deleteVariable).Build() + require.ErrorContains(t, err, "delete expression has nil expression") + + var nodePattern *cypher.NodePattern + _, err = v2.New().Create(nodePattern).Build() + require.ErrorContains(t, err, "node pattern is nil") + + var relationshipPattern *cypher.RelationshipPattern + _, err = v2.New().Create(relationshipPattern).Build() + require.ErrorContains(t, err, "relationship pattern is nil") +} + func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) { cases := map[string]struct { builder v2.QueryBuilder diff --git a/query/v2/util.go b/query/v2/util.go index d7b045eb..dac77202 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -119,7 +119,7 @@ func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet, id func isDetachDeleteQualifier(qualifier cypher.Expression, identifiers runtimeIdentifiers) bool { variable, typeOK := qualifier.(*cypher.Variable) - if !typeOK { + if !typeOK || variable == nil { return false } @@ -289,6 +289,128 @@ func validateExpressionValue(expression cypher.Expression, context string) error return collectModelErrors(expression) } +func validateAssignmentOperator(operator cypher.AssignmentOperator) error { + switch operator { + case cypher.OperatorAssignment, cypher.OperatorAdditionAssignment, cypher.OperatorLabelAssignment: + return nil + default: + return fmt.Errorf("unsupported set item operator: %s", operator) + } +} + +func setItemFromValue(setItem *cypher.SetItem) (*cypher.SetItem, error) { + if setItem == nil { + return nil, fmt.Errorf("set item is nil") + } + + if err := validateExpressionValue(setItem.Left, "set item left"); err != nil { + return nil, err + } + + if err := validateAssignmentOperator(setItem.Operator); err != nil { + return nil, err + } + + if err := validateExpressionValue(setItem.Right, "set item right"); err != nil { + return nil, err + } + + if err := collectModelErrors(setItem); err != nil { + return nil, err + } + + return setItem, nil +} + +func setItemsFromSet(setClause *cypher.Set) ([]*cypher.SetItem, error) { + if setClause == nil { + return nil, fmt.Errorf("set clause is nil") + } + + setItems := make([]*cypher.SetItem, 0, len(setClause.Items)) + + for _, setItem := range setClause.Items { + if normalizedSetItem, err := setItemFromValue(setItem); err != nil { + return nil, err + } else { + setItems = append(setItems, normalizedSetItem) + } + } + + return setItems, nil +} + +func removeItemFromValue(removeItem *cypher.RemoveItem) (*cypher.RemoveItem, error) { + if removeItem == nil { + return nil, fmt.Errorf("remove item is nil") + } + + hasKindMatcher := removeItem.KindMatcher != nil + hasProperty := !isNilPointer(removeItem.Property) + + switch { + case hasKindMatcher && hasProperty: + return nil, fmt.Errorf("remove item has multiple targets") + + case hasKindMatcher: + if err := collectModelErrors(removeItem.KindMatcher); err != nil { + return nil, err + } + + case hasProperty: + if err := validateExpressionValue(removeItem.Property, "remove item property"); err != nil { + return nil, err + } + + default: + return nil, fmt.Errorf("remove item has no target") + } + + if err := collectModelErrors(removeItem); err != nil { + return nil, err + } + + return removeItem, nil +} + +func removeItemsFromRemove(removeClause *cypher.Remove) ([]*cypher.RemoveItem, error) { + if removeClause == nil { + return nil, fmt.Errorf("remove clause is nil") + } + + removeItems := make([]*cypher.RemoveItem, 0, len(removeClause.Items)) + + for _, removeItem := range removeClause.Items { + if normalizedRemoveItem, err := removeItemFromValue(removeItem); err != nil { + return nil, err + } else { + removeItems = append(removeItems, normalizedRemoveItem) + } + } + + return removeItems, nil +} + +func validateNodePattern(nodePattern *cypher.NodePattern) error { + if nodePattern == nil { + return fmt.Errorf("node pattern is nil") + } + + return collectModelErrors(nodePattern) +} + +func validateRelationshipPattern(relationshipPattern *cypher.RelationshipPattern) error { + if relationshipPattern == nil { + return fmt.Errorf("relationship pattern is nil") + } + + if err := validateRelationshipDirection(relationshipPattern.Direction); err != nil { + return err + } + + return collectModelErrors(relationshipPattern) +} + func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { if projectionItem, typeOK := value.(*cypher.ProjectionItem); typeOK { if projectionItem == nil { @@ -476,21 +598,55 @@ func (s *identifierSet) CollectFromValue(value any) error { } case *cypher.Set: - return s.CollectFromExpression(typedValue) + if setItems, err := setItemsFromSet(typedValue); err != nil { + return err + } else { + for _, setItem := range setItems { + if err := s.CollectFromExpression(setItem); err != nil { + return err + } + } + } case *cypher.SetItem: - return s.CollectFromExpression(typedValue) + if setItem, err := setItemFromValue(typedValue); err != nil { + return err + } else { + return s.CollectFromExpression(setItem) + } case *cypher.Remove: - return s.CollectFromExpression(typedValue) + if removeItems, err := removeItemsFromRemove(typedValue); err != nil { + return err + } else { + for _, removeItem := range removeItems { + if err := s.CollectFromValue(removeItem); err != nil { + return err + } + } + } case *cypher.RemoveItem: - return s.CollectFromExpression(typedValue) + if removeItem, err := removeItemFromValue(typedValue); err != nil { + return err + } else if removeItem.KindMatcher != nil { + return s.CollectFromExpression(removeItem.KindMatcher) + } else { + return s.CollectFromExpression(removeItem) + } case *cypher.NodePattern: + if err := validateNodePattern(typedValue); err != nil { + return err + } + return s.CollectFromExpression(typedValue) case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + return err + } + return s.CollectFromExpression(typedValue) case *cypher.Variable: @@ -569,6 +725,10 @@ func collectCreateScope(identifiers runtimeIdentifiers, values ...any) (*createS for _, value := range values { switch typedValue := value.(type) { case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + return nil, err + } + scope.createsRelationship = true scope.identifiers.Add(identifiers.start) scope.identifiers.Add(identifiers.end) @@ -688,6 +848,11 @@ func collectModelErrorsFromKnownValues(values ...any) error { modelErrors = append(modelErrors, err) } + case *cypher.NodePattern: + if err := validateNodePattern(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + case *cypher.Order: if _, err := sortItemsFromOrder(typedValue); err != nil { modelErrors = append(modelErrors, err) @@ -707,6 +872,39 @@ func collectModelErrorsFromKnownValues(values ...any) error { modelErrors = append(modelErrors, err) } + case *cypher.RelationshipPattern: + if err := validateRelationshipPattern(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Remove: + if _, err := removeItemsFromRemove(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.RemoveItem: + if _, err := removeItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.Set: + if _, err := setItemsFromSet(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + + case *cypher.SetItem: + if _, err := setItemFromValue(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } else if err := collectModelErrors(typedValue); err != nil { + modelErrors = append(modelErrors, err) + } + case *cypher.SortItem: if _, err := sortItemFromValue(typedValue); err != nil { modelErrors = append(modelErrors, err) @@ -727,15 +925,9 @@ func collectModelErrorsFromKnownValues(values ...any) error { *cypher.KindMatcher, *cypher.ListLiteral, *cypher.Negation, - *cypher.NodePattern, *cypher.Parenthetical, *cypher.PatternPredicate, *cypher.PropertyLookup, - *cypher.RelationshipPattern, - *cypher.Remove, - *cypher.RemoveItem, - *cypher.Set, - *cypher.SetItem, *cypher.UnaryAddOrSubtractExpression, *cypher.UpdatingClause, *cypher.Variable: From 00b855e6a154be4f264a234472765fdb106af2a1 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 08:32:55 -0700 Subject: [PATCH 23/55] fix(query/v2): validate alias symbols --- query/v2/query.go | 5 ++--- query/v2/query_test.go | 15 +++++++++++++++ query/v2/util.go | 33 +++++++++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 3 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index af7e5822..f3451cee 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -3,7 +3,6 @@ package v2 import ( "errors" "fmt" - "strings" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/graph" @@ -89,8 +88,8 @@ func validateRuntimeIdentifiers(identifiers runtimeIdentifiers) []error { ) for _, alias := range aliases { - if strings.TrimSpace(alias.value) == "" { - errs = append(errs, fmt.Errorf("scope alias %s is empty", alias.role)) + if err := validateCypherSymbol(alias.value, "scope alias "+alias.role); err != nil { + errs = append(errs, err) continue } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 3a765f3a..da17ce13 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -171,6 +171,10 @@ func TestInvalidScopeAliasesReturnBuildErrors(t *testing.T) { duplicateAliasScope := v2.NewScope("path", "node", "node", "relationship", "end") _, err = duplicateAliasScope.New().Return(duplicateAliasScope.Start()).Build() require.ErrorContains(t, err, `scope aliases node and start both use "node"`) + + invalidAliasScope := v2.NewScope("path", "bad name", "start", "relationship", "end") + _, err = invalidAliasScope.New().Return(invalidAliasScope.Node()).Build() + require.ErrorContains(t, err, `scope alias node has invalid symbol "bad name"`) } func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { @@ -215,6 +219,17 @@ func TestProjectionAliasDoesNotCreateMatchInference(t *testing.T) { require.Empty(t, preparedQuery.Parameters) } +func TestInvalidProjectionAliasReturnsBuildError(t *testing.T) { + _, err := v2.New().Return(v2.As(v2.Literal(1), "bad alias")).Build() + require.ErrorContains(t, err, `projection alias has invalid symbol "bad alias"`) + + _, err = v2.New().Return(&cypher.ProjectionItem{ + Expression: v2.Literal(1), + Alias: cypher.NewVariableWithSymbol("1bad"), + }).Build() + require.ErrorContains(t, err, `projection alias has invalid symbol "1bad"`) +} + func TestUnsupportedOrderByTypeReturnsError(t *testing.T) { _, err := v2.New().Return(v2.Node()).OrderBy(123).Build() require.ErrorContains(t, err, "unsupported expression type: int") diff --git a/query/v2/util.go b/query/v2/util.go index dac77202..e63c6e09 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -6,6 +6,7 @@ import ( "reflect" "sort" "strconv" + "strings" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/walk" @@ -197,6 +198,32 @@ func sortedPropertyKeys(properties map[string]any) []string { return keys } +func isCypherSymbolStart(char byte) bool { + return char == '_' || (char >= 'A' && char <= 'Z') || (char >= 'a' && char <= 'z') +} + +func isCypherSymbolPart(char byte) bool { + return isCypherSymbolStart(char) || (char >= '0' && char <= '9') +} + +func validateCypherSymbol(symbol, context string) error { + if strings.TrimSpace(symbol) == "" { + return fmt.Errorf("%s is empty", context) + } + + if !isCypherSymbolStart(symbol[0]) { + return fmt.Errorf("%s has invalid symbol %q", context, symbol) + } + + for idx := 1; idx < len(symbol); idx++ { + if !isCypherSymbolPart(symbol[idx]) { + return fmt.Errorf("%s has invalid symbol %q", context, symbol) + } + } + + return nil +} + func projectionExpression(value any) (cypher.Expression, error) { if isNilPointer(value) { return nil, fmt.Errorf("expression is nil: %T", value) @@ -421,6 +448,12 @@ func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { return nil, err } + if projectionItem.Alias != nil { + if err := validateCypherSymbol(projectionItem.Alias.Symbol, "projection alias"); err != nil { + return nil, err + } + } + if err := collectModelErrors(projectionItem); err != nil { return nil, err } From fa5a4f5237545860aa6bb0ca04a5dde77f365a56 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Thu, 7 May 2026 20:19:17 -0700 Subject: [PATCH 24/55] fix (query/v2): named parameter fixups; preserve projection metadata; expose shortest path helper functions --- query/v2/query.go | 51 ++++++++++++++++++++++--- query/v2/query_test.go | 81 ++++++++++++++++++++++++++++++++++++++++ query/v2/util.go | 85 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 208 insertions(+), 9 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index f3451cee..a388140c 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -998,6 +998,49 @@ func (s *builder) buildProjectionOrder() (*cypher.Order, error) { return orderByNode, nil } +func appendProjectionOrder(projection *cypher.Projection, sortItems ...*cypher.SortItem) { + if len(sortItems) == 0 { + return + } + + if projection.Order == nil { + projection.Order = &cypher.Order{} + } + + projection.Order.Items = append(projection.Order.Items, sortItems...) +} + +func applyReturnProjection(projection *cypher.Projection, returnClause *cypher.Return) error { + if projectionItems, err := projectionItemsFromReturn(returnClause); err != nil { + return err + } else { + projection.Distinct = projection.Distinct || returnClause.Projection.Distinct + projection.All = projection.All || returnClause.Projection.All + + for _, projectionItem := range projectionItems { + projection.AddItem(projectionItem) + } + } + + if returnClause.Projection.Order != nil { + if sortItems, err := sortItemsFromOrder(returnClause.Projection.Order); err != nil { + return err + } else { + appendProjectionOrder(projection, sortItems...) + } + } + + if returnClause.Projection.Skip != nil { + projection.Skip = returnClause.Projection.Skip + } + + if returnClause.Projection.Limit != nil { + projection.Limit = returnClause.Projection.Limit + } + + return nil +} + func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error { var ( hasProjectedItems = len(s.projections) > 0 @@ -1016,12 +1059,8 @@ func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error for _, nextProjection := range s.projections { switch typedNextProjection := nextProjection.(type) { case *cypher.Return: - if projectionItems, err := projectionItemsFromReturn(typedNextProjection); err != nil { + if err := applyReturnProjection(projection, typedNextProjection); err != nil { return err - } else { - for _, projectionItem := range projectionItems { - projection.AddItem(projectionItem) - } } default: @@ -1044,7 +1083,7 @@ func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error if projectionOrder, err := s.buildProjectionOrder(); err != nil { return err } else if projectionOrder != nil { - projection.Order = projectionOrder + appendProjectionOrder(projection, projectionOrder.Items...) } } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index da17ce13..2c906c46 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -185,6 +185,38 @@ func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { require.ErrorContains(t, err, "unsupported relationship direction: both") } +func TestShortestPathControls(t *testing.T) { + preparedQuery, err := v2.New().WithShortestPaths().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match p = shortestPath((s)-[*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + }, preparedQuery.Parameters) + + preparedQuery, err = v2.New().WithAllShortestPaths().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match p = allShortestPaths((s)-[*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + + _, err = v2.New().WithShortestPaths().WithAllShortestPaths().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.ErrorContains(t, err, "query is requesting both all shortest paths and shortest paths") +} + func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) { _, err := v2.New().Create( v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth), @@ -267,6 +299,24 @@ func TestRawProjectionAndOrderInputsAreNormalized(t *testing.T) { require.Equal(t, "match (n) return id(n) order by n.name desc", renderPrepared(t, preparedQuery)) } +func TestRawReturnInputPreservesProjectionMetadata(t *testing.T) { + returnClause := cypher.NewReturn() + projection := returnClause.NewProjection(true) + projection.Items = append(projection.Items, v2.Node().ID()) + projection.Order = &cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + } + projection.Skip = cypher.NewSkip(5) + projection.Limit = cypher.NewLimit(10) + + preparedQuery, err := v2.New().Return(returnClause).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) order by n.name desc skip 5 limit 10", renderPrepared(t, preparedQuery)) +} + func TestRawUpdatingInputsAreValidated(t *testing.T) { var setClause *cypher.Set _, err := v2.New().Update(setClause).Build() @@ -336,6 +386,37 @@ func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) { } } +func TestNamedParameterMaterialization(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Node().Property("first").Equals("auto"), + v2.Node().Property("second").Equals(v2.NamedParameter("p0", "named")), + ).Return( + v2.Node(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) where n.first = $p1 and n.second = $p0 return n", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": "named", + "p1": "auto", + }, preparedQuery.Parameters) + + _, err = v2.New().Where( + v2.Node().Property("name").Equals(v2.NamedParameter("bad name", "value")), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, `parameter has invalid symbol "bad name"`) + + _, err = v2.New().Where( + v2.Node().Property("first").Equals(v2.NamedParameter("same", "first")), + v2.Node().Property("second").Equals(v2.NamedParameter("same", "second")), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, "parameter same is bound to multiple values") +} + func TestCompatibilityHelpers(t *testing.T) { preparedQuery, err := v2.New().Where( v2.And( diff --git a/query/v2/util.go b/query/v2/util.go index e63c6e09..c15c6400 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -504,6 +504,10 @@ func projectionItemsFromReturn(returnClause *cypher.Return) ([]*cypher.Projectio return nil, fmt.Errorf("return clause has nil projection") } + if err := validateProjectionMetadata(returnClause.Projection); err != nil { + return nil, err + } + projectionItems := make([]*cypher.ProjectionItem, 0, len(returnClause.Projection.Items)) for _, returnItem := range returnClause.Projection.Items { @@ -517,6 +521,28 @@ func projectionItemsFromReturn(returnClause *cypher.Return) ([]*cypher.Projectio return projectionItems, nil } +func validateProjectionMetadata(projection *cypher.Projection) error { + if projection.Order != nil { + if _, err := sortItemsFromOrder(projection.Order); err != nil { + return err + } + } + + if projection.Skip != nil { + if err := validateExpressionValue(projection.Skip.Value, "projection skip"); err != nil { + return err + } + } + + if projection.Limit != nil { + if err := validateExpressionValue(projection.Limit.Value, "projection limit"); err != nil { + return err + } + } + + return nil +} + func sortItemsFromOrder(order *cypher.Order) ([]*cypher.SortItem, error) { if order == nil { return nil, fmt.Errorf("order is nil") @@ -990,10 +1016,16 @@ type parameterMaterializer struct { nextIndex int } -func newParameterMaterializer() *parameterMaterializer { +func newParameterMaterializer(parameters map[string]any) *parameterMaterializer { + materializedParameters := map[string]any{} + + for symbol, value := range parameters { + materializedParameters[symbol] = value + } + return ¶meterMaterializer{ Visitor: walk.NewVisitor[cypher.SyntaxNode](), - parameters: map[string]any{}, + parameters: materializedParameters, } } @@ -1026,8 +1058,55 @@ func (s *parameterMaterializer) Enter(node cypher.SyntaxNode) { s.parameters[parameter.Symbol] = parameter.Value } +type namedParameterCollector struct { + walk.Visitor[cypher.SyntaxNode] + + parameters map[string]any +} + +func newNamedParameterCollector() *namedParameterCollector { + return &namedParameterCollector{ + Visitor: walk.NewVisitor[cypher.SyntaxNode](), + parameters: map[string]any{}, + } +} + +func (s *namedParameterCollector) Enter(node cypher.SyntaxNode) { + parameter, typeOK := node.(*cypher.Parameter) + if !typeOK || parameter.Symbol == "" { + return + } + + if err := validateCypherSymbol(parameter.Symbol, "parameter"); err != nil { + s.SetError(err) + return + } + + if existingValue, exists := s.parameters[parameter.Symbol]; exists && !reflect.DeepEqual(existingValue, parameter.Value) { + s.SetErrorf("parameter %s is bound to multiple values", parameter.Symbol) + return + } + + s.parameters[parameter.Symbol] = parameter.Value +} + +func collectNamedParameters(query *cypher.RegularQuery) (map[string]any, error) { + collector := newNamedParameterCollector() + + if err := walk.Cypher(query, collector); err != nil { + return nil, err + } + + return collector.parameters, nil +} + func materializeParameters(query *cypher.RegularQuery) (map[string]any, error) { - materializer := newParameterMaterializer() + namedParameters, err := collectNamedParameters(query) + if err != nil { + return nil, err + } + + materializer := newParameterMaterializer(namedParameters) if err := walk.Cypher(query, materializer); err != nil { return nil, err From 6134427479d269af90d79a630bddf1fa782b8534 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Fri, 8 May 2026 16:18:14 -0700 Subject: [PATCH 25/55] fix (query/v2): clean up some poorly supported neo4j constructs --- cypher/models/pgsql/translate/translator.go | 3 - query/neo4j/neo4j.go | 12 ++ query/v2/backend_test.go | 126 ++++++++++++++++++-- query/v2/query_test.go | 36 +++++- query/v2/util.go | 8 -- 5 files changed, 161 insertions(+), 24 deletions(-) diff --git a/cypher/models/pgsql/translate/translator.go b/cypher/models/pgsql/translate/translator.go index 03d18b48..993248b1 100644 --- a/cypher/models/pgsql/translate/translator.go +++ b/cypher/models/pgsql/translate/translator.go @@ -284,9 +284,6 @@ func (s *Translator) Exit(expression cypher.SyntaxNode) { s.SetError(err) } - case *cypher.Create: - s.SetErrorf("pgsql translator does not support create clauses") - case *cypher.Delete: if err := s.translateDelete(s.scope, typedExpression); err != nil { s.SetError(err) diff --git a/query/neo4j/neo4j.go b/query/neo4j/neo4j.go index 0689f74f..7299d004 100644 --- a/query/neo4j/neo4j.go +++ b/query/neo4j/neo4j.go @@ -53,6 +53,14 @@ func (s *QueryBuilder) rewriteParameters() error { return nil } +func hasPreparedMatchPattern(readingClause *cypher.ReadingClause) bool { + if readingClause == nil || readingClause.Match == nil { + return false + } + + return len(readingClause.Match.Pattern) > 0 +} + func (s *QueryBuilder) Apply(criteria graph.Criteria) { switch typedCriteria := criteria.(type) { case *cypher.Where: @@ -201,6 +209,10 @@ func (s *QueryBuilder) prepareMatch() error { return ErrAmbiguousQueryVariables } + if firstReadingClause := query.GetFirstReadingClause(s.query); hasPreparedMatchPattern(firstReadingClause) { + return nil + } + if singleNodeBound && !creatingSingleNode { patternPart.AddPatternElements(&cypher.NodePattern{ Variable: cypher.NewVariableWithSymbol(query.NodeSymbol), diff --git a/query/v2/backend_test.go b/query/v2/backend_test.go index de504a7a..6c454694 100644 --- a/query/v2/backend_test.go +++ b/query/v2/backend_test.go @@ -2,6 +2,7 @@ package v2_test import ( "context" + "strings" "testing" "github.com/specterops/dawgs/cypher/models/pgsql/translate" @@ -49,9 +50,31 @@ func TestBackendParityNeo4jPrepare(t *testing.T) { v2.Relationship().ID(), v2.End().ID(), ), - expectedCypher: "match (s)-[r]->(e) where id(s) = $p0 return id(s), id(r), id(e)", + expectedCypher: "match (s)-[r:MemberOf]->(e) where id(s) = $p0 return id(s), id(r), id(e)", expectedParams: map[string]any{"p0": 1}, }, + "shortest path": { + builder: v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedCypher: "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", + expectedParams: map[string]any{"p0": 1, "p1": 2}, + }, + "all shortest paths": { + builder: v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedCypher: "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", + expectedParams: map[string]any{"p0": 1, "p1": 2}, + }, "create node": { builder: v2.New().Create( v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})), @@ -79,6 +102,15 @@ func TestBackendParityNeo4jPrepare(t *testing.T) { expectedCypher: "match ()-[r]->() where id(r) = $p0 delete r", expectedParams: map[string]any{"p0": 1}, }, + "delete node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Delete( + v2.Node(), + ), + expectedCypher: "match (n) where id(n) = $p0 detach delete n", + expectedParams: map[string]any{"p0": 1}, + }, } for name, testCase := range cases { @@ -115,8 +147,8 @@ func TestBackendParityPGTranslate(t *testing.T) { v2.Node().ID(), v2.Node().Kinds(), ), - expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and (n0.properties ->> 'name') like '%' || @pi0::text || '%')) select (s0.n0).id, (s0.n0).kind_ids from s0;", - expectedParams: map[string]any{"p0": "admin", "pi0": "admin"}, + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.kind_ids operator (pg_catalog.&&) array [1]::int2[] and cypher_contains((n0.properties ->> 'name'), (@pi0::text)::text)::bool)) select (s0.n0).id, (s0.n0).kind_ids from s0;", + expectedParams: map[string]any{"pi0": "admin"}, }, "relationship read": { builder: v2.New().Where( @@ -128,7 +160,7 @@ func TestBackendParityPGTranslate(t *testing.T) { v2.End().ID(), ), expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on (n0.id = @pi0::int8) and n0.id = e0.start_id join node n1 on n1.id = e0.end_id where e0.kind_id = any (array [2]::int2[])) select (s0.n0).id, (s0.e0).id, (s0.n1).id from s0;", - expectedParams: map[string]any{"p0": 1, "pi0": 1}, + expectedParams: map[string]any{"pi0": 1}, }, "update node": { builder: v2.New().Where( @@ -137,7 +169,7 @@ func TestBackendParityPGTranslate(t *testing.T) { v2.SetProperty(v2.Node().Property("name"), "updated"), ), expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (update node n1 set properties = n1.properties || jsonb_build_object('name', @pi1::text)::jsonb from s0 where (s0.n0).id = n1.id returning (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n0) select 1;", - expectedParams: map[string]any{"p0": 1, "p1": "updated", "pi0": 1, "pi1": "updated"}, + expectedParams: map[string]any{"pi0": 1, "pi1": "updated"}, }, "delete relationship": { builder: v2.New().Where( @@ -145,8 +177,17 @@ func TestBackendParityPGTranslate(t *testing.T) { ).Delete( v2.Relationship(), ), - expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0, (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0, (n1.id, n1.kind_ids, n1.properties)::nodecomposite as n1 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.id = @pi0::int8)), s1 as (delete from edge e1 using s0 where (s0.e0).id = e1.id) select 1;", - expectedParams: map[string]any{"p0": 1, "pi0": 1}, + expectedSQL: "with s0 as (select (e0.id, e0.start_id, e0.end_id, e0.kind_id, e0.properties)::edgecomposite as e0 from edge e0 join node n0 on n0.id = e0.start_id join node n1 on n1.id = e0.end_id where (e0.id = @pi0::int8)), s1 as (delete from edge e1 using s0 where (s0.e0).id = e1.id) select 1;", + expectedParams: map[string]any{"pi0": 1}, + }, + "delete node": { + builder: v2.New().Where( + v2.Node().ID().Equals(1), + ).Delete( + v2.Node(), + ), + expectedSQL: "with s0 as (select (n0.id, n0.kind_ids, n0.properties)::nodecomposite as n0 from node n0 where (n0.id = @pi0::int8)), s1 as (delete from node n1 using s0 where (s0.n0).id = n1.id) select 1;", + expectedParams: map[string]any{"pi0": 1}, }, } @@ -155,7 +196,7 @@ func TestBackendParityPGTranslate(t *testing.T) { preparedQuery, err := testCase.builder.Build() require.NoError(t, err) - translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters) + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID) require.NoError(t, err) sql, err := translate.Translated(translation) @@ -166,7 +207,64 @@ func TestBackendParityPGTranslate(t *testing.T) { } } -func TestBackendParityPGCreateUnsupported(t *testing.T) { +func TestBackendParityPGTranslateShortestPaths(t *testing.T) { + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(edgeKind) + + cases := map[string]struct { + builder v2.QueryBuilder + expectedHarness string + }{ + "shortest path": { + builder: v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedHarness: "bidirectional_sp_harness", + }, + "all shortest paths": { + builder: v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + expectedHarness: "bidirectional_asp_harness", + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + require.Contains(t, sql, testCase.expectedHarness) + require.Contains(t, sql, "ordered_edges_to_path") + require.Contains(t, sql, "n0.id = 1") + require.Contains(t, sql, "n1.id = 2") + + serializedHarnessQueryHasKindConstraint := false + for _, parameterValue := range translation.Parameters { + if serializedQuery, typeOK := parameterValue.(string); typeOK && strings.Contains(serializedQuery, "array [1]::int2[]") { + serializedHarnessQueryHasKindConstraint = true + break + } + } + require.True(t, serializedHarnessQueryHasKindConstraint, "expected serialized shortest-path harness query to contain edge kind constraint: %#v", translation.Parameters) + }) + } +} + +func TestBackendParityPGCreate(t *testing.T) { edgeKind := graph.StringKind("MemberOf") mapper := testKindMapper(edgeKind) @@ -180,6 +278,12 @@ func TestBackendParityPGCreateUnsupported(t *testing.T) { ).Build() require.NoError(t, err) - _, err = translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters) - require.ErrorContains(t, err, "pgsql translator does not support create clauses") + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + require.Contains(t, sql, "insert into edge") + require.Contains(t, sql, "graph_id") + require.Contains(t, sql, "kind_id") } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 2c906c46..da3e80cd 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -187,26 +187,28 @@ func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { func TestShortestPathControls(t *testing.T) { preparedQuery, err := v2.New().WithShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), v2.Start().ID().Equals(1), v2.End().ID().Equals(2), ).Return( v2.Path(), ).Build() require.NoError(t, err) - require.Equal(t, "match p = shortestPath((s)-[*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + require.Equal(t, "match p = shortestPath((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) require.Equal(t, map[string]any{ "p0": 1, "p1": 2, }, preparedQuery.Parameters) preparedQuery, err = v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), v2.Start().ID().Equals(1), v2.End().ID().Equals(2), ).Return( v2.Path(), ).Build() require.NoError(t, err) - require.Equal(t, "match p = allShortestPaths((s)-[*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + require.Equal(t, "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) _, err = v2.New().WithShortestPaths().WithAllShortestPaths().Where( v2.Start().ID().Equals(1), @@ -217,6 +219,16 @@ func TestShortestPathControls(t *testing.T) { require.ErrorContains(t, err, "query is requesting both all shortest paths and shortest paths") } +func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) { + _, err := v2.New().Where( + v2.Node().ID().Equals(1), + v2.Relationship().ID().Equals(2), + ).Return( + v2.Node(), + ).Build() + require.ErrorContains(t, err, "query mixes node and relationship query identifiers") +} + func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) { _, err := v2.New().Create( v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth), @@ -317,6 +329,26 @@ func TestRawReturnInputPreservesProjectionMetadata(t *testing.T) { require.Equal(t, "match (n) return distinct id(n) order by n.name desc skip 5 limit 10", renderPrepared(t, preparedQuery)) } +func TestRawReturnInputMergesWithBuilderProjectionControls(t *testing.T) { + returnClause := cypher.NewReturn() + projection := returnClause.NewProjection(true) + projection.Items = append(projection.Items, v2.Node().ID()) + projection.Order = &cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + } + projection.Skip = cypher.NewSkip(5) + projection.Limit = cypher.NewLimit(10) + + preparedQuery, err := v2.New().Return(returnClause).OrderBy( + v2.Asc(v2.Node().Property("created_at")), + ).Skip(15).Limit(20).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return distinct id(n) order by n.name desc, n.created_at asc skip 15 limit 20", renderPrepared(t, preparedQuery)) +} + func TestRawUpdatingInputsAreValidated(t *testing.T) { var setClause *cypher.Set _, err := v2.New().Update(setClause).Build() diff --git a/query/v2/util.go b/query/v2/util.go index c15c6400..ab003dc1 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -619,14 +619,6 @@ func (s *identifierSet) CollectFromValue(value any) error { case QualifiedExpression: return s.CollectFromExpression(typedValue.qualifier()) - case kindContinuation: - s.Add(typedValue.identifier.Symbol) - return nil - - case kindsContinuation: - s.Add(typedValue.identifier.Symbol) - return nil - case *cypher.Return: if projectionItems, err := projectionItemsFromReturn(typedValue); err != nil { return err From d36a214b3681cd1880c95526a5a059ca11e5435b Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 18:34:35 -0700 Subject: [PATCH 26/55] fix(query/v2): copy caller AST before build --- query/v2/query.go | 23 +++++++++++++---------- query/v2/query_test.go | 35 +++++++++++++++++++++++++++++++++++ query/v2/util.go | 40 ++++++++++++++++++++++++++++++++++------ 3 files changed, 82 insertions(+), 16 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index a388140c..19ceb3c7 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -829,8 +829,9 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { pendingDetachDelete = true } - s.deleteItems = append(s.deleteItems, qualifier) - pendingDeleteItems = append(pendingDeleteItems, qualifier) + deleteItem := copyExpression(qualifier) + s.deleteItems = append(s.deleteItems, deleteItem) + pendingDeleteItems = append(pendingDeleteItems, deleteItem) case *cypher.Variable: if err := validateExpressionValue(typedNextUpdate, "delete expression"); err != nil { @@ -844,8 +845,9 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { pendingDetachDelete = true } - s.deleteItems = append(s.deleteItems, typedNextUpdate) - pendingDeleteItems = append(pendingDeleteItems, typedNextUpdate) + deleteItem := copyExpression(typedNextUpdate) + s.deleteItems = append(s.deleteItems, deleteItem) + pendingDeleteItems = append(pendingDeleteItems, deleteItem) default: s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete)) @@ -906,7 +908,7 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId return err } - pattern.AddPatternElements(typedNextCreate) + pattern.AddPatternElements(cypher.Copy(typedNextCreate)) case *cypher.RelationshipPattern: if err := validateRelationshipPattern(typedNextCreate); err != nil { @@ -917,7 +919,7 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId Variable: identifiers.Start(), }) - pattern.AddPatternElements(typedNextCreate) + pattern.AddPatternElements(cypher.Copy(typedNextCreate)) pattern.AddPatternElements(&cypher.NodePattern{ Variable: identifiers.End(), @@ -1031,11 +1033,11 @@ func applyReturnProjection(projection *cypher.Projection, returnClause *cypher.R } if returnClause.Projection.Skip != nil { - projection.Skip = returnClause.Projection.Skip + projection.Skip = copySkip(returnClause.Projection.Skip) } if returnClause.Projection.Limit != nil { - projection.Limit = returnClause.Projection.Limit + projection.Limit = copyLimit(returnClause.Projection.Limit) } return nil @@ -1154,10 +1156,11 @@ func (s *builder) Build() (*PreparedQuery, error) { } } + constraintCopy := cypher.Copy(nextConstraint) if constraints.Left == nil { - constraints.Left = nextConstraint + constraints.Left = constraintCopy } else { - constraints.NewPartialComparison(cypher.OperatorAnd, nextConstraint) + constraints.NewPartialComparison(cypher.OperatorAnd, constraintCopy) } } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index da3e80cd..fa5bf7e6 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -449,6 +449,41 @@ func TestNamedParameterMaterialization(t *testing.T) { require.ErrorContains(t, err, "parameter same is bound to multiple values") } +func TestBuildDoesNotMutateCallerOwnedAST(t *testing.T) { + constraint := v2.Node().Property("name").Equals("alice") + constraintParameter := constraint.(*cypher.Comparison).FirstPartial().Right.(*cypher.Parameter) + + preparedQuery, err := v2.New().Where(constraint).Return(v2.Node()).Build() + require.NoError(t, err) + require.Equal(t, map[string]any{"p0": "alice"}, preparedQuery.Parameters) + require.Empty(t, constraintParameter.Symbol) + + setItem := v2.SetProperty(v2.Node().Property("status"), "active") + setParameter := setItem.Right.(*cypher.Parameter) + + preparedQuery, err = v2.New().Where(v2.Node().ID().Equals(1)).Update(setItem).Build() + require.NoError(t, err) + require.Equal(t, map[string]any{"p0": 1, "p1": "active"}, preparedQuery.Parameters) + require.Empty(t, setParameter.Symbol) + + createPattern := v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.Parameter(map[string]any{"name": "node"})) + createParameter := createPattern.Properties.(*cypher.Parameter) + + preparedQuery, err = v2.New().Create(createPattern).Build() + require.NoError(t, err) + require.Equal(t, map[string]any{"p0": map[string]any{"name": "node"}}, preparedQuery.Parameters) + require.Empty(t, createParameter.Symbol) + + rawReturn := cypher.NewReturn() + rawReturn.NewProjection(false).AddItem(cypher.NewProjectionItemWithExpr(v2.Parameter("projected"))) + rawReturnParameter := rawReturn.Projection.Items[0].(*cypher.ProjectionItem).Expression.(*cypher.Parameter) + + preparedQuery, err = v2.New().Return(rawReturn).Build() + require.NoError(t, err) + require.Equal(t, map[string]any{"p0": "projected"}, preparedQuery.Parameters) + require.Empty(t, rawReturnParameter.Symbol) +} + func TestCompatibilityHelpers(t *testing.T) { preparedQuery, err := v2.New().Where( v2.And( diff --git a/query/v2/util.go b/query/v2/util.go index ab003dc1..1801a489 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -308,6 +308,34 @@ func projectionExpression(value any) (cypher.Expression, error) { } } +func copyExpression(expression cypher.Expression) cypher.Expression { + return cypher.Copy(expression) +} + +func copyProjectionItem(item *cypher.ProjectionItem) *cypher.ProjectionItem { + return cypher.Copy(item) +} + +func copySortItem(item *cypher.SortItem) *cypher.SortItem { + return cypher.Copy(item) +} + +func copySetItem(item *cypher.SetItem) *cypher.SetItem { + return cypher.Copy(item) +} + +func copyRemoveItem(item *cypher.RemoveItem) *cypher.RemoveItem { + return cypher.Copy(item) +} + +func copySkip(skip *cypher.Skip) *cypher.Skip { + return cypher.Copy(skip) +} + +func copyLimit(limit *cypher.Limit) *cypher.Limit { + return cypher.Copy(limit) +} + func validateExpressionValue(expression cypher.Expression, context string) error { if isNilPointer(expression) { return fmt.Errorf("%s has nil expression", context) @@ -346,7 +374,7 @@ func setItemFromValue(setItem *cypher.SetItem) (*cypher.SetItem, error) { return nil, err } - return setItem, nil + return copySetItem(setItem), nil } func setItemsFromSet(setClause *cypher.Set) ([]*cypher.SetItem, error) { @@ -397,7 +425,7 @@ func removeItemFromValue(removeItem *cypher.RemoveItem) (*cypher.RemoveItem, err return nil, err } - return removeItem, nil + return copyRemoveItem(removeItem), nil } func removeItemsFromRemove(removeClause *cypher.Remove) ([]*cypher.RemoveItem, error) { @@ -458,13 +486,13 @@ func projectionItemFromValue(value any) (*cypher.ProjectionItem, error) { return nil, err } - return projectionItem, nil + return copyProjectionItem(projectionItem), nil } if expression, err := projectionExpression(value); err != nil { return nil, err } else { - return cypher.NewProjectionItemWithExpr(expression), nil + return cypher.NewProjectionItemWithExpr(copyExpression(expression)), nil } } @@ -482,7 +510,7 @@ func sortItemFromValue(value any) (*cypher.SortItem, error) { return nil, err } - return sortItem, nil + return copySortItem(sortItem), nil } if expression, err := projectionExpression(value); err != nil { @@ -490,7 +518,7 @@ func sortItemFromValue(value any) (*cypher.SortItem, error) { } else { return &cypher.SortItem{ Ascending: true, - Expression: expression, + Expression: copyExpression(expression), }, nil } } From 5ad5ad3687122afac023a30ac34ce5e1ac591582 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 18:35:27 -0700 Subject: [PATCH 27/55] fix(query/v2): support explicit create endpoints --- query/v2/query.go | 49 ++++++++++++++++++++++++++++++++++++------ query/v2/query_test.go | 21 ++++++++++++++++++ 2 files changed, 63 insertions(+), 7 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 19ceb3c7..e45ccf7e 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -867,6 +867,37 @@ func (s *builder) Where(constraints ...cypher.SyntaxNode) QueryBuilder { return s } +func patternEndsWithNodePattern(pattern *cypher.PatternPart) bool { + numElements := len(pattern.PatternElements) + if numElements == 0 { + return false + } + + return pattern.PatternElements[numElements-1].IsNodePattern() +} + +func isCreateNodeValue(value any, identifiers runtimeIdentifiers) bool { + switch typedValue := value.(type) { + case QualifiedExpression: + if variable, typeOK := typedValue.qualifier().(*cypher.Variable); typeOK { + switch variable.Symbol { + case identifiers.node, identifiers.start, identifiers.end: + return true + } + } + + case *cypher.NodePattern: + return typedValue != nil + } + + return false +} + +func nextCreateValueIsNode(creates []any, idx int, identifiers runtimeIdentifiers) bool { + nextIdx := idx + 1 + return nextIdx < len(creates) && isCreateNodeValue(creates[nextIdx], identifiers) +} + func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeIdentifiers, creates []any) error { if len(creates) == 0 { return nil @@ -880,7 +911,7 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId } ) - for _, nextCreate := range creates { + for idx, nextCreate := range creates { switch typedNextCreate := nextCreate.(type) { case QualifiedExpression: switch typedExpression := typedNextCreate.qualifier().(type) { @@ -915,15 +946,19 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId return err } - pattern.AddPatternElements(&cypher.NodePattern{ - Variable: identifiers.Start(), - }) + if !patternEndsWithNodePattern(pattern) { + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.Start(), + }) + } pattern.AddPatternElements(cypher.Copy(typedNextCreate)) - pattern.AddPatternElements(&cypher.NodePattern{ - Variable: identifiers.End(), - }) + if !nextCreateValueIsNode(creates, idx, identifiers) { + pattern.AddPatternElements(&cypher.NodePattern{ + Variable: identifiers.End(), + }) + } default: return fmt.Errorf("invalid type for create: %T", nextCreate) diff --git a/query/v2/query_test.go b/query/v2/query_test.go index fa5bf7e6..acaccc6b 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -81,6 +81,27 @@ func TestCreateRelationshipWithMatchedEndpoints(t *testing.T) { }, preparedQuery.Parameters) } +func TestCreateRelationshipWithExplicitEndpoints(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.Start(), + v2.RelationshipPattern(graph.StringKind("A"), v2.NamedParameter("props", map[string]any{"name": "rel"}), graph.DirectionOutbound), + v2.End(), + ).Return( + v2.Relationship().ID(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (s), (e) where id(s) = $p0 and id(e) = $p1 create (s)-[r:A $props]->(e) return id(r)", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "p0": 1, + "p1": 2, + "props": map[string]any{"name": "rel"}, + }, preparedQuery.Parameters) +} + func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { preparedQuery, err := v2.New().Create( v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, v2.NamedParameter("props", map[string]any{"name": "node"})), From 2d82eb8ca46c9321227f5a3ac60a7147b8e4e77e Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 18:35:58 -0700 Subject: [PATCH 28/55] fix(query/v2): validate pagination bounds --- query/v2/query.go | 18 ++++++++++++++---- query/v2/query_test.go | 12 ++++++++++++ 2 files changed, 26 insertions(+), 4 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index e45ccf7e..f4e32d84 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -685,11 +685,21 @@ func (s *builder) OrderBy(sortItems ...any) QueryBuilder { } func (s *builder) Skip(skip int) QueryBuilder { + if skip < 0 { + s.trackError(fmt.Errorf("skip must be non-negative: %d", skip)) + return s + } + s.skip = &skip return s } func (s *builder) Limit(limit int) QueryBuilder { + if limit < 0 { + s.trackError(fmt.Errorf("limit must be non-negative: %d", limit)) + return s + } + s.limit = &limit return s } @@ -1081,8 +1091,8 @@ func applyReturnProjection(projection *cypher.Projection, returnClause *cypher.R func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error { var ( hasProjectedItems = len(s.projections) > 0 - hasSkip = s.skip != nil && *s.skip > 0 - hasLimit = s.limit != nil && *s.limit > 0 + hasSkip = s.skip != nil + hasLimit = s.limit != nil requiresProjection = hasProjectedItems || hasSkip || hasLimit ) @@ -1109,11 +1119,11 @@ func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error } } - if s.skip != nil && *s.skip > 0 { + if s.skip != nil { projection.Skip = cypher.NewSkip(*s.skip) } - if s.limit != nil && *s.limit > 0 { + if s.limit != nil { projection.Limit = cypher.NewLimit(*s.limit) } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index acaccc6b..1c302480 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -274,6 +274,18 @@ func TestProjectionAndOrderHelpers(t *testing.T) { require.Equal(t, "match (n) return distinct id(n) as node_id order by n.name asc, id(n) desc", renderPrepared(t, preparedQuery)) } +func TestPaginationZeroValuesAndNegativeValidation(t *testing.T) { + preparedQuery, err := v2.New().Return(v2.Node()).Skip(0).Limit(0).Build() + require.NoError(t, err) + require.Equal(t, "match (n) return n skip 0 limit 0", renderPrepared(t, preparedQuery)) + + _, err = v2.New().Return(v2.Node()).Skip(-1).Build() + require.ErrorContains(t, err, "skip must be non-negative: -1") + + _, err = v2.New().Return(v2.Node()).Limit(-1).Build() + require.ErrorContains(t, err, "limit must be non-negative: -1") +} + func TestProjectionAliasDoesNotCreateMatchInference(t *testing.T) { preparedQuery, err := v2.New().Return( v2.As(v2.Literal(1), "one"), From 9192bed6f9e6dc209361854a90aff722d07a6702 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 18:36:20 -0700 Subject: [PATCH 29/55] fix(query/v2): reject empty logical helpers --- query/v2/query.go | 4 ++++ query/v2/query_test.go | 8 ++++++++ 2 files changed, 12 insertions(+) diff --git a/query/v2/query.go b/query/v2/query.go index f4e32d84..5b9960f9 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -214,6 +214,10 @@ func valueExpression(value any) cypher.Expression { } func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) cypher.SyntaxNode { + if len(operands) == 0 { + return invalidExpression(fmt.Errorf("%s requires at least one operand", operator)) + } + expressionList := &cypher.Comparison{} if len(operands) > 0 { diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 1c302480..6f0e3b39 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -250,6 +250,14 @@ func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) { require.ErrorContains(t, err, "query mixes node and relationship query identifiers") } +func TestEmptyLogicalHelpersReturnBuildErrors(t *testing.T) { + _, err := v2.New().Where(v2.And()).Return(v2.Node()).Build() + require.ErrorContains(t, err, "and requires at least one operand") + + _, err = v2.New().Where(v2.Or()).Return(v2.Node()).Build() + require.ErrorContains(t, err, "or requires at least one operand") +} + func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) { _, err := v2.New().Create( v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth), From ecbd423b1146417c904c1ea91f597b2908fa541b Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 18:36:47 -0700 Subject: [PATCH 30/55] fix(query/v2): require relationship shortest paths --- query/v2/query.go | 8 ++++++++ query/v2/query_test.go | 6 ++++++ 2 files changed, 14 insertions(+) diff --git a/query/v2/query.go b/query/v2/query.go index 5b9960f9..f6849b4b 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -1150,6 +1150,10 @@ func (s *builder) hasActions() bool { return len(s.projections) > 0 || len(s.setItems) > 0 || len(s.removeItems) > 0 || len(s.creates) > 0 || len(s.deleteItems) > 0 } +func (s *builder) wantsShortestPathPattern() bool { + return s.shortestPathQuery || s.allShorestPathsQuery +} + func (s *builder) Build() (*PreparedQuery, error) { if len(s.errors) > 0 { return nil, errors.Join(s.errors...) @@ -1232,6 +1236,10 @@ func (s *builder) Build() (*PreparedQuery, error) { matchIdentifiers := readIdentifiers.Clone() matchIdentifiers.Or(actionIdentifiers) + if s.wantsShortestPathPattern() && !isRelationshipPattern(matchIdentifiers, s.identifiers) { + return nil, fmt.Errorf("shortest path query requires relationship query identifiers") + } + if len(s.constraints) > 0 || matchIdentifiers.Len() > 0 { if isNodePattern(matchIdentifiers, s.identifiers) { if err := prepareNodePattern(match, matchIdentifiers, s.identifiers); err != nil { diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 6f0e3b39..2d09eec2 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -238,6 +238,12 @@ func TestShortestPathControls(t *testing.T) { v2.Path(), ).Build() require.ErrorContains(t, err, "query is requesting both all shortest paths and shortest paths") + + _, err = v2.New().WithShortestPaths().Return(v2.Node()).Build() + require.ErrorContains(t, err, "shortest path query requires relationship query identifiers") + + _, err = v2.New().WithAllShortestPaths().Return(v2.As(v2.Literal(1), "one")).Build() + require.ErrorContains(t, err, "shortest path query requires relationship query identifiers") } func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) { From b3a26cc353610e55e670796e93c7c78edda8f11f Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 18:38:01 -0700 Subject: [PATCH 31/55] test(query/v2): cover legacy query parity --- query/v2/legacy_parity_test.go | 184 +++++++++++++++++++++++++++++++++ 1 file changed, 184 insertions(+) create mode 100644 query/v2/legacy_parity_test.go diff --git a/query/v2/legacy_parity_test.go b/query/v2/legacy_parity_test.go new file mode 100644 index 00000000..0a6f35bd --- /dev/null +++ b/query/v2/legacy_parity_test.go @@ -0,0 +1,184 @@ +package v2_test + +import ( + "testing" + + "github.com/specterops/dawgs/cypher/models/cypher" + "github.com/specterops/dawgs/graph" + legacyquery "github.com/specterops/dawgs/query" + "github.com/specterops/dawgs/query/neo4j" + v2 "github.com/specterops/dawgs/query/v2" + "github.com/stretchr/testify/require" +) + +func renderNeo4jQuery(t *testing.T, regularQuery *cypher.RegularQuery, prepareAllShortestPaths bool) (string, map[string]any) { + t.Helper() + + queryBuilder := neo4j.NewQueryBuilder(regularQuery) + + if prepareAllShortestPaths { + require.NoError(t, queryBuilder.PrepareAllShortestPaths()) + } else { + require.NoError(t, queryBuilder.Prepare()) + } + + rendered, err := queryBuilder.Render() + require.NoError(t, err) + + return rendered, queryBuilder.Parameters +} + +func assertLegacyNeo4jParity(t *testing.T, legacyQuery *cypher.RegularQuery, v2Builder v2.QueryBuilder, prepareLegacyAllShortestPaths bool) { + t.Helper() + + preparedQuery, err := v2Builder.Build() + require.NoError(t, err) + + legacyRendered, legacyParameters := renderNeo4jQuery(t, legacyQuery, prepareLegacyAllShortestPaths) + v2Rendered, v2Parameters := renderNeo4jQuery(t, preparedQuery.Query, false) + + require.Equal(t, legacyRendered, v2Rendered) + require.Equal(t, legacyParameters, v2Parameters) +} + +func TestLegacyNeo4jParity(t *testing.T) { + userKind := graph.StringKind("User") + edgeKind := graph.StringKind("MemberOf") + + t.Run("node count by kind", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.KindIn(legacyquery.Node(), userKind), + ), + legacyquery.Returning( + legacyquery.Count(legacyquery.Node()), + ), + ), + v2.New().Where( + v2.Node().Kinds().Has(userKind), + ).Return( + v2.Node().Count(), + ), + false, + ) + }) + + t.Run("node read with pagination", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.StringContains(legacyquery.NodeProperty("name"), "admin"), + legacyquery.IsNotNull(legacyquery.NodeProperty("enabled")), + ), + ), + legacyquery.Returning( + legacyquery.Node(), + legacyquery.OrderBy(legacyquery.Order(legacyquery.NodeProperty("name"), legacyquery.Ascending())), + legacyquery.Offset(0), + legacyquery.Limit(0), + ), + ), + v2.New().Where( + v2.And( + v2.Node().Property("name").Contains("admin"), + v2.Node().Property("enabled").IsNotNull(), + ), + ).Return( + v2.Node(), + ).OrderBy( + v2.Asc(v2.Node().Property("name")), + ).Skip(0).Limit(0), + false, + ) + }) + + t.Run("relationship read", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.KindIn(legacyquery.Relationship(), edgeKind), + legacyquery.Equals(legacyquery.StartID(), 1), + legacyquery.Equals(legacyquery.EndID(), 2), + ), + ), + legacyquery.Returning( + legacyquery.StartID(), + legacyquery.RelationshipID(), + legacyquery.EndID(), + ), + ), + v2.New().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Start().ID(), + v2.Relationship().ID(), + v2.End().ID(), + ), + false, + ) + }) + + t.Run("create relationship with matched endpoints", func(t *testing.T) { + properties := map[string]any{"name": "rel"} + + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.Equals(legacyquery.StartID(), 1), + legacyquery.Equals(legacyquery.EndID(), 2), + ), + ), + legacyquery.Create( + legacyquery.Start(), + legacyquery.RelationshipPattern(edgeKind, legacyquery.Parameter(properties), graph.DirectionOutbound), + legacyquery.End(), + ), + legacyquery.Returning( + legacyquery.RelationshipID(), + ), + ), + v2.New().Where( + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Create( + v2.Start(), + v2.RelationshipPattern(edgeKind, v2.Parameter(properties), graph.DirectionOutbound), + v2.End(), + ).Return( + v2.Relationship().ID(), + ), + false, + ) + }) + + t.Run("all shortest paths", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.KindIn(legacyquery.Relationship(), edgeKind), + legacyquery.Equals(legacyquery.StartID(), 1), + legacyquery.Equals(legacyquery.EndID(), 2), + ), + ), + legacyquery.Returning( + legacyquery.Path(), + ), + ), + v2.New().WithAllShortestPaths().Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ), + true, + ) + }) +} From 1749b2957e19156cdb7b733287b1a16b08a93b51 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 19:25:45 -0700 Subject: [PATCH 32/55] Fix query v2 aliased projection inference --- query/v2/query_test.go | 9 +++++++++ query/v2/util.go | 7 +++++++ 2 files changed, 16 insertions(+) diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 2d09eec2..3b9b603c 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -310,6 +310,15 @@ func TestProjectionAliasDoesNotCreateMatchInference(t *testing.T) { require.Empty(t, preparedQuery.Parameters) } +func TestAliasedProjectionCreatesMatchInference(t *testing.T) { + preparedQuery, err := v2.New().Return( + v2.As(v2.Node().ID(), "node_id"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return id(n) as node_id", renderPrepared(t, preparedQuery)) +} + func TestInvalidProjectionAliasReturnsBuildError(t *testing.T) { _, err := v2.New().Return(v2.As(v2.Literal(1), "bad alias")).Build() require.ErrorContains(t, err, `projection alias has invalid symbol "bad alias"`) diff --git a/query/v2/util.go b/query/v2/util.go index 1801a489..444f3a99 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -647,6 +647,13 @@ func (s *identifierSet) CollectFromValue(value any) error { case QualifiedExpression: return s.CollectFromExpression(typedValue.qualifier()) + case *cypher.ProjectionItem: + if projectionItem, err := projectionItemFromValue(typedValue); err != nil { + return err + } else { + return s.CollectFromExpression(projectionItem) + } + case *cypher.Return: if projectionItems, err := projectionItemsFromReturn(typedValue); err != nil { return err From e6a1999bf0807f1528b684c4350cb3bba9dd3e7c Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 19:26:24 -0700 Subject: [PATCH 33/55] Fix query v2 qualified value expressions --- query/v2/query.go | 6 ++++++ query/v2/query_test.go | 19 +++++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/query/v2/query.go b/query/v2/query.go index f6849b4b..98d9a07f 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -208,6 +208,12 @@ func valueExpression(value any) cypher.Expression { return typedValue case *cypher.IDInCollection: return typedValue + case QualifiedExpression: + if expression, err := projectionExpression(typedValue); err != nil { + return invalidExpression(err) + } else { + return expression + } default: return Parameter(value) } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 3b9b603c..aa7879ef 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -505,6 +505,25 @@ func TestNamedParameterMaterialization(t *testing.T) { require.ErrorContains(t, err, "parameter same is bound to multiple values") } +func TestQualifiedExpressionValuesUseProjectionSemantics(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Node().Property("copy").Equals(v2.Node().Property("source")), + v2.Node().Property("kinds").Equals(v2.Node().Kinds()), + ).Return( + v2.Node(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match (n) where n.copy = n.source and n.kinds = labels(n) return n", renderPrepared(t, preparedQuery)) + + preparedQuery, err = v2.New().Where( + v2.Relationship().Property("kind").Equals(v2.Relationship().Kind()), + ).Return( + v2.Relationship(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match ()-[r]->() where r.kind = type(r) return r", renderPrepared(t, preparedQuery)) +} + func TestBuildDoesNotMutateCallerOwnedAST(t *testing.T) { constraint := v2.Node().Property("name").Equals("alice") constraintParameter := constraint.(*cypher.Comparison).FirstPartial().Right.(*cypher.Parameter) From 638104615078129933d566aa11ddf467fb188499 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 19:27:36 -0700 Subject: [PATCH 34/55] Validate query v2 delete targets --- query/v2/query.go | 55 +++++++++++++++++++++++++++++++++++------- query/v2/query_test.go | 16 ++++++++++++ 2 files changed, 62 insertions(+), 9 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 98d9a07f..e9f9b7f1 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -308,6 +308,12 @@ type scopedExpression interface { roleName() string } +type deleteTarget interface { + QualifiedExpression + + deleteTarget() +} + type EntityContinuation interface { QualifiedExpression @@ -495,6 +501,8 @@ func (s *entity[T]) qualifier() cypher.Expression { return s.identifier } +func (s *entity[T]) deleteTarget() {} + func (s *entity[T]) roleName() string { return s.role } @@ -837,38 +845,54 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { for _, nextDelete := range deleteItems { switch typedNextUpdate := nextDelete.(type) { - case QualifiedExpression: - qualifier := typedNextUpdate.qualifier() - if err := validateExpressionValue(qualifier, "delete expression"); err != nil { + case deleteTarget: + if isNilPointer(typedNextUpdate) { + s.trackError(fmt.Errorf("delete target is nil")) + continue + } + + deleteItem, detach, err := deleteItemFromExpression(typedNextUpdate.qualifier(), s.identifiers) + if err != nil { s.trackError(err) continue } - if isDetachDeleteQualifier(qualifier, s.identifiers) { + if detach { s.detachDelete = true pendingDetachDelete = true } - deleteItem := copyExpression(qualifier) s.deleteItems = append(s.deleteItems, deleteItem) pendingDeleteItems = append(pendingDeleteItems, deleteItem) case *cypher.Variable: - if err := validateExpressionValue(typedNextUpdate, "delete expression"); err != nil { + deleteItem, detach, err := deleteItemFromExpression(typedNextUpdate, s.identifiers) + if err != nil { s.trackError(err) continue } - switch typedNextUpdate.Symbol { - case s.identifiers.node, s.identifiers.start, s.identifiers.end: + if detach { s.detachDelete = true pendingDetachDelete = true } - deleteItem := copyExpression(typedNextUpdate) s.deleteItems = append(s.deleteItems, deleteItem) pendingDeleteItems = append(pendingDeleteItems, deleteItem) + case QualifiedExpression: + if isNilPointer(typedNextUpdate) { + s.trackError(fmt.Errorf("delete target is nil")) + continue + } + + if err := validateExpressionValue(typedNextUpdate.qualifier(), "delete expression"); err != nil { + s.trackError(err) + continue + } + + s.trackError(fmt.Errorf("delete target must be an entity, path, or variable; got %T", nextDelete)) + default: s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete)) } @@ -878,6 +902,19 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { return s } +func deleteItemFromExpression(expression cypher.Expression, identifiers runtimeIdentifiers) (cypher.Expression, bool, error) { + if err := validateExpressionValue(expression, "delete expression"); err != nil { + return nil, false, err + } + + variable, typeOK := expression.(*cypher.Variable) + if !typeOK || variable == nil { + return nil, false, fmt.Errorf("delete target must resolve to a variable, got %T", expression) + } + + return copyExpression(variable), isDetachDeleteQualifier(variable, identifiers), nil +} + func (s *builder) trackError(err error) { s.errors = append(s.errors, err) } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index aa7879ef..5d679bc3 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -439,6 +439,22 @@ func TestRawUpdatingInputsAreValidated(t *testing.T) { require.ErrorContains(t, err, "relationship pattern is nil") } +func TestDeleteRejectsNonTargetQualifiedExpressions(t *testing.T) { + cases := map[string]any{ + "property": v2.Node().Property("name"), + "id": v2.Node().ID(), + "kinds": v2.Node().Kinds(), + "kind": v2.Relationship().Kind(), + } + + for name, target := range cases { + t.Run(name, func(t *testing.T) { + _, err := v2.New().Delete(target).Build() + require.ErrorContains(t, err, "delete target must be an entity, path, or variable") + }) + } +} + func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) { cases := map[string]struct { builder v2.QueryBuilder From a7db069bc24f5dbcab3f584c82d0d6f89d5e33f1 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 19:37:58 -0700 Subject: [PATCH 35/55] Preserve query v2 relationship kind conjunctions --- query/v2/query.go | 28 ++++++++++++++++++++++++---- query/v2/query_test.go | 12 ++++++++++++ 2 files changed, 36 insertions(+), 4 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index e9f9b7f1..44b5d277 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -1184,6 +1184,22 @@ func (s *builder) buildProjection(singlePartQuery *cypher.SinglePartQuery) error return nil } +func countRelationshipKindMatchers(constraints []cypher.SyntaxNode, identifiers runtimeIdentifiers) (int, error) { + var count int + + for _, nextConstraint := range constraints { + if kindMatcher, typeOK := nextConstraint.(*cypher.KindMatcher); typeOK { + if identifier, typeOK := kindMatcher.Reference.(*cypher.Variable); !typeOK { + return 0, fmt.Errorf("expected type *cypher.Variable, got %T", kindMatcher.Reference) + } else if identifier.Symbol == identifiers.relationship { + count++ + } + } + } + + return count, nil +} + type PreparedQuery struct { Query *cypher.RegularQuery Parameters map[string]any @@ -1232,9 +1248,13 @@ func (s *builder) Build() (*PreparedQuery, error) { if len(s.constraints) > 0 { var ( - whereClause = match.NewWhere() - constraints = &cypher.Comparison{} + whereClause = match.NewWhere() + constraints = &cypher.Comparison{} + numRelationshipKindMatchers, err = countRelationshipKindMatchers(s.constraints, s.identifiers) ) + if err != nil { + return nil, err + } for _, nextConstraint := range s.constraints { if err := collectModelErrorsFromKnownValues(nextConstraint); err != nil { @@ -1244,8 +1264,8 @@ func (s *builder) Build() (*PreparedQuery, error) { switch typedNextConstraint := nextConstraint.(type) { case *cypher.KindMatcher: if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK { - return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint) - } else if identifier.Symbol == s.identifiers.relationship { + return nil, fmt.Errorf("expected type *cypher.Variable, got %T", typedNextConstraint.Reference) + } else if identifier.Symbol == s.identifiers.relationship && numRelationshipKindMatchers == 1 { relationshipKinds = relationshipKinds.Add(typedNextConstraint.Kinds...) readIdentifiers.Add(s.identifiers.relationship) continue diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 5d679bc3..1a0f711e 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -256,6 +256,18 @@ func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) { require.ErrorContains(t, err, "query mixes node and relationship query identifiers") } +func TestMultipleRelationshipKindMatchersRemainConjunctive(t *testing.T) { + preparedQuery, err := v2.New().Where( + v2.Relationship().Kind().Is(graph.StringKind("A")), + v2.Relationship().Kind().Is(graph.StringKind("B")), + ).Return( + v2.Relationship(), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match ()-[r]->() where r:A and r:B return r", renderPrepared(t, preparedQuery)) +} + func TestEmptyLogicalHelpersReturnBuildErrors(t *testing.T) { _, err := v2.New().Where(v2.And()).Return(v2.Node()).Build() require.ErrorContains(t, err, "and requires at least one operand") From 72e33ee9e2f1fd6484d9fdcb6d4f572f5de9c138 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 19:38:24 -0700 Subject: [PATCH 36/55] Infer query v2 raw return metadata identifiers --- query/v2/query_test.go | 16 ++++++++++++++++ query/v2/util.go | 30 ++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 1a0f711e..a1fb6b69 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -397,6 +397,22 @@ func TestRawReturnInputPreservesProjectionMetadata(t *testing.T) { require.Equal(t, "match (n) return distinct id(n) order by n.name desc skip 5 limit 10", renderPrepared(t, preparedQuery)) } +func TestRawReturnMetadataCreatesMatchInference(t *testing.T) { + returnClause := cypher.NewReturn() + projection := returnClause.NewProjection(false) + projection.Items = append(projection.Items, v2.Literal(1)) + projection.Order = &cypher.Order{ + Items: []*cypher.SortItem{ + v2.Desc(v2.Node().Property("name")), + }, + } + + preparedQuery, err := v2.New().Return(returnClause).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return 1 order by n.name desc", renderPrepared(t, preparedQuery)) +} + func TestRawReturnInputMergesWithBuilderProjectionControls(t *testing.T) { returnClause := cypher.NewReturn() projection := returnClause.NewProjection(true) diff --git a/query/v2/util.go b/query/v2/util.go index 444f3a99..fb4c360c 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -665,6 +665,10 @@ func (s *identifierSet) CollectFromValue(value any) error { } } + if err := s.CollectFromProjectionMetadata(typedValue.Projection); err != nil { + return err + } + case *cypher.Order: if sortItems, err := sortItemsFromOrder(typedValue); err != nil { return err @@ -786,6 +790,32 @@ func (s *identifierSet) CollectFromValue(value any) error { return nil } +func (s *identifierSet) CollectFromProjectionMetadata(projection *cypher.Projection) error { + if projection == nil { + return nil + } + + if projection.Order != nil { + if err := s.CollectFromValue(projection.Order); err != nil { + return err + } + } + + if projection.Skip != nil { + if err := s.CollectFromExpression(projection.Skip.Value); err != nil { + return err + } + } + + if projection.Limit != nil { + if err := s.CollectFromExpression(projection.Limit.Value); err != nil { + return err + } + } + + return nil +} + func collectIdentifiersFromValues(values ...any) (*identifierSet, error) { identifiers := newIdentifierSet() From 5aabfb647e286120e0cfa12e0019b31183d910c8 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 19:38:44 -0700 Subject: [PATCH 37/55] Reject query v2 path delete targets --- query/v2/query.go | 6 +++++- query/v2/query_test.go | 10 +++++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 44b5d277..5438f91e 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -891,7 +891,7 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { continue } - s.trackError(fmt.Errorf("delete target must be an entity, path, or variable; got %T", nextDelete)) + s.trackError(fmt.Errorf("delete target must be a node, relationship, or variable; got %T", nextDelete)) default: s.trackError(fmt.Errorf("unknown delete type: %T", nextDelete)) @@ -912,6 +912,10 @@ func deleteItemFromExpression(expression cypher.Expression, identifiers runtimeI return nil, false, fmt.Errorf("delete target must resolve to a variable, got %T", expression) } + if variable.Symbol == identifiers.path { + return nil, false, fmt.Errorf("delete target must be a node or relationship variable, got path variable %q", variable.Symbol) + } + return copyExpression(variable), isDetachDeleteQualifier(variable, identifiers), nil } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index a1fb6b69..20169362 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -478,11 +478,19 @@ func TestDeleteRejectsNonTargetQualifiedExpressions(t *testing.T) { for name, target := range cases { t.Run(name, func(t *testing.T) { _, err := v2.New().Delete(target).Build() - require.ErrorContains(t, err, "delete target must be an entity, path, or variable") + require.ErrorContains(t, err, "delete target must be a node, relationship, or variable") }) } } +func TestDeleteRejectsPathTargets(t *testing.T) { + _, err := v2.New().Delete(v2.Path()).Build() + require.ErrorContains(t, err, `delete target must be a node or relationship variable, got path variable "p"`) + + _, err = v2.New().Delete(v2.Variable("p")).Build() + require.ErrorContains(t, err, `delete target must be a node or relationship variable, got path variable "p"`) +} + func TestInvalidHelperInputsReturnBuildErrors(t *testing.T) { cases := map[string]struct { builder v2.QueryBuilder From 994a1fc8ad4219c98724cd3711ac16e2d18d32e6 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 19:46:49 -0700 Subject: [PATCH 38/55] Address query v2 review feedback --- cypher/models/pgsql/test/query_test.go | 4 ++-- query/v2/query.go | 8 ++++++++ query/v2/query_test.go | 16 ++++++++++++---- 3 files changed, 22 insertions(+), 6 deletions(-) diff --git a/cypher/models/pgsql/test/query_test.go b/cypher/models/pgsql/test/query_test.go index abb2d388..ac39ca9f 100644 --- a/cypher/models/pgsql/test/query_test.go +++ b/cypher/models/pgsql/test/query_test.go @@ -31,12 +31,12 @@ func TestQuery_KindGeneratesInclusiveKindMatcher(t *testing.T) { for _, queryBuilder := range queries { builtQuery, err := queryBuilder.Build() if err != nil { - t.Errorf("could not build query: %v", err) + t.Fatalf("could not build query: %v", err) } translatedQuery, err := translate.Translate(context.Background(), builtQuery.Query, mapper, builtQuery.Parameters, translate.DefaultGraphID) if err != nil { - t.Errorf("could not translate query: %#v: %v", builtQuery, err) + t.Fatalf("could not translate query: %#v: %v", builtQuery, err) } walk.PgSQL(translatedQuery.Statement, walk.NewSimpleVisitor(func(node pgsql.SyntaxNode, visitorHandler walk.VisitorHandler) { diff --git a/query/v2/query.go b/query/v2/query.go index 5438f91e..ea5b7628 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -880,6 +880,14 @@ func (s *builder) Delete(deleteItems ...any) QueryBuilder { s.deleteItems = append(s.deleteItems, deleteItem) pendingDeleteItems = append(pendingDeleteItems, deleteItem) + case *cypher.PropertyLookup: + if err := validateExpressionValue(typedNextUpdate, "delete expression"); err != nil { + s.trackError(err) + continue + } + + s.trackError(fmt.Errorf("delete target must be a node, relationship, or variable; use remove for properties")) + case QualifiedExpression: if isNilPointer(typedNextUpdate) { s.trackError(fmt.Errorf("delete target is nil")) diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 20169362..edd8cc06 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -329,6 +329,13 @@ func TestAliasedProjectionCreatesMatchInference(t *testing.T) { require.NoError(t, err) require.Equal(t, "match (n) return id(n) as node_id", renderPrepared(t, preparedQuery)) + + preparedQuery, err = v2.New().Return( + v2.As(v2.Node(), "alias"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (n) return n as alias", renderPrepared(t, preparedQuery)) } func TestInvalidProjectionAliasReturnsBuildError(t *testing.T) { @@ -469,10 +476,11 @@ func TestRawUpdatingInputsAreValidated(t *testing.T) { func TestDeleteRejectsNonTargetQualifiedExpressions(t *testing.T) { cases := map[string]any{ - "property": v2.Node().Property("name"), - "id": v2.Node().ID(), - "kinds": v2.Node().Kinds(), - "kind": v2.Relationship().Kind(), + "property continuation": v2.Node().Property("name"), + "raw property lookup": cypher.NewPropertyLookup("n", "name"), + "id": v2.Node().ID(), + "kinds": v2.Node().Kinds(), + "kind": v2.Relationship().Kind(), } for name, target := range cases { From ad2fe5552af8d9452c11873d90d1dcaf9bcbe990 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 19:55:33 -0700 Subject: [PATCH 39/55] Reject unbound query v2 identifiers --- query/v2/query.go | 4 ++ query/v2/query_test.go | 51 ++++++++++++++++++++++++ query/v2/util.go | 88 +++++++++++++++++++++++++++++++++++++++++- 3 files changed, 141 insertions(+), 2 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index ea5b7628..858132ce 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -1311,6 +1311,10 @@ func (s *builder) Build() (*PreparedQuery, error) { matchIdentifiers := readIdentifiers.Clone() matchIdentifiers.Or(actionIdentifiers) + if err := validateKnownIdentifiers(matchIdentifiers, s.identifiers); err != nil { + return nil, err + } + if s.wantsShortestPathPattern() && !isRelationshipPattern(matchIdentifiers, s.identifiers) { return nil, fmt.Errorf("shortest path query requires relationship query identifiers") } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index edd8cc06..3ef4e7b4 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -256,6 +256,57 @@ func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) { require.ErrorContains(t, err, "query mixes node and relationship query identifiers") } +func TestRawIdentifiersMustBeKnownToScope(t *testing.T) { + cases := map[string]v2.QueryBuilder{ + "delete": v2.New().Where( + v2.Node().ID().Equals(1), + ).Delete( + v2.Variable("x"), + ), + "projection": v2.New().Return( + v2.Node(), + v2.Variable("x"), + ), + "sort": v2.New().Return( + v2.Node(), + ).OrderBy( + v2.Asc(v2.Variable("x")), + ), + } + + for name, builder := range cases { + t.Run(name, func(t *testing.T) { + _, err := builder.Build() + require.ErrorContains(t, err, `query contains unknown identifier "x"`) + }) + } +} + +func TestPathIdentifierRequiresShortestPathMatch(t *testing.T) { + _, err := v2.New().Return( + v2.Node(), + v2.Path(), + ).Build() + require.ErrorContains(t, err, `query contains unbound identifier "p"`) + + _, err = v2.New().Return( + v2.Relationship(), + v2.Path(), + ).Build() + require.ErrorContains(t, err, `query contains unbound identifier "p"`) +} + +func TestCreatedRawIdentifiersDoNotRequireMatch(t *testing.T) { + preparedQuery, err := v2.New().Create(&cypher.NodePattern{ + Variable: v2.Variable("created"), + }).Return( + v2.Variable("created"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (created) return created", renderPrepared(t, preparedQuery)) +} + func TestMultipleRelationshipKindMatchersRemainConjunctive(t *testing.T) { preparedQuery, err := v2.New().Where( v2.Relationship().Kind().Is(graph.StringKind("A")), diff --git a/query/v2/util.go b/query/v2/util.go index fb4c360c..ac51c63f 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -27,11 +27,63 @@ func isRelationshipPattern(seen *identifierSet, identifiers runtimeIdentifiers) return hasStart || hasRelationship || hasEnd } +func runtimeIdentifierSet(identifiers runtimeIdentifiers) *identifierSet { + return newIdentifierSet( + identifiers.path, + identifiers.node, + identifiers.start, + identifiers.relationship, + identifiers.end, + ) +} + +func nodePatternIdentifierSet(identifiers runtimeIdentifiers) *identifierSet { + return newIdentifierSet(identifiers.node) +} + +func relationshipPatternIdentifierSet(identifiers runtimeIdentifiers, includePath bool) *identifierSet { + allowedIdentifiers := newIdentifierSet( + identifiers.start, + identifiers.relationship, + identifiers.end, + ) + + if includePath { + allowedIdentifiers.Add(identifiers.path) + } + + return allowedIdentifiers +} + +func createRelationshipMatchIdentifierSet(identifiers runtimeIdentifiers) *identifierSet { + return newIdentifierSet(identifiers.start, identifiers.end) +} + +func validateKnownIdentifiers(seen *identifierSet, identifiers runtimeIdentifiers) error { + if identifier, hasIdentifier := seen.FirstOutside(runtimeIdentifierSet(identifiers)); hasIdentifier { + return fmt.Errorf("query contains unknown identifier %q", identifier) + } + + return nil +} + +func validateBoundIdentifiers(seen, bound *identifierSet) error { + if identifier, hasIdentifier := seen.FirstOutside(bound); hasIdentifier { + return fmt.Errorf("query contains unbound identifier %q", identifier) + } + + return nil +} + func prepareNodePattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error { if isRelationshipPattern(seen, identifiers) { return fmt.Errorf("query mixes node and relationship query identifiers") } + if err := validateBoundIdentifiers(seen, nodePatternIdentifierSet(identifiers)); err != nil { + return err + } + match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ Variable: identifiers.Node(), }) @@ -57,6 +109,10 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identi return err } + if err := validateBoundIdentifiers(seen, relationshipPatternIdentifierSet(identifiers, shortestPaths || allShortestPaths)); err != nil { + return err + } + var ( newPatternPart = match.NewPatternPart() startNodeSeen = seen.Contains(identifiers.start) @@ -103,6 +159,10 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identi } func prepareCreateRelationshipMatch(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers) error { + if err := validateBoundIdentifiers(seen, createRelationshipMatchIdentifierSet(identifiers)); err != nil { + return err + } + if seen.Contains(identifiers.start) { match.NewPatternPart().AddPatternElements(&cypher.NodePattern{ Variable: identifiers.Start(), @@ -593,10 +653,16 @@ type identifierSet struct { identifiers map[string]struct{} } -func newIdentifierSet() *identifierSet { - return &identifierSet{ +func newIdentifierSet(identifiers ...string) *identifierSet { + set := &identifierSet{ identifiers: map[string]struct{}{}, } + + for _, identifier := range identifiers { + set.Add(identifier) + } + + return set } func (s *identifierSet) Add(identifier string) { @@ -630,6 +696,24 @@ func (s *identifierSet) Contains(identifier string) bool { return containsIdentifier } +func (s *identifierSet) FirstOutside(allowed *identifierSet) (string, bool) { + var identifiers []string + + for identifier := range s.identifiers { + if !allowed.Contains(identifier) { + identifiers = append(identifiers, identifier) + } + } + + sort.Strings(identifiers) + + if len(identifiers) == 0 { + return "", false + } + + return identifiers[0], true +} + func (s *identifierSet) CollectFromExpression(expr cypher.Expression) error { if exprIdentifiers, err := extractCypherIdentifiers(expr); err != nil { return err From 292af2db3c54d4776b1eb9c1373c5ff295b4c42a Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 20:22:10 -0700 Subject: [PATCH 40/55] Add query v2 recursive traversal depth --- query/v2/backend_test.go | 70 ++++++++++++++++++++++++++ query/v2/query.go | 103 +++++++++++++++++++++++++++++++++++++-- query/v2/query_test.go | 100 +++++++++++++++++++++++++++++++++++++ query/v2/util.go | 12 +++-- 4 files changed, 279 insertions(+), 6 deletions(-) diff --git a/query/v2/backend_test.go b/query/v2/backend_test.go index 6c454694..03ce4787 100644 --- a/query/v2/backend_test.go +++ b/query/v2/backend_test.go @@ -75,6 +75,17 @@ func TestBackendParityNeo4jPrepare(t *testing.T) { expectedCypher: "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", expectedParams: map[string]any{"p0": 1, "p1": 2}, }, + "recursive traversal": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + ).Return( + v2.Path(), + v2.End().ID(), + ), + expectedCypher: "match p = (s)-[r:MemberOf*1..2]->(e) where id(s) = $p0 return p, id(e)", + expectedParams: map[string]any{"p0": 1}, + }, "create node": { builder: v2.New().Create( v2.NodePattern(graph.Kinds{graph.StringKind("User")}, v2.NamedParameter("props", map[string]any{"name": "u"})), @@ -129,6 +140,65 @@ func TestBackendParityNeo4jPrepare(t *testing.T) { } } +func TestBackendParityPGTranslateTraversalDepth(t *testing.T) { + edgeKind := graph.StringKind("MemberOf") + mapper := testKindMapper(edgeKind) + + cases := map[string]struct { + builder v2.QueryBuilder + expectedSQLContains []string + }{ + "path": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + ).Return( + v2.Path(), + ), + expectedSQLContains: []string{ + "with recursive", + "ordered_edges_to_path", + "n0.id = @pi0::int8", + "e0.kind_id = any (array [1]::int2[])", + "depth < 2", + }, + }, + "endpoints": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 2)).Where( + v2.Relationship().Kind().Is(edgeKind), + v2.Start().ID().Equals(1), + ).Return( + v2.Start().ID(), + v2.End().ID(), + ), + expectedSQLContains: []string{ + "with recursive", + "n0.id = @pi0::int8", + "e0.kind_id = any (array [1]::int2[])", + "depth < 2", + "select (s0.n0).id, (s0.n1).id from s0", + }, + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + + translation, err := translate.Translate(context.Background(), preparedQuery.Query, mapper, preparedQuery.Parameters, translate.DefaultGraphID) + require.NoError(t, err) + + sql, err := translate.Translated(translation) + require.NoError(t, err) + + for _, expected := range testCase.expectedSQLContains { + require.Contains(t, sql, expected) + } + }) + } +} + func TestBackendParityPGTranslate(t *testing.T) { userKind := graph.StringKind("User") edgeKind := graph.StringKind("MemberOf") diff --git a/query/v2/query.go b/query/v2/query.go index 858132ce..880205a1 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -16,6 +16,68 @@ type runtimeIdentifiers struct { end string } +type TraversalDepth struct { + patternRange *cypher.PatternRange + err error +} + +func traversalDepthBound(value int64) *int64 { + return &value +} + +func newTraversalDepth(start, end *int64) TraversalDepth { + if start != nil && *start < 0 { + return TraversalDepth{ + err: fmt.Errorf("traversal depth minimum must be non-negative: %d", *start), + } + } + + if end != nil && *end < 0 { + return TraversalDepth{ + err: fmt.Errorf("traversal depth maximum must be non-negative: %d", *end), + } + } + + if start != nil && end != nil && *end < *start { + return TraversalDepth{ + err: fmt.Errorf("traversal depth maximum %d is less than minimum %d", *end, *start), + } + } + + return TraversalDepth{ + patternRange: cypher.NewPatternRange(start, end), + } +} + +func (s TraversalDepth) rangePattern() *cypher.PatternRange { + if s.patternRange == nil { + return &cypher.PatternRange{} + } + + return cypher.Copy(s.patternRange) +} + +func AnyDepth() TraversalDepth { + return newTraversalDepth(nil, nil) +} + +func MinDepth(min int64) TraversalDepth { + return newTraversalDepth(traversalDepthBound(min), nil) +} + +func MaxDepth(max int64) TraversalDepth { + return newTraversalDepth(nil, traversalDepthBound(max)) +} + +func DepthRange(min, max int64) TraversalDepth { + return newTraversalDepth(traversalDepthBound(min), traversalDepthBound(max)) +} + +func ExactDepth(depth int64) TraversalDepth { + depthBound := traversalDepthBound(depth) + return newTraversalDepth(depthBound, depthBound) +} + func (s runtimeIdentifiers) Path() *cypher.Variable { return cypher.NewVariableWithSymbol(s.path) } @@ -623,6 +685,7 @@ type QueryBuilder interface { Delete(expressions ...any) QueryBuilder WithShortestPaths() QueryBuilder WithAllShortestPaths() QueryBuilder + WithTraversalDepth(depth TraversalDepth) QueryBuilder WithRelationshipDirection(direction graph.Direction) QueryBuilder Build() (*PreparedQuery, error) } @@ -659,6 +722,7 @@ type builder struct { deleteItems []cypher.Expression detachDelete bool relationshipDirection graph.Direction + traversalDepth *cypher.PatternRange shortestPathQuery bool allShorestPathsQuery bool skip *int @@ -687,6 +751,16 @@ func (s *builder) WithAllShortestPaths() QueryBuilder { return s } +func (s *builder) WithTraversalDepth(depth TraversalDepth) QueryBuilder { + if depth.err != nil { + s.trackError(depth.err) + } else { + s.traversalDepth = depth.rangePattern() + } + + return s +} + func (s *builder) WithRelationshipDirection(direction graph.Direction) QueryBuilder { if err := validateRelationshipDirection(direction); err != nil { s.trackError(err) @@ -1225,6 +1299,14 @@ func (s *builder) wantsShortestPathPattern() bool { return s.shortestPathQuery || s.allShorestPathsQuery } +func (s *builder) wantsTraversalPattern() bool { + return s.traversalDepth != nil +} + +func (s *builder) usesRangedRelationshipPattern() bool { + return s.wantsTraversalPattern() || s.wantsShortestPathPattern() +} + func (s *builder) Build() (*PreparedQuery, error) { if len(s.errors) > 0 { return nil, errors.Join(s.errors...) @@ -1295,9 +1377,16 @@ func (s *builder) Build() (*PreparedQuery, error) { if constraints.Left != nil { whereClause.Add(constraints) - if err := readIdentifiers.CollectFromExpression(whereClause); err != nil { + whereIdentifiers := newIdentifierSet() + if err := whereIdentifiers.CollectFromExpression(whereClause); err != nil { return nil, err } + + if s.usesRangedRelationshipPattern() && whereIdentifiers.Contains(s.identifiers.relationship) { + return nil, fmt.Errorf("ranged relationship patterns only support top-level relationship kind constraints") + } + + readIdentifiers.Or(whereIdentifiers) } } @@ -1308,6 +1397,10 @@ func (s *builder) Build() (*PreparedQuery, error) { actionIdentifiers.Remove(createScope.identifiers) + if s.usesRangedRelationshipPattern() && actionIdentifiers.Contains(s.identifiers.relationship) { + return nil, fmt.Errorf("ranged relationship patterns do not support relationship projections or mutations; return the path instead") + } + matchIdentifiers := readIdentifiers.Clone() matchIdentifiers.Or(actionIdentifiers) @@ -1315,6 +1408,10 @@ func (s *builder) Build() (*PreparedQuery, error) { return nil, err } + if s.wantsTraversalPattern() && !isRelationshipPattern(matchIdentifiers, s.identifiers) && !matchIdentifiers.Contains(s.identifiers.path) { + return nil, fmt.Errorf("recursive traversal query requires relationship query identifiers") + } + if s.wantsShortestPathPattern() && !isRelationshipPattern(matchIdentifiers, s.identifiers) { return nil, fmt.Errorf("shortest path query requires relationship query identifiers") } @@ -1328,8 +1425,8 @@ func (s *builder) Build() (*PreparedQuery, error) { if err := prepareCreateRelationshipMatch(match, matchIdentifiers, s.identifiers); err != nil { return nil, err } - } else if isRelationshipPattern(matchIdentifiers, s.identifiers) { - if err := prepareRelationshipPattern(match, matchIdentifiers, s.identifiers, relationshipKinds, s.relationshipDirection, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { + } else if isRelationshipPattern(matchIdentifiers, s.identifiers) || (s.wantsTraversalPattern() && matchIdentifiers.Contains(s.identifiers.path)) { + if err := prepareRelationshipPattern(match, matchIdentifiers, s.identifiers, relationshipKinds, s.traversalDepth, s.relationshipDirection, s.shortestPathQuery, s.allShorestPathsQuery); err != nil { return nil, err } } else { diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 3ef4e7b4..7e31e1a4 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -231,6 +231,16 @@ func TestShortestPathControls(t *testing.T) { require.NoError(t, err) require.Equal(t, "match p = allShortestPaths((s)-[r:MemberOf*]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + preparedQuery, err = v2.New().WithShortestPaths().WithTraversalDepth(v2.MinDepth(1)).Where( + v2.Relationship().Kind().Is(graph.StringKind("MemberOf")), + v2.Start().ID().Equals(1), + v2.End().ID().Equals(2), + ).Return( + v2.Path(), + ).Build() + require.NoError(t, err) + require.Equal(t, "match p = shortestPath((s)-[r:MemberOf*1..]->(e)) where id(s) = $p0 and id(e) = $p1 return p", renderPrepared(t, preparedQuery)) + _, err = v2.New().WithShortestPaths().WithAllShortestPaths().Where( v2.Start().ID().Equals(1), v2.End().ID().Equals(2), @@ -246,6 +256,96 @@ func TestShortestPathControls(t *testing.T) { require.ErrorContains(t, err, "shortest path query requires relationship query identifiers") } +func TestTraversalDepthControls(t *testing.T) { + cases := map[string]struct { + builder v2.QueryBuilder + expectedCypher string + expectedParams map[string]any + }{ + "any depth": { + builder: v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.End()), + expectedCypher: "match ()-[*]->(e) return e", + }, + "minimum depth": { + builder: v2.New().WithTraversalDepth(v2.MinDepth(1)).Return(v2.End()), + expectedCypher: "match ()-[*1..]->(e) return e", + }, + "maximum depth": { + builder: v2.New().WithTraversalDepth(v2.MaxDepth(5)).Return(v2.End()), + expectedCypher: "match ()-[*..5]->(e) return e", + }, + "depth range": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(1, 5)).Where( + v2.Relationship().Kind().IsOneOf(graph.Kinds{graph.StringKind("KindA"), graph.StringKind("KindB")}), + v2.Start().ID().Equals(1), + v2.End().Kinds().Has(graph.StringKind("User")), + ).Return( + v2.Path(), + v2.End(), + ), + expectedCypher: "match p = (s)-[r:KindA|KindB*1..5]->(e) where id(s) = $p0 and e:User return p, e", + expectedParams: map[string]any{"p0": 1}, + }, + "exact depth": { + builder: v2.New().WithTraversalDepth(v2.ExactDepth(3)).Return(v2.End()), + expectedCypher: "match ()-[*3..3]->(e) return e", + }, + "inbound depth range": { + builder: v2.New().WithTraversalDepth(v2.DepthRange(2, 5)).WithRelationshipDirection(graph.DirectionInbound).Return(v2.Start()), + expectedCypher: "match (s)<-[*2..5]-() return s", + }, + "path only": { + builder: v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Path()), + expectedCypher: "match p = ()-[*]->() return p", + }, + } + + for name, testCase := range cases { + t.Run(name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + require.Equal(t, testCase.expectedCypher, renderPrepared(t, preparedQuery)) + + if testCase.expectedParams == nil { + require.Empty(t, preparedQuery.Parameters) + } else { + require.Equal(t, testCase.expectedParams, preparedQuery.Parameters) + } + }) + } +} + +func TestInvalidTraversalDepthControls(t *testing.T) { + _, err := v2.New().WithTraversalDepth(v2.MinDepth(-1)).Return(v2.End()).Build() + require.ErrorContains(t, err, "traversal depth minimum must be non-negative: -1") + + _, err = v2.New().WithTraversalDepth(v2.MaxDepth(-1)).Return(v2.End()).Build() + require.ErrorContains(t, err, "traversal depth maximum must be non-negative: -1") + + _, err = v2.New().WithTraversalDepth(v2.DepthRange(3, 1)).Return(v2.End()).Build() + require.ErrorContains(t, err, "traversal depth maximum 1 is less than minimum 3") + + _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Node()).Build() + require.ErrorContains(t, err, "recursive traversal query requires relationship query identifiers") + + _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Where( + v2.Relationship().Property("enabled").Equals(true), + ).Return( + v2.End(), + ).Build() + require.ErrorContains(t, err, "ranged relationship patterns only support top-level relationship kind constraints") + + _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Return(v2.Relationship()).Build() + require.ErrorContains(t, err, "ranged relationship patterns do not support relationship projections or mutations") + + _, err = v2.New().WithTraversalDepth(v2.AnyDepth()).Where( + v2.Start().ID().Equals(1), + ).Delete( + v2.Relationship(), + ).Build() + require.ErrorContains(t, err, "ranged relationship patterns do not support relationship projections or mutations") +} + func TestMixedNodeAndRelationshipIdentifiersReturnError(t *testing.T) { _, err := v2.New().Where( v2.Node().ID().Equals(1), diff --git a/query/v2/util.go b/query/v2/util.go index ac51c63f..a477e46b 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -100,7 +100,7 @@ func validateRelationshipDirection(direction graph.Direction) error { } } -func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers, relationshipKinds graph.Kinds, direction graph.Direction, shortestPaths, allShortestPaths bool) error { +func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identifiers runtimeIdentifiers, relationshipKinds graph.Kinds, relationshipRange *cypher.PatternRange, direction graph.Direction, shortestPaths, allShortestPaths bool) error { if shortestPaths && allShortestPaths { return errors.New("query is requesting both all shortest paths and shortest paths") } @@ -109,7 +109,8 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identi return err } - if err := validateBoundIdentifiers(seen, relationshipPatternIdentifierSet(identifiers, shortestPaths || allShortestPaths)); err != nil { + hasRangedRelationshipPattern := relationshipRange != nil || shortestPaths || allShortestPaths + if err := validateBoundIdentifiers(seen, relationshipPatternIdentifierSet(identifiers, hasRangedRelationshipPattern)); err != nil { return err } @@ -140,8 +141,13 @@ func prepareRelationshipPattern(match *cypher.Match, seen *identifierSet, identi relationshipPattern.Variable = identifiers.Relationship() } - if shortestPaths || allShortestPaths { + if shortestPaths || allShortestPaths || (relationshipRange != nil && seen.Contains(identifiers.path)) { newPatternPart.Variable = identifiers.Path() + } + + if relationshipRange != nil { + relationshipPattern.Range = cypher.Copy(relationshipRange) + } else if shortestPaths || allShortestPaths { relationshipPattern.Range = &cypher.PatternRange{} } From dcebd3a28703087c7cca154811c19a8c0baf75f6 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:15:38 -0700 Subject: [PATCH 41/55] test: cover v2 logical precedence --- query/v2/legacy_parity_test.go | 29 ++++++++++++++++++++ query/v2/query_test.go | 50 ++++++++++++++++++++++++++++++++++ 2 files changed, 79 insertions(+) diff --git a/query/v2/legacy_parity_test.go b/query/v2/legacy_parity_test.go index 0a6f35bd..e8fe1bee 100644 --- a/query/v2/legacy_parity_test.go +++ b/query/v2/legacy_parity_test.go @@ -94,6 +94,35 @@ func TestLegacyNeo4jParity(t *testing.T) { ) }) + t.Run("node read with or and adjacent predicate", func(t *testing.T) { + assertLegacyNeo4jParity(t, + legacyquery.SinglePartQuery( + legacyquery.Where( + legacyquery.And( + legacyquery.Or( + legacyquery.Equals(legacyquery.NodeProperty("name"), "alice"), + legacyquery.Equals(legacyquery.NodeProperty("name"), "bob"), + ), + legacyquery.IsNotNull(legacyquery.NodeProperty("enabled")), + ), + ), + legacyquery.Returning( + legacyquery.Node(), + ), + ), + v2.New().Where( + v2.Or( + v2.Node().Property("name").Equals("alice"), + v2.Node().Property("name").Equals("bob"), + ), + v2.Node().Property("enabled").IsNotNull(), + ).Return( + v2.Node(), + ), + false, + ) + }) + t.Run("relationship read", func(t *testing.T) { assertLegacyNeo4jParity(t, legacyquery.SinglePartQuery( diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 7e31e1a4..47b9fd05 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -116,6 +116,56 @@ func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { }, preparedQuery.Parameters) } +func TestLogicalHelpersPreservePrecedence(t *testing.T) { + a := v2.Node().Property("a").Equals("a") + b := v2.Node().Property("b").Equals("b") + c := v2.Node().Property("c").Equals("c") + + testCases := []struct { + name string + builder v2.QueryBuilder + expected string + }{ + { + name: "or is parenthesized in isolation", + builder: v2.New().Where( + v2.Or(a, b), + ).Return(v2.Node()), + expected: "match (n) where (n.a = $p0 or n.b = $p1) return n", + }, + { + name: "or is parenthesized when where and-chains constraints", + builder: v2.New().Where( + v2.Or(a, b), + c, + ).Return(v2.Node()), + expected: "match (n) where (n.a = $p0 or n.b = $p1) and n.c = $p2 return n", + }, + { + name: "nested or is parenthesized inside and", + builder: v2.New().Where( + v2.And(a, v2.Or(b, c)), + ).Return(v2.Node()), + expected: "match (n) where n.a = $p0 and (n.b = $p1 or n.c = $p2) return n", + }, + { + name: "not wraps or", + builder: v2.New().Where( + v2.Not(v2.Or(a, b)), + ).Return(v2.Node()), + expected: "match (n) where not (n.a = $p0 or n.b = $p1) return n", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + preparedQuery, err := testCase.builder.Build() + require.NoError(t, err) + require.Equal(t, testCase.expected, renderPrepared(t, preparedQuery)) + }) + } +} + func TestInvalidCreateQualifiedExpressionReturnsError(t *testing.T) { _, err := v2.New().Create(v2.Node().Property("name")).Build() require.ErrorContains(t, err, "invalid qualified expression for create: *cypher.PropertyLookup") From fb2c5a3bd62d2c6cb4c9f9a2ae960f87f923e845 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:17:02 -0700 Subject: [PATCH 42/55] fix: preserve v2 logical precedence --- query/v2/query.go | 86 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 68 insertions(+), 18 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 880205a1..45aa735b 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -281,34 +281,88 @@ func valueExpression(value any) cypher.Expression { } } -func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) cypher.SyntaxNode { +func joinedExpressionList(operator cypher.Operator, operands []cypher.SyntaxNode) ([]cypher.Expression, cypher.SyntaxNode) { if len(operands) == 0 { - return invalidExpression(fmt.Errorf("%s requires at least one operand", operator)) + return nil, invalidExpression(fmt.Errorf("%s requires at least one operand", operator)) } - expressionList := &cypher.Comparison{} + expressions := make([]cypher.Expression, len(operands)) + for idx, operand := range operands { + expressions[idx] = operand + } + + return expressions, nil +} + +func comparisonHasLogicalOperator(comparison *cypher.Comparison) bool { + if comparison == nil { + return false + } + + for _, partial := range comparison.Partials { + switch partial.Operator { + case cypher.OperatorAnd, cypher.OperatorOr: + return true + } + } + + return false +} + +func parenthesizeDisjunctiveExpression(expression cypher.Expression) cypher.Expression { + switch typedExpression := expression.(type) { + case *cypher.Parenthetical: + return typedExpression + case *cypher.Disjunction, *cypher.ExclusiveDisjunction: + return cypher.NewParenthetical(typedExpression) + case *cypher.Comparison: + if comparisonHasLogicalOperator(typedExpression) { + return cypher.NewParenthetical(typedExpression) + } + } - if len(operands) > 0 { - expressionList.Left = operands[0] + return expression +} - for _, operand := range operands[1:] { - expressionList.NewPartialComparison(operator, operand) +func parenthesizeLogicalExpression(expression cypher.Expression) cypher.Expression { + switch typedExpression := expression.(type) { + case *cypher.Parenthetical: + return typedExpression + case *cypher.Conjunction, *cypher.Disjunction, *cypher.ExclusiveDisjunction: + return cypher.NewParenthetical(typedExpression) + case *cypher.Comparison: + if comparisonHasLogicalOperator(typedExpression) { + return cypher.NewParenthetical(typedExpression) } } - return expressionList + return expression } func Not(operand cypher.Expression) cypher.Expression { - return cypher.NewNegation(operand) + return cypher.NewNegation(parenthesizeLogicalExpression(operand)) } func And(operands ...cypher.SyntaxNode) cypher.SyntaxNode { - return joinedExpressionList(cypher.OperatorAnd, operands) + expressions, errExpression := joinedExpressionList(cypher.OperatorAnd, operands) + if errExpression != nil { + return errExpression + } + + for idx, expression := range expressions { + expressions[idx] = parenthesizeDisjunctiveExpression(expression) + } + + return cypher.NewConjunction(expressions...) } func Or(operands ...cypher.SyntaxNode) cypher.SyntaxNode { - return joinedExpressionList(cypher.OperatorOr, operands) + expressions, errExpression := joinedExpressionList(cypher.OperatorOr, operands) + if errExpression != nil { + return errExpression + } + + return cypher.NewParenthetical(cypher.NewDisjunction(expressions...)) } type SortDirection int @@ -1343,7 +1397,7 @@ func (s *builder) Build() (*PreparedQuery, error) { if len(s.constraints) > 0 { var ( whereClause = match.NewWhere() - constraints = &cypher.Comparison{} + constraints = cypher.NewConjunction() numRelationshipKindMatchers, err = countRelationshipKindMatchers(s.constraints, s.identifiers) ) if err != nil { @@ -1367,14 +1421,10 @@ func (s *builder) Build() (*PreparedQuery, error) { } constraintCopy := cypher.Copy(nextConstraint) - if constraints.Left == nil { - constraints.Left = constraintCopy - } else { - constraints.NewPartialComparison(cypher.OperatorAnd, constraintCopy) - } + constraints.Add(parenthesizeDisjunctiveExpression(constraintCopy)) } - if constraints.Left != nil { + if constraints.Len() > 0 { whereClause.Add(constraints) whereIdentifiers := newIdentifierSet() From fa5ede4d0212b50cca16e3950acd0b4fd6f34a77 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:17:24 -0700 Subject: [PATCH 43/55] chore: streamline v2 build validation --- query/v2/query.go | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 45aa735b..6d0cfdc1 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -1370,7 +1370,7 @@ func (s *builder) Build() (*PreparedQuery, error) { return nil, fmt.Errorf("query has no action specified") } - if err := collectModelErrorsFromKnownValues(s.constraints, s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems); err != nil { + if err := collectModelErrorsFromKnownValues(s.constraints, s.creates, s.setItems, s.removeItems, s.deleteItems, s.projections, s.sortItems); err != nil { return nil, err } @@ -1405,10 +1405,6 @@ func (s *builder) Build() (*PreparedQuery, error) { } for _, nextConstraint := range s.constraints { - if err := collectModelErrorsFromKnownValues(nextConstraint); err != nil { - return nil, err - } - switch typedNextConstraint := nextConstraint.(type) { case *cypher.KindMatcher: if identifier, typeOK := typedNextConstraint.Reference.(*cypher.Variable); !typeOK { From c79ed9c0bd30cf4f5614ccd94946c784d14f89c6 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:18:08 -0700 Subject: [PATCH 44/55] refactor: return concrete v2 mutation types --- query/v2/query.go | 20 ++++++++++---------- query/v2/query_test.go | 22 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 10 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index 6d0cfdc1..c82f99b6 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -446,8 +446,8 @@ type KindContinuation interface { type KindsContinuation interface { Has(kind graph.Kind) cypher.Expression HasOneOf(kinds graph.Kinds) cypher.Expression - Add(kinds graph.Kinds) cypher.Expression - Remove(kinds graph.Kinds) cypher.Expression + Add(kinds graph.Kinds) *cypher.SetItem + Remove(kinds graph.Kinds) *cypher.RemoveItem } type Comparable interface { @@ -576,7 +576,7 @@ func (s *entity[T]) Count() cypher.Expression { return cypher.NewSimpleFunctionInvocation(cypher.CountFunction, s.identifier) } -func (s *entity[T]) SetProperties(properties map[string]any) cypher.Expression { +func (s *entity[T]) SetProperties(properties map[string]any) *cypher.Set { set := &cypher.Set{} for _, key := range sortedPropertyKeys(properties) { @@ -586,7 +586,7 @@ func (s *entity[T]) SetProperties(properties map[string]any) cypher.Expression { return set } -func (s *entity[T]) RemoveProperties(properties []string) cypher.Expression { +func (s *entity[T]) RemoveProperties(properties []string) *cypher.Remove { remove := &cypher.Remove{} for _, key := range properties { @@ -689,7 +689,7 @@ func (s kindsContinuation) HasOneOf(kinds graph.Kinds) cypher.Expression { } } -func (s kindsContinuation) Add(kinds graph.Kinds) cypher.Expression { +func (s kindsContinuation) Add(kinds graph.Kinds) *cypher.SetItem { return cypher.NewSetItem( s.identifier, cypher.OperatorLabelAssignment, @@ -697,7 +697,7 @@ func (s kindsContinuation) Add(kinds graph.Kinds) cypher.Expression { ) } -func (s kindsContinuation) Remove(kinds graph.Kinds) cypher.Expression { +func (s kindsContinuation) Remove(kinds graph.Kinds) *cypher.RemoveItem { return cypher.RemoveKindsByMatcher(cypher.NewKindMatcher(s.identifier, kinds, false)) } @@ -713,8 +713,8 @@ type RelationshipContinuation interface { RelationshipPattern(kind graph.Kind, properties cypher.Expression, direction graph.Direction) cypher.Expression Kind() KindContinuation - SetProperties(properties map[string]any) cypher.Expression - RemoveProperties(properties []string) cypher.Expression + SetProperties(properties map[string]any) *cypher.Set + RemoveProperties(properties []string) *cypher.Remove } type NodeContinuation interface { @@ -723,8 +723,8 @@ type NodeContinuation interface { NodePattern(kinds graph.Kinds, properties cypher.Expression) cypher.Expression Kinds() KindsContinuation - SetProperties(properties map[string]any) cypher.Expression - RemoveProperties(properties []string) cypher.Expression + SetProperties(properties map[string]any) *cypher.Set + RemoveProperties(properties []string) *cypher.Remove } type QueryBuilder interface { diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 47b9fd05..75d74d93 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -857,6 +857,28 @@ func TestUpdateCompatibilityHelpers(t *testing.T) { }, preparedQuery.Parameters) } +func TestFluentMutationHelpersReturnConcreteMutationTypes(t *testing.T) { + kinds := graph.Kinds{graph.StringKind("Enabled")} + + var addItem *cypher.SetItem = v2.Node().Kinds().Add(kinds) + require.NotNil(t, addItem) + + var removeItem *cypher.RemoveItem = v2.Node().Kinds().Remove(kinds) + require.NotNil(t, removeItem) + + var nodeSet *cypher.Set = v2.Node().SetProperties(map[string]any{"name": "updated"}) + require.Len(t, nodeSet.Items, 1) + + var nodeRemove *cypher.Remove = v2.Node().RemoveProperties([]string{"stale"}) + require.Len(t, nodeRemove.Items, 1) + + var relationshipSet *cypher.Set = v2.Relationship().SetProperties(map[string]any{"name": "updated"}) + require.Len(t, relationshipSet.Items, 1) + + var relationshipRemove *cypher.Remove = v2.Relationship().RemoveProperties([]string{"stale"}) + require.Len(t, relationshipRemove.Items, 1) +} + func TestSetPropertiesSortsKeys(t *testing.T) { properties := map[string]any{ "zeta": 3, From 62cd5bc9721f31f9cf6542f3c4fa2cefe296a50f Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:19:12 -0700 Subject: [PATCH 45/55] fix: accept unicode cypher symbols --- query/v2/query_test.go | 16 ++++++++++++++++ query/v2/util.go | 20 +++++++++++++------- 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 75d74d93..47c65762 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -248,6 +248,22 @@ func TestInvalidScopeAliasesReturnBuildErrors(t *testing.T) { require.ErrorContains(t, err, `scope alias node has invalid symbol "bad name"`) } +func TestUnicodeCypherSymbols(t *testing.T) { + scope := v2.NewScope("路径", "节点", "起点", "关系", "终点") + + preparedQuery, err := scope.New().Where( + scope.Node().Property("name").Equals(v2.NamedParameter("名字", "alice")), + ).Return( + v2.As(scope.Node().ID(), "标识"), + ).Build() + require.NoError(t, err) + + require.Equal(t, "match (节点) where 节点.name = $名字 return id(节点) as 标识", renderPrepared(t, preparedQuery)) + require.Equal(t, map[string]any{ + "名字": "alice", + }, preparedQuery.Parameters) +} + func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build() require.ErrorContains(t, err, "unsupported relationship direction: invalid") diff --git a/query/v2/util.go b/query/v2/util.go index a477e46b..9a0c8b33 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -7,6 +7,8 @@ import ( "sort" "strconv" "strings" + "unicode" + "unicode/utf8" "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/walk" @@ -264,12 +266,12 @@ func sortedPropertyKeys(properties map[string]any) []string { return keys } -func isCypherSymbolStart(char byte) bool { - return char == '_' || (char >= 'A' && char <= 'Z') || (char >= 'a' && char <= 'z') +func isCypherSymbolStart(char rune) bool { + return char == '_' || unicode.IsLetter(char) || unicode.In(char, unicode.Nl, unicode.Pc) } -func isCypherSymbolPart(char byte) bool { - return isCypherSymbolStart(char) || (char >= '0' && char <= '9') +func isCypherSymbolPart(char rune) bool { + return isCypherSymbolStart(char) || unicode.IsDigit(char) || unicode.In(char, unicode.Mark, unicode.Sc) } func validateCypherSymbol(symbol, context string) error { @@ -277,12 +279,16 @@ func validateCypherSymbol(symbol, context string) error { return fmt.Errorf("%s is empty", context) } - if !isCypherSymbolStart(symbol[0]) { + if !utf8.ValidString(symbol) { return fmt.Errorf("%s has invalid symbol %q", context, symbol) } - for idx := 1; idx < len(symbol); idx++ { - if !isCypherSymbolPart(symbol[idx]) { + for idx, char := range symbol { + if idx == 0 { + if !isCypherSymbolStart(char) { + return fmt.Errorf("%s has invalid symbol %q", context, symbol) + } + } else if !isCypherSymbolPart(char) { return fmt.Errorf("%s has invalid symbol %q", context, symbol) } } From 54bae60e9f01bb2a47ee3dbab24117bf2142ccc9 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:19:58 -0700 Subject: [PATCH 46/55] chore: harden v2 helper edges --- query/v2/query.go | 3 +++ query/v2/util.go | 8 ++++++++ query/v2/util_internal_test.go | 20 ++++++++++++++++++++ 3 files changed, 31 insertions(+) create mode 100644 query/v2/util_internal_test.go diff --git a/query/v2/query.go b/query/v2/query.go index c82f99b6..5db2d09e 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -78,6 +78,7 @@ func ExactDepth(depth int64) TraversalDepth { return newTraversalDepth(depthBound, depthBound) } +// Accessors return fresh variables; compare symbols rather than pointer identity. func (s runtimeIdentifiers) Path() *cypher.Variable { return cypher.NewVariableWithSymbol(s.path) } @@ -731,6 +732,7 @@ type QueryBuilder interface { Where(constraints ...cypher.SyntaxNode) QueryBuilder OrderBy(sortItems ...any) QueryBuilder Skip(offset int) QueryBuilder + // Limit accepts zero, which renders LIMIT 0 and returns an empty result set. Limit(limit int) QueryBuilder Return(projections ...any) QueryBuilder ReturnDistinct(projections ...any) QueryBuilder @@ -898,6 +900,7 @@ func (s *builder) appendDeleteItems(detach bool, items ...cypher.Expression) { return } + // Consecutive deletes share one clause; any node delete makes the whole clause DETACH DELETE. lastClauseIdx := len(s.updatingClauses) - 1 if lastClauseIdx >= 0 && s.updatingClauses[lastClauseIdx].kind == updatingClauseDelete { s.updatingClauses[lastClauseIdx].detach = s.updatingClauses[lastClauseIdx].detach || detach diff --git a/query/v2/util.go b/query/v2/util.go index 9a0c8b33..434bdfb2 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -692,12 +692,20 @@ func (s *identifierSet) Clone() *identifierSet { } func (s *identifierSet) Or(other *identifierSet) { + if s == nil || other == nil { + return + } + for otherIdentifier := range other.identifiers { s.identifiers[otherIdentifier] = struct{}{} } } func (s *identifierSet) Remove(other *identifierSet) { + if s == nil || other == nil { + return + } + for otherIdentifier := range other.identifiers { delete(s.identifiers, otherIdentifier) } diff --git a/query/v2/util_internal_test.go b/query/v2/util_internal_test.go new file mode 100644 index 00000000..1ae62236 --- /dev/null +++ b/query/v2/util_internal_test.go @@ -0,0 +1,20 @@ +package v2 + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIdentifierSetOrAndRemoveNilSafe(t *testing.T) { + var nilSet *identifierSet + + nilSet.Or(newIdentifierSet("ignored")) + nilSet.Remove(newIdentifierSet("ignored")) + + set := newIdentifierSet("kept") + set.Or(nil) + set.Remove(nil) + + require.True(t, set.Contains("kept")) +} From 0d0a059e1b03ef51f8a7d0f98996f5ba027c64f6 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:29:28 -0700 Subject: [PATCH 47/55] fix: preserve and grouping inside v2 or --- query/v2/query.go | 4 ++++ query/v2/query_test.go | 7 +++++++ 2 files changed, 11 insertions(+) diff --git a/query/v2/query.go b/query/v2/query.go index 5db2d09e..cbe03869 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -363,6 +363,10 @@ func Or(operands ...cypher.SyntaxNode) cypher.SyntaxNode { return errExpression } + for idx, expression := range expressions { + expressions[idx] = parenthesizeLogicalExpression(expression) + } + return cypher.NewParenthetical(cypher.NewDisjunction(expressions...)) } diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 47c65762..81eae4e5 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -148,6 +148,13 @@ func TestLogicalHelpersPreservePrecedence(t *testing.T) { ).Return(v2.Node()), expected: "match (n) where n.a = $p0 and (n.b = $p1 or n.c = $p2) return n", }, + { + name: "nested and is parenthesized inside or", + builder: v2.New().Where( + v2.Or(v2.And(a, b), c), + ).Return(v2.Node()), + expected: "match (n) where ((n.a = $p0 and n.b = $p1) or n.c = $p2) return n", + }, { name: "not wraps or", builder: v2.New().Where( From 67e1e811c8319256e8945dad7d94dbb3bd97881f Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:35:38 -0700 Subject: [PATCH 48/55] fix: support undirected relationship patterns --- cypher/models/cypher/format/format.go | 10 +---- cypher/models/cypher/format/format_test.go | 50 ++++++++++++++++++++++ cypher/test/cases/mutation_tests.json | 2 +- cypher/test/cases/positive_tests.json | 10 ++--- query/neo4j/neo4j_test.go | 10 ++--- query/v2/query_test.go | 20 ++++++--- query/v2/util.go | 2 +- 7 files changed, 78 insertions(+), 26 deletions(-) diff --git a/cypher/models/cypher/format/format.go b/cypher/models/cypher/format/format.go index f2f62c15..a05d7b1a 100644 --- a/cypher/models/cypher/format/format.go +++ b/cypher/models/cypher/format/format.go @@ -80,14 +80,11 @@ func (s Emitter) formatNodePattern(output io.Writer, nodePattern *cypher.NodePat func (s Emitter) formatRelationshipPattern(output io.Writer, relationshipPattern *cypher.RelationshipPattern) error { switch relationshipPattern.Direction { - case graph.DirectionOutbound: + case graph.DirectionOutbound, graph.DirectionBoth: if _, err := io.WriteString(output, "-["); err != nil { return err } - case graph.DirectionBoth: - fallthrough - case graph.DirectionInbound: if _, err := io.WriteString(output, "<-["); err != nil { return err @@ -147,14 +144,11 @@ func (s Emitter) formatRelationshipPattern(output io.Writer, relationshipPattern } switch relationshipPattern.Direction { - case graph.DirectionInbound: + case graph.DirectionInbound, graph.DirectionBoth: if _, err := io.WriteString(output, "]-"); err != nil { return err } - case graph.DirectionBoth: - fallthrough - case graph.DirectionOutbound: if _, err := io.WriteString(output, "]->"); err != nil { return err diff --git a/cypher/models/cypher/format/format_test.go b/cypher/models/cypher/format/format_test.go index 327f65d4..8a02d2e9 100644 --- a/cypher/models/cypher/format/format_test.go +++ b/cypher/models/cypher/format/format_test.go @@ -6,6 +6,7 @@ import ( "github.com/specterops/dawgs/cypher/models/cypher" "github.com/specterops/dawgs/cypher/models/cypher/format" + "github.com/specterops/dawgs/graph" "github.com/specterops/dawgs/cypher/frontend" "github.com/stretchr/testify/require" @@ -27,6 +28,55 @@ func TestCypherEmitter_StripLiterals(t *testing.T) { require.Equal(t, "match (n {value: $STRIPPED}) where n.other = $STRIPPED and n.number = $STRIPPED return n.name, n", buffer.String()) } +func TestCypherEmitter_RelationshipDirections(t *testing.T) { + testCases := []struct { + name string + direction graph.Direction + expected string + }{ + { + name: "outbound", + direction: graph.DirectionOutbound, + expected: "match (a)-[r]->(b) return r", + }, + { + name: "inbound", + direction: graph.DirectionInbound, + expected: "match (a)<-[r]-(b) return r", + }, + { + name: "both", + direction: graph.DirectionBoth, + expected: "match (a)-[r]-(b) return r", + }, + } + + for _, testCase := range testCases { + t.Run(testCase.name, func(t *testing.T) { + regularQuery, singlePartQuery := cypher.NewRegularQueryWithSingleQuery() + match := singlePartQuery.NewReadingClause().NewMatch(false) + match.NewPatternPart().AddPatternElements( + &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("a"), + }, + &cypher.RelationshipPattern{ + Variable: cypher.NewVariableWithSymbol("r"), + Direction: testCase.direction, + }, + &cypher.NodePattern{ + Variable: cypher.NewVariableWithSymbol("b"), + }, + ) + + singlePartQuery.NewProjection(false).AddItem(cypher.NewProjectionItemWithExpr(cypher.NewVariableWithSymbol("r"))) + + rendered, err := format.RegularQuery(regularQuery, false) + require.NoError(t, err) + require.Equal(t, testCase.expected, rendered) + }) + } +} + func TestCypherEmitter_HappyPath(t *testing.T) { test.LoadFixture(t, test.MutationTestCases).Run(t) test.LoadFixture(t, test.PositiveTestCases).Run(t) diff --git a/cypher/test/cases/mutation_tests.json b/cypher/test/cases/mutation_tests.json index 7893439b..dc73b031 100644 --- a/cypher/test/cases/mutation_tests.json +++ b/cypher/test/cases/mutation_tests.json @@ -4,7 +4,7 @@ "name": "Multipart query with mutation", "type": "string_match", "details": { - "query": "match (s:Ship {name: 'Nebuchadnezzar'}) with s as ship merge p = (c:Crew {name: 'Neo'})\u003c-[:CrewOf]-\u003e(ship) set c.title = 'The One' return p", + "query": "match (s:Ship {name: 'Nebuchadnezzar'}) with s as ship merge p = (c:Crew {name: 'Neo'})-[:CrewOf]-(ship) set c.title = 'The One' return p", "fitness": 7 } }, diff --git a/cypher/test/cases/positive_tests.json b/cypher/test/cases/positive_tests.json index b3941c04..cedfccdc 100644 --- a/cypher/test/cases/positive_tests.json +++ b/cypher/test/cases/positive_tests.json @@ -189,7 +189,7 @@ "name": "Specify bi-directional relationship", "type": "string_match", "details": { - "query": "match (p:Person)\u003c-[]-\u003e(m:Movie) return m", + "query": "match (p:Person)-[]-(m:Movie) return m", "fitness": 0 } }, @@ -437,7 +437,7 @@ "name": "built-in shortestPaths()", "type": "string_match", "details": { - "query": "match p = shortestPath((p1:Person)\u003c-[*]-\u003e(p2:Person)) where p1.name = 'tom' and p2.name = 'jerry' return p", + "query": "match p = shortestPath((p1:Person)-[*]-(p2:Person)) where p1.name = 'tom' and p2.name = 'jerry' return p", "fitness": 17 } }, @@ -453,7 +453,7 @@ "name": "Find nodes with relationships", "type": "string_match", "details": { - "query": "match (b) where (b)\u003c-[]-\u003e() return b", + "query": "match (b) where (b)-[]-() return b", "fitness": -4 } }, @@ -461,7 +461,7 @@ "name": "Find nodes with no relationships", "type": "string_match", "details": { - "query": "match (b) where not ((b)\u003c-[]-\u003e()) return b", + "query": "match (b) where not ((b)-[]-()) return b", "fitness": -5 } }, @@ -898,4 +898,4 @@ } } ] -} \ No newline at end of file +} diff --git a/query/neo4j/neo4j_test.go b/query/neo4j/neo4j_test.go index 2efc3435..05117347 100644 --- a/query/neo4j/neo4j_test.go +++ b/query/neo4j/neo4j_test.go @@ -422,7 +422,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where (n)<-[]->() return n")) + ), "match (n) where (n)-[]-() return n")) t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery( query.Where( @@ -436,7 +436,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.OrderBy( query.Order(query.NodeProperty("value"), query.Ascending()), ), - ), "match (n) where (n)<-[]->() return n order by n.value asc")) + ), "match (n) where (n)-[]-() return n order by n.value asc")) t.Run("Node has Relationships Order by Node Item", assertQueryResult(query.SinglePartQuery( query.Where( @@ -451,7 +451,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Order(query.NodeProperty("value_1"), query.Ascending()), query.Order(query.NodeProperty("value_2"), query.Descending()), ), - ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc")) + ), "match (n) where (n)-[]-() return n order by n.value_1 asc, n.value_2 desc")) t.Run("Node has Relationships Order by Node Item with Limit and Offset", assertQueryResult(query.SinglePartQuery( query.Where( @@ -469,7 +469,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Limit(10), query.Offset(20), - ), "match (n) where (n)<-[]->() return n order by n.value_1 asc, n.value_2 desc skip 20 limit 10")) + ), "match (n) where (n)-[]-() return n order by n.value_1 asc, n.value_2 desc skip 20 limit 10")) t.Run("Node has no Relationships", assertQueryResult(query.SinglePartQuery( query.Where( @@ -479,7 +479,7 @@ func TestQueryBuilder_Render(t *testing.T) { query.Returning( query.Node(), ), - ), "match (n) where not ((n)<-[]->()) return n")) + ), "match (n) where not ((n)-[]-()) return n")) t.Run("Node Datetime Before", assertQueryResult(query.SinglePartQuery( query.Where( diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 81eae4e5..2f46dbd7 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -274,9 +274,13 @@ func TestUnicodeCypherSymbols(t *testing.T) { func TestInvalidRelationshipDirectionReturnsError(t *testing.T) { _, err := v2.New().WithRelationshipDirection(graph.Direction(99)).Return(v2.Relationship()).Build() require.ErrorContains(t, err, "unsupported relationship direction: invalid") +} + +func TestRelationshipDirectionBoth(t *testing.T) { + preparedQuery, err := v2.New().WithRelationshipDirection(graph.DirectionBoth).Return(v2.Relationship()).Build() + require.NoError(t, err) - _, err = v2.New().WithRelationshipDirection(graph.DirectionBoth).Return(v2.Relationship()).Build() - require.ErrorContains(t, err, "unsupported relationship direction: both") + require.Equal(t, "match ()-[r]-() return r", renderPrepared(t, preparedQuery)) } func TestShortestPathControls(t *testing.T) { @@ -500,13 +504,17 @@ func TestEmptyLogicalHelpersReturnBuildErrors(t *testing.T) { require.ErrorContains(t, err, "or requires at least one operand") } -func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) { - _, err := v2.New().Create( +func TestExplicitRelationshipPatternDirectionBoth(t *testing.T) { + preparedQuery, err := v2.New().Create( v2.RelationshipPattern(graph.StringKind("Edge"), nil, graph.DirectionBoth), ).Build() - require.ErrorContains(t, err, "unsupported relationship direction: both") + require.NoError(t, err) + + require.Equal(t, "create (s)-[r:Edge]-(e)", renderPrepared(t, preparedQuery)) +} - _, err = v2.New().Create( +func TestInvalidExplicitRelationshipPatternDirectionReturnsError(t *testing.T) { + _, err := v2.New().Create( v2.Relationship().RelationshipPattern(graph.StringKind("Edge"), nil, graph.Direction(99)), ).Build() require.ErrorContains(t, err, "unsupported relationship direction: invalid") diff --git a/query/v2/util.go b/query/v2/util.go index 434bdfb2..03b5fb10 100644 --- a/query/v2/util.go +++ b/query/v2/util.go @@ -95,7 +95,7 @@ func prepareNodePattern(match *cypher.Match, seen *identifierSet, identifiers ru func validateRelationshipDirection(direction graph.Direction) error { switch direction { - case graph.DirectionInbound, graph.DirectionOutbound: + case graph.DirectionInbound, graph.DirectionOutbound, graph.DirectionBoth: return nil default: return fmt.Errorf("unsupported relationship direction: %s", direction) From e44205e2db8238518b493d3c1fa77a46926acb90 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Sat, 9 May 2026 21:55:14 -0700 Subject: [PATCH 49/55] fix: validate v2 sort and split creates --- query/v2/query.go | 59 +++++++++++++++++++++++++++++++++++++++--- query/v2/query_test.go | 44 +++++++++++++++++++++++++++++++ 2 files changed, 99 insertions(+), 4 deletions(-) diff --git a/query/v2/query.go b/query/v2/query.go index cbe03869..66ecb179 100644 --- a/query/v2/query.go +++ b/query/v2/query.go @@ -385,10 +385,24 @@ func Desc(expression any) *cypher.SortItem { return Order(expression, SortDescending) } +func validateSortDirection(direction SortDirection) error { + switch direction { + case SortAscending, SortDescending: + return nil + default: + return fmt.Errorf("unsupported sort direction: %d", direction) + } +} + func Order(expression any, direction SortDirection) *cypher.SortItem { + expressionValue := expressionOrError(expression) + if err := validateSortDirection(direction); err != nil { + expressionValue = invalidExpression(err) + } + return &cypher.SortItem{ Ascending: direction != SortDescending, - Expression: expressionOrError(expression), + Expression: expressionValue, } } @@ -1097,25 +1111,57 @@ func isCreateNodeValue(value any, identifiers runtimeIdentifiers) bool { return false } +func isCreateRelationshipValue(value any) bool { + _, typeOK := value.(*cypher.RelationshipPattern) + return typeOK +} + func nextCreateValueIsNode(creates []any, idx int, identifiers runtimeIdentifiers) bool { nextIdx := idx + 1 return nextIdx < len(creates) && isCreateNodeValue(creates[nextIdx], identifiers) } +func newCreatePatternPart(createClause *cypher.Create) *cypher.PatternPart { + pattern := &cypher.PatternPart{} + createClause.Pattern = append(createClause.Pattern, pattern) + return pattern +} + +func createPatternHasElements(pattern *cypher.PatternPart) bool { + return pattern != nil && len(pattern.PatternElements) > 0 +} + +func shouldStartNewCreatePattern(pattern *cypher.PatternPart, nextCreate any, patternClosed bool, identifiers runtimeIdentifiers) bool { + if !createPatternHasElements(pattern) { + return false + } + + if isCreateNodeValue(nextCreate, identifiers) && patternEndsWithNodePattern(pattern) { + return true + } + + return patternClosed && isCreateRelationshipValue(nextCreate) +} + func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeIdentifiers, creates []any) error { if len(creates) == 0 { return nil } var ( - pattern = &cypher.PatternPart{} createClause = &cypher.Create{ - Unique: false, - Pattern: []*cypher.PatternPart{pattern}, + Unique: false, } + pattern = newCreatePatternPart(createClause) + patternClosed bool ) for idx, nextCreate := range creates { + if shouldStartNewCreatePattern(pattern, nextCreate, patternClosed, identifiers) { + pattern = newCreatePatternPart(createClause) + patternClosed = false + } + switch typedNextCreate := nextCreate.(type) { case QualifiedExpression: switch typedExpression := typedNextCreate.qualifier().(type) { @@ -1129,6 +1175,7 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId pattern.AddPatternElements(&cypher.NodePattern{ Variable: cypher.NewVariableWithSymbol(typedExpression.Symbol), }) + patternClosed = false default: return fmt.Errorf("invalid variable reference for create: %s", typedExpression.Symbol) @@ -1144,6 +1191,7 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId } pattern.AddPatternElements(cypher.Copy(typedNextCreate)) + patternClosed = false case *cypher.RelationshipPattern: if err := validateRelationshipPattern(typedNextCreate); err != nil { @@ -1162,6 +1210,9 @@ func buildCreates(singlePartQuery *cypher.SinglePartQuery, identifiers runtimeId pattern.AddPatternElements(&cypher.NodePattern{ Variable: identifiers.End(), }) + patternClosed = true + } else { + patternClosed = false } default: diff --git a/query/v2/query_test.go b/query/v2/query_test.go index 2f46dbd7..61ea0d32 100644 --- a/query/v2/query_test.go +++ b/query/v2/query_test.go @@ -19,6 +19,21 @@ func renderPrepared(t *testing.T, preparedQuery *v2.PreparedQuery) string { return cypherQueryStr } +func firstCreateClause(t *testing.T, preparedQuery *v2.PreparedQuery) *cypher.Create { + t.Helper() + + updatingClauses := preparedQuery.Query.SingleQuery.SinglePartQuery.UpdatingClauses + require.NotEmpty(t, updatingClauses) + + updatingClause, typeOK := updatingClauses[0].(*cypher.UpdatingClause) + require.True(t, typeOK) + + createClause, typeOK := updatingClause.Clause.(*cypher.Create) + require.True(t, typeOK) + + return createClause +} + func TestQuery(t *testing.T) { preparedQuery, err := v2.New().Where( v2.Not(v2.Relationship().Kind().Is(graph.StringKind("test"))), @@ -102,6 +117,28 @@ func TestCreateRelationshipWithExplicitEndpoints(t *testing.T) { }, preparedQuery.Parameters) } +func TestCreateSplitsDisjointNodePatterns(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.NodePattern(graph.Kinds{graph.StringKind("A")}, nil), + v2.NodePattern(graph.Kinds{graph.StringKind("B")}, nil), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (n:A), (n:B)", renderPrepared(t, preparedQuery)) + require.Len(t, firstCreateClause(t, preparedQuery).Pattern, 2) +} + +func TestCreateSplitsBackToBackRelationshipPatterns(t *testing.T) { + preparedQuery, err := v2.New().Create( + v2.RelationshipPattern(graph.StringKind("A"), nil, graph.DirectionOutbound), + v2.RelationshipPattern(graph.StringKind("B"), nil, graph.DirectionOutbound), + ).Build() + require.NoError(t, err) + + require.Equal(t, "create (s)-[r:A]->(e), (s)-[r:B]->(e)", renderPrepared(t, preparedQuery)) + require.Len(t, firstCreateClause(t, preparedQuery).Pattern, 2) +} + func TestCreateNodeReturnDoesNotCreateMatch(t *testing.T) { preparedQuery, err := v2.New().Create( v2.Node().NodePattern(graph.Kinds{graph.StringKind("A")}, v2.NamedParameter("props", map[string]any{"name": "node"})), @@ -532,6 +569,13 @@ func TestProjectionAndOrderHelpers(t *testing.T) { require.Equal(t, "match (n) return distinct id(n) as node_id order by n.name asc, id(n) desc", renderPrepared(t, preparedQuery)) } +func TestInvalidSortDirectionReturnsError(t *testing.T) { + _, err := v2.New().Return(v2.Node()).OrderBy( + v2.Order(v2.Node().Property("name"), v2.SortDirection(99)), + ).Build() + require.ErrorContains(t, err, "unsupported sort direction: 99") +} + func TestPaginationZeroValuesAndNegativeValidation(t *testing.T) { preparedQuery, err := v2.New().Return(v2.Node()).Skip(0).Limit(0).Build() require.NoError(t, err) From a89d4812b106daf65a49c78816c10006ac28358e Mon Sep 17 00:00:00 2001 From: John Hopper Date: Mon, 11 May 2026 08:07:33 -0700 Subject: [PATCH 50/55] feat: add benchmark report formats --- cmd/benchmark/README.md | 9 ++- cmd/benchmark/main.go | 14 ++++- cmd/benchmark/report.go | 111 +++++++++++++++++++++++++++++++++++ cmd/benchmark/report_test.go | 98 +++++++++++++++++++++++++++++++ cmd/benchmark/runner.go | 2 + 5 files changed, 230 insertions(+), 4 deletions(-) create mode 100644 cmd/benchmark/report_test.go diff --git a/cmd/benchmark/README.md b/cmd/benchmark/README.md index 5b07018a..de9ca76b 100644 --- a/cmd/benchmark/README.md +++ b/cmd/benchmark/README.md @@ -19,6 +19,9 @@ go run ./cmd/benchmark -driver neo4j -connection "neo4j://neo4j:password@localho # Save to file go run ./cmd/benchmark -connection "..." -output report.md + +# Emit benchfmt for benchstat +go run ./cmd/benchmark -connection "..." -format benchfmt -output report.bench ``` ## Flags @@ -31,7 +34,11 @@ go run ./cmd/benchmark -connection "..." -output report.md | `-dataset` | | Run only this dataset | | `-local-dataset` | | Add a local dataset to the default set | | `-dataset-dir` | `integration/testdata` | Path to testdata directory | -| `-output` | stdout | Markdown output file | +| `-format` | `markdown` | Output format (`markdown`, `json`, `benchfmt`) | +| `-output` | stdout | Output file | + +Use `-format benchfmt` when comparing scenario timings with `benchstat`. Each timed scenario iteration is emitted as a +separate `ns/op` sample so two benchmark runs can be compared directly. ## Example: Neo4j on local/phantom diff --git a/cmd/benchmark/main.go b/cmd/benchmark/main.go index dcf73f03..8337ab98 100644 --- a/cmd/benchmark/main.go +++ b/cmd/benchmark/main.go @@ -41,6 +41,7 @@ func main() { connStr = flag.String("connection", "", "database connection string (or CONNECTION_STRING)") iterations = flag.Int("iterations", 10, "timed iterations per scenario") output = flag.String("output", "", "markdown output file (default: stdout)") + format = flag.String("format", reportFormatMarkdown, "output format (markdown, json, benchfmt)") datasetDir = flag.String("dataset-dir", "integration/testdata", "path to testdata directory") localDataset = flag.String("local-dataset", "", "additional local dataset (e.g. local/phantom)") onlyDataset = flag.String("dataset", "", "run only this dataset (e.g. diamond, local/phantom)") @@ -49,6 +50,13 @@ func main() { flag.Parse() + if *iterations < 1 { + fatal("iterations must be at least 1") + } + if !isReportFormat(*format) { + fatal("unsupported output format %q", *format) + } + conn := *connStr if conn == "" { conn = os.Getenv("CONNECTION_STRING") @@ -153,7 +161,7 @@ func main() { } } - // Write markdown + // Write report var mdOut *os.File if *output != "" { var err error @@ -166,8 +174,8 @@ func main() { mdOut = os.Stdout } - if err := writeMarkdown(mdOut, report); err != nil { - fatal("failed to write markdown: %v", err) + if err := writeReport(mdOut, report, *format); err != nil { + fatal("failed to write report: %v", err) } if *output != "" { diff --git a/cmd/benchmark/report.go b/cmd/benchmark/report.go index a440c8ba..0f67cc60 100644 --- a/cmd/benchmark/report.go +++ b/cmd/benchmark/report.go @@ -17,9 +17,19 @@ package main import ( + "encoding/json" "fmt" "io" + "runtime" + "strings" "time" + "unicode" +) + +const ( + reportFormatBenchfmt = "benchfmt" + reportFormatJSON = "json" + reportFormatMarkdown = "markdown" ) // Report holds all benchmark results and metadata. @@ -31,6 +41,36 @@ type Report struct { Results []Result } +func writeReport(w io.Writer, r Report, format string) error { + if !isReportFormat(format) { + return fmt.Errorf("unsupported output format %q", format) + } + + switch format { + case reportFormatBenchfmt: + return writeBenchfmt(w, r) + case reportFormatJSON: + return writeJSON(w, r) + default: + return writeMarkdown(w, r) + } +} + +func isReportFormat(format string) bool { + switch format { + case reportFormatBenchfmt, reportFormatJSON, reportFormatMarkdown: + return true + default: + return false + } +} + +func writeJSON(w io.Writer, r Report) error { + encoder := json.NewEncoder(w) + encoder.SetIndent("", " ") + return encoder.Encode(r) +} + func writeMarkdown(w io.Writer, r Report) error { fmt.Fprintf(w, "# Benchmarks — %s @ %s (%s, %d iterations)\n\n", r.Driver, r.GitRef, r.Date, r.Iterations) fmt.Fprintf(w, "| Query | Dataset | Median | P95 | Max |\n") @@ -55,6 +95,77 @@ func writeMarkdown(w io.Writer, r Report) error { return nil } +func writeBenchfmt(w io.Writer, r Report) error { + goos := runtime.GOOS + goarch := runtime.GOARCH + procs := runtime.GOMAXPROCS(0) + + fmt.Fprintf(w, "goos: %s\n", goos) + fmt.Fprintf(w, "goarch: %s\n", goarch) + fmt.Fprintf(w, "pkg: github.com/specterops/dawgs/cmd/benchmark\n") + + for _, res := range r.Results { + benchName := benchName(r.Driver, res) + + for _, sample := range res.Samples { + fmt.Fprintf(w, "%s-%d\t1\t%d ns/op\n", benchName, procs, sample.Nanoseconds()) + } + } + + return nil +} + +func benchName(driver string, res Result) string { + parts := []string{ + "BenchmarkDawgsIntegration", + sanitizeBenchNamePart(driver), + sanitizeBenchNamePart(res.Dataset), + sanitizeBenchNamePart(res.Section), + sanitizeBenchNamePart(res.Label), + } + + return strings.Join(parts, "/") +} + +func sanitizeBenchNamePart(value string) string { + var builder strings.Builder + lastUnderscore := false + + for _, char := range value { + switch { + case char == '/' || char == '-' || char == '_': + if char == '_' { + if !lastUnderscore { + builder.WriteRune(char) + } + lastUnderscore = true + } else { + builder.WriteRune(char) + lastUnderscore = false + } + case unicode.IsLetter(char) || unicode.IsDigit(char): + builder.WriteRune(char) + lastUnderscore = false + case unicode.IsSpace(char): + if !lastUnderscore { + builder.WriteByte('_') + } + lastUnderscore = true + default: + if !lastUnderscore { + builder.WriteByte('_') + } + lastUnderscore = true + } + } + + if builder.Len() == 0 { + return "unknown" + } + + return builder.String() +} + func fmtDuration(d time.Duration) string { ms := float64(d.Microseconds()) / 1000.0 if ms < 1 { diff --git a/cmd/benchmark/report_test.go b/cmd/benchmark/report_test.go new file mode 100644 index 00000000..225ab6d0 --- /dev/null +++ b/cmd/benchmark/report_test.go @@ -0,0 +1,98 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWriteReportRejectsUnknownFormat(t *testing.T) { + err := writeReport(&bytes.Buffer{}, Report{}, "xml") + require.ErrorContains(t, err, "unsupported output format") +} + +func TestWriteJSON(t *testing.T) { + report := testReport() + var out bytes.Buffer + + require.NoError(t, writeReport(&out, report, reportFormatJSON)) + + require.Contains(t, out.String(), `"Driver": "pg"`) + require.Contains(t, out.String(), `"Samples": [`) + require.Contains(t, out.String(), `1000000`) +} + +func TestWriteBenchfmt(t *testing.T) { + report := testReport() + var out bytes.Buffer + + require.NoError(t, writeReport(&out, report, reportFormatBenchfmt)) + + output := out.String() + require.Contains(t, output, "goos: ") + require.Contains(t, output, "goarch: ") + require.Contains(t, output, "pkg: github.com/specterops/dawgs/cmd/benchmark") + require.Contains(t, output, "BenchmarkDawgsIntegration/pg/base/Match_Nodes/base-") + require.Contains(t, output, "\t1\t1000000 ns/op") + require.Contains(t, output, "\t1\t2000000 ns/op") +} + +func TestSanitizeBenchNamePart(t *testing.T) { + require.Equal(t, "Shortest_Paths", sanitizeBenchNamePart("Shortest Paths")) + require.Equal(t, "n1_-_n3", sanitizeBenchNamePart("n1 -> n3")) + require.Equal(t, "local/phantom", sanitizeBenchNamePart("local/phantom")) + require.Equal(t, "unknown", sanitizeBenchNamePart("")) +} + +func TestWriteMarkdownOmitsSamples(t *testing.T) { + report := testReport() + var out bytes.Buffer + + require.NoError(t, writeReport(&out, report, reportFormatMarkdown)) + + output := out.String() + require.Contains(t, output, "| Match Nodes | base | 2.0ms | 2.0ms | 2.0ms |") + require.False(t, strings.Contains(output, "1000000")) +} + +func testReport() Report { + return Report{ + Driver: "pg", + GitRef: "abcdef0", + Date: "2026-05-11", + Iterations: 2, + Results: []Result{{ + Section: "Match Nodes", + Dataset: "base", + Label: "base", + Stats: Stats{ + Median: 2 * time.Millisecond, + P95: 2 * time.Millisecond, + Max: 2 * time.Millisecond, + }, + Samples: []time.Duration{ + time.Millisecond, + 2 * time.Millisecond, + }, + }}, + } +} diff --git a/cmd/benchmark/runner.go b/cmd/benchmark/runner.go index 52772d20..ba2d5709 100644 --- a/cmd/benchmark/runner.go +++ b/cmd/benchmark/runner.go @@ -37,6 +37,7 @@ type Result struct { Dataset string Label string Stats Stats + Samples []time.Duration } // runScenario executes a scenario N times and returns timing stats. @@ -61,6 +62,7 @@ func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations Dataset: s.Dataset, Label: s.Label, Stats: computeStats(durations), + Samples: durations, }, nil } From 22baa58d281bdabc84d9602c5cd2adc76ee74fa4 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Mon, 11 May 2026 08:15:35 -0700 Subject: [PATCH 51/55] feat: add benchmark diff harness --- .gitignore | 3 + Makefile | 15 ++ README.md | 23 +++ cmd/benchdiff/README.md | 49 ++++++ cmd/benchdiff/benchfmt.go | 217 ++++++++++++++++++++++++++ cmd/benchdiff/benchfmt_test.go | 101 ++++++++++++ cmd/benchdiff/command.go | 91 +++++++++++ cmd/benchdiff/compare.go | 276 +++++++++++++++++++++++++++++++++ cmd/benchdiff/main.go | 180 +++++++++++++++++++++ cmd/benchdiff/markdown.go | 76 +++++++++ cmd/benchdiff/report.go | 123 +++++++++++++++ cmd/benchdiff/run.go | 259 +++++++++++++++++++++++++++++++ 12 files changed, 1413 insertions(+) create mode 100644 cmd/benchdiff/README.md create mode 100644 cmd/benchdiff/benchfmt.go create mode 100644 cmd/benchdiff/benchfmt_test.go create mode 100644 cmd/benchdiff/command.go create mode 100644 cmd/benchdiff/compare.go create mode 100644 cmd/benchdiff/main.go create mode 100644 cmd/benchdiff/markdown.go create mode 100644 cmd/benchdiff/report.go create mode 100644 cmd/benchdiff/run.go diff --git a/.gitignore b/.gitignore index cc490c22..71483163 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,6 @@ # Local integration test datasets integration/testdata/local/ + +# Local benchmark comparison output +.bench/ diff --git a/Makefile b/Makefile index 336ff233..b24ecd80 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,12 @@ THIS_FILE := $(lastword $(MAKEFILE_LIST)) # Go configuration GO_CMD ?= go CGO_ENABLED ?= 0 +BENCH ?= . +BENCH_COUNT ?= 10 +BENCH_TIME ?= 1s +BENCH_BASE ?= main +BENCH_TARGET ?= HEAD +BENCH_KIND ?= all # Main packages to test/build MAIN_PACKAGES := $(shell $(GO_CMD) list ./...) @@ -45,6 +51,14 @@ test_integration: @echo "Running all integration tests..." @$(GO_CMD) test -tags 'manual_integration integration' -race -cover -count=1 -p=1 -parallel=1 $(MAIN_PACKAGES) +test_bench: + @echo "Running benchmarks..." + @$(GO_CMD) test -run '^$$' -bench '$(BENCH)' -benchmem -count=$(BENCH_COUNT) -benchtime=$(BENCH_TIME) $(MAIN_PACKAGES) + +bench_diff: + @echo "Running benchmark diff..." + @$(GO_CMD) run ./cmd/benchdiff --base '$(BENCH_BASE)' --target '$(BENCH_TARGET)' --kind '$(BENCH_KIND)' --bench '$(BENCH)' --bench-count '$(BENCH_COUNT)' --benchtime '$(BENCH_TIME)' $(BENCHDIFF_ARGS) + test_neo4j: @echo "Running Neo4j integration tests..." @$(GO_CMD) test -tags integration -race -cover -count=1 -p=1 -parallel=1 $(MAIN_PACKAGES) @@ -96,6 +110,7 @@ help: @echo " test_all - Run all tests including integration tests" @echo " test_integration - Run all integration tests" @echo " test_bench - Run benchmark test" + @echo " bench_diff - Compare benchmarks between commits" @echo " test_neo4j - Run Neo4j integration tests" @echo " test_pg - Run PostgreSQL integration tests" @echo " test_update - Update test cases" diff --git a/README.md b/README.md index baf9a5a2..3d0724ae 100644 --- a/README.md +++ b/README.md @@ -55,3 +55,26 @@ export CONNECTION_STRING="neo4j://neo4j:weneedbetterpasswords@localhost:7687" ``` Use `make test` for unit tests only and `make test_integration` for integration tests only. + +### Benchmarking + +Run the package benchmark suite with: + +```bash +make test_bench +``` + +Use `cmd/benchdiff` to compare benchmarks between two committed refs without changing the active worktree: + +```bash +go run ./cmd/benchdiff -base main -target HEAD -kind unit +``` + +For integration benchmark comparisons, provide the same `CONNECTION_STRING` used by integration tests: + +```bash +export CONNECTION_STRING="postgresql://dawgs:weneedbetterpasswords@localhost:65432/dawgs" +go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg -fail-regression 10% +``` + +The harness writes raw outputs and a Markdown report under `.bench/runs/` by default. diff --git a/cmd/benchdiff/README.md b/cmd/benchdiff/README.md new file mode 100644 index 00000000..82cb2734 --- /dev/null +++ b/cmd/benchdiff/README.md @@ -0,0 +1,49 @@ +# Benchdiff + +Compares the existing benchmark suites between two committed git refs without changing the active worktree. + +## Usage + +```bash +# Unit Go benchmarks only +go run ./cmd/benchdiff -base main -target HEAD -kind unit + +# Unit and integration benchmarks +export CONNECTION_STRING="postgresql://dawgs:weneedbetterpasswords@localhost:65432/dawgs" +go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg + +# Fail if a benchmark median regresses by more than 10% +go run ./cmd/benchdiff -base main -target HEAD -kind unit -fail-regression 10% +``` + +`benchdiff` creates detached worktrees under `.bench/`, runs each selected benchmark suite, writes raw output, and +produces a Markdown report. Worktrees are removed by default after the run; pass `-keep-worktrees` to preserve them. + +## Flags + +| Flag | Default | Description | +|------|---------|-------------| +| `-base` | `main` | Base git ref | +| `-target` | `HEAD` | Target git ref | +| `-kind` | `all` | Benchmark kind (`all`, `unit`, `integration`) | +| `-packages` | `./...` | Package list for Go benchmarks | +| `-bench` | `.` | Go benchmark regexp | +| `-bench-count` | `10` | Go benchmark repetition count | +| `-benchtime` | `1s` | Go benchmark benchtime | +| `-driver` | `pg` | Integration benchmark database driver | +| `-connection` | | Integration connection string (or `CONNECTION_STRING`) | +| `-dataset` | | Run only this integration dataset | +| `-local-dataset` | | Add a local integration dataset | +| `-dataset-dir` | `integration/testdata` | Integration testdata directory | +| `-integration-iterations` | `10` | Timed iterations per integration scenario | +| `-out` | `.bench/runs/..-` | Output directory | +| `-benchstat` | `auto` | `benchstat` command, `auto`, or `none` | +| `-fail-regression` | `0` | Median regression percentage that fails the command | +| `-keep-worktrees` | `false` | Preserve temporary worktrees | + +If `benchstat` is not on `PATH` and `-benchstat auto` is used, the harness falls back to +`go run golang.org/x/perf/cmd/benchstat@latest`. + +Integration comparisons use native `cmd/benchmark -format benchfmt` when both refs support it. If either ref predates +that flag, the harness runs both refs in Markdown compatibility mode and compares each scenario's median as a single +`ns/op` sample. diff --git a/cmd/benchdiff/benchfmt.go b/cmd/benchdiff/benchfmt.go new file mode 100644 index 00000000..818cb1a3 --- /dev/null +++ b/cmd/benchdiff/benchfmt.go @@ -0,0 +1,217 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bufio" + "bytes" + "fmt" + "io" + "math" + "os" + "regexp" + "runtime" + "sort" + "strconv" + "strings" + "time" + "unicode" +) + +var benchmarkLinePattern = regexp.MustCompile(`^(Benchmark\S+)\s+\d+\s+([0-9]+(?:\.[0-9]+)?)\s+ns/op\b`) + +type benchmarkSamples map[string][]float64 + +type regression struct { + Name string + BaseMedianNS float64 + TargetMedianNS float64 + Percent float64 +} + +func parseBenchfmtNS(data []byte) benchmarkSamples { + samples := benchmarkSamples{} + scanner := bufio.NewScanner(bytes.NewReader(data)) + + for scanner.Scan() { + matches := benchmarkLinePattern.FindStringSubmatch(scanner.Text()) + if len(matches) != 3 { + continue + } + + ns, err := strconv.ParseFloat(matches[2], 64) + if err != nil { + continue + } + + samples[matches[1]] = append(samples[matches[1]], ns) + } + + return samples +} + +func parseBenchfmtNSFile(path string) (benchmarkSamples, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, err + } + + return parseBenchfmtNS(data), nil +} + +func findRegressions(base, target benchmarkSamples, threshold float64) []regression { + if threshold <= 0 { + return nil + } + + var regressions []regression + for name, baseValues := range base { + targetValues := target[name] + if len(baseValues) == 0 || len(targetValues) == 0 { + continue + } + + baseMedian := median(baseValues) + targetMedian := median(targetValues) + if baseMedian <= 0 { + continue + } + + percent := ((targetMedian - baseMedian) / baseMedian) * 100 + if percent > threshold { + regressions = append(regressions, regression{ + Name: name, + BaseMedianNS: baseMedian, + TargetMedianNS: targetMedian, + Percent: percent, + }) + } + } + + sort.Slice(regressions, func(i, j int) bool { + return regressions[i].Percent > regressions[j].Percent + }) + + return regressions +} + +func median(values []float64) float64 { + sorted := append([]float64(nil), values...) + sort.Float64s(sorted) + + mid := len(sorted) / 2 + if len(sorted)%2 == 0 { + return (sorted[mid-1] + sorted[mid]) / 2 + } + + return sorted[mid] +} + +func writeIntegrationBenchfmt(w io.Writer, driver string, rows []markdownBenchmarkRow) error { + fmt.Fprintf(w, "goos: %s\n", runtime.GOOS) + fmt.Fprintf(w, "goarch: %s\n", runtime.GOARCH) + fmt.Fprintln(w, "pkg: github.com/specterops/dawgs/cmd/benchmark") + + procs := runtime.GOMAXPROCS(0) + for _, row := range rows { + fmt.Fprintf(w, "%s-%d\t1\t%d ns/op\n", integrationBenchmarkName(driver, row.Dataset, row.Query), procs, row.Median.Nanoseconds()) + } + + return nil +} + +func integrationBenchmarkName(driver, dataset, query string) string { + return strings.Join([]string{ + "BenchmarkDawgsIntegration", + sanitizeBenchNamePart(driver), + sanitizeBenchNamePart(dataset), + sanitizeBenchNamePart(query), + }, "/") +} + +func sanitizeBenchNamePart(value string) string { + var builder strings.Builder + lastUnderscore := false + + for _, char := range value { + switch { + case char == '/' || char == '-' || char == '_': + if char == '_' { + if !lastUnderscore { + builder.WriteRune(char) + } + lastUnderscore = true + } else { + builder.WriteRune(char) + lastUnderscore = false + } + case unicode.IsLetter(char) || unicode.IsDigit(char): + builder.WriteRune(char) + lastUnderscore = false + case unicode.IsSpace(char): + if !lastUnderscore { + builder.WriteByte('_') + } + lastUnderscore = true + default: + if !lastUnderscore { + builder.WriteByte('_') + } + lastUnderscore = true + } + } + + if builder.Len() == 0 { + return "unknown" + } + + return builder.String() +} + +func parseBenchmarkDuration(value string) (time.Duration, error) { + trimmed := strings.TrimSpace(value) + if trimmed == "" || trimmed == "-" { + return 0, fmt.Errorf("empty benchmark duration") + } + + unitStart := len(trimmed) + for idx, char := range trimmed { + if (char < '0' || char > '9') && char != '.' { + unitStart = idx + break + } + } + + number, err := strconv.ParseFloat(strings.TrimSpace(trimmed[:unitStart]), 64) + if err != nil { + return 0, err + } + + unit := strings.TrimSpace(trimmed[unitStart:]) + switch unit { + case "ns": + return time.Duration(math.Round(number)), nil + case "us": + return time.Duration(math.Round(number * float64(time.Microsecond))), nil + case "ms": + return time.Duration(math.Round(number * float64(time.Millisecond))), nil + case "s": + return time.Duration(math.Round(number * float64(time.Second))), nil + default: + return 0, fmt.Errorf("unsupported benchmark duration unit %q", unit) + } +} diff --git a/cmd/benchdiff/benchfmt_test.go b/cmd/benchdiff/benchfmt_test.go new file mode 100644 index 00000000..1e37b91e --- /dev/null +++ b/cmd/benchdiff/benchfmt_test.go @@ -0,0 +1,101 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestParseBenchfmtNS(t *testing.T) { + samples := parseBenchfmtNS([]byte(` +goos: linux +BenchmarkThing-12 10 100.5 ns/op 1 B/op +BenchmarkThing-12 10 120 ns/op 1 B/op +BenchmarkOther/sub-12 1 200 ns/op +`)) + + require.Equal(t, []float64{100.5, 120}, samples["BenchmarkThing-12"]) + require.Equal(t, []float64{200}, samples["BenchmarkOther/sub-12"]) +} + +func TestFindRegressions(t *testing.T) { + base := benchmarkSamples{ + "BenchmarkFast-12": {100, 110, 120}, + "BenchmarkSame-12": {100, 100, 100}, + } + target := benchmarkSamples{ + "BenchmarkFast-12": {140, 150, 160}, + "BenchmarkSame-12": {105, 105, 105}, + } + + regressions := findRegressions(base, target, 10) + + require.Len(t, regressions, 1) + require.Equal(t, "BenchmarkFast-12", regressions[0].Name) + require.InDelta(t, 36.36, regressions[0].Percent, 0.01) +} + +func TestParseBenchmarkMarkdown(t *testing.T) { + rows := parseBenchmarkMarkdown([]byte(` +| Query | Dataset | Median | P95 | Max | +|-------|---------|-------:|----:|----:| +| Match Nodes | base | 0.14ms | 0.22ms | 0.31ms | +| Match Edges | base | 464ms | 604ms | 604ms | +`)) + + require.Equal(t, []markdownBenchmarkRow{ + {Query: "Match Nodes", Dataset: "base", Median: 140 * time.Microsecond}, + {Query: "Match Edges", Dataset: "base", Median: 464 * time.Millisecond}, + }, rows) +} + +func TestWriteIntegrationBenchfmt(t *testing.T) { + var out bytes.Buffer + rows := []markdownBenchmarkRow{{ + Query: "Shortest Paths / n1 -> n3", + Dataset: "base", + Median: time.Millisecond, + }} + + require.NoError(t, writeIntegrationBenchfmt(&out, "pg", rows)) + require.Contains(t, out.String(), "BenchmarkDawgsIntegration/pg/base/Shortest_Paths_/_n1_-_n3-") + require.Contains(t, out.String(), "\t1\t1000000 ns/op") +} + +func TestParseRegressionThreshold(t *testing.T) { + threshold, err := parseRegressionThreshold("10%") + require.NoError(t, err) + require.Equal(t, 10.0, threshold) + + threshold, err = parseRegressionThreshold("2.5") + require.NoError(t, err) + require.Equal(t, 2.5, threshold) + + _, err = parseRegressionThreshold("-1") + require.Error(t, err) +} + +func TestValidateBenchtime(t *testing.T) { + require.NoError(t, validateBenchtime("1s")) + require.NoError(t, validateBenchtime("100x")) + require.Error(t, validateBenchtime("0x")) + require.Error(t, validateBenchtime("soon")) +} diff --git a/cmd/benchdiff/command.go b/cmd/benchdiff/command.go new file mode 100644 index 00000000..eae51723 --- /dev/null +++ b/cmd/benchdiff/command.go @@ -0,0 +1,91 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "strings" +) + +func gitOutput(ctx context.Context, dir string, args ...string) (string, error) { + output, err := runCommand(ctx, dir, nil, "git", args...) + if err != nil { + return "", err + } + + return strings.TrimSpace(string(output)), nil +} + +func runCommand(ctx context.Context, dir string, env []string, name string, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, name, args...) + cmd.Dir = dir + cmd.Env = os.Environ() + if len(env) > 0 { + cmd.Env = append(cmd.Env, env...) + } + + output, err := cmd.CombinedOutput() + if err != nil { + return output, commandError{ + Name: name, + Args: args, + Err: err, + Output: output, + } + } + + return output, nil +} + +type commandError struct { + Name string + Args []string + Err error + Output []byte +} + +func (err commandError) Error() string { + var builder strings.Builder + builder.WriteString(err.Name) + if len(err.Args) > 0 { + builder.WriteByte(' ') + builder.WriteString(strings.Join(err.Args, " ")) + } + builder.WriteString(": ") + builder.WriteString(err.Err.Error()) + + output := bytes.TrimSpace(err.Output) + if len(output) > 0 { + builder.WriteString("\n") + builder.Write(outputTail(output, 4096)) + } + + return builder.String() +} + +func outputTail(output []byte, limit int) []byte { + if len(output) <= limit { + return output + } + + prefix := []byte(fmt.Sprintf("... truncated %d bytes ...\n", len(output)-limit)) + return append(prefix, output[len(output)-limit:]...) +} diff --git a/cmd/benchdiff/compare.go b/cmd/benchdiff/compare.go new file mode 100644 index 00000000..3a5c99f2 --- /dev/null +++ b/cmd/benchdiff/compare.go @@ -0,0 +1,276 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "context" + "fmt" + "os" + "os/exec" + "path/filepath" + "strconv" + "strings" +) + +func runUnitComparison(ctx context.Context, cfg resolvedConfig, baseWorktree, targetWorktree string) (comparison, error) { + outDir := filepath.Join(cfg.OutDirAbs, "unit") + if err := os.MkdirAll(outDir, 0755); err != nil { + return comparison{}, err + } + + baseFile := filepath.Join(outDir, "base.txt") + targetFile := filepath.Join(outDir, "target.txt") + + if err := runGoBenchmarks(ctx, cfg, baseWorktree, baseFile); err != nil { + return comparison{}, err + } + if err := runGoBenchmarks(ctx, cfg, targetWorktree, targetFile); err != nil { + return comparison{}, err + } + + benchstatOutput, err := runBenchstat(ctx, cfg, baseFile, targetFile) + if err != nil { + return comparison{}, err + } + benchstatFile := filepath.Join(outDir, "benchstat.txt") + if err := os.WriteFile(benchstatFile, benchstatOutput, 0644); err != nil { + return comparison{}, err + } + + regressions, err := regressionsForFiles(baseFile, targetFile, cfg.Threshold) + if err != nil { + return comparison{}, err + } + + return comparison{ + Name: "Unit Benchmarks", + BaseFile: baseFile, + TargetFile: targetFile, + BenchstatFile: benchstatFile, + Benchstat: string(benchstatOutput), + Regressions: regressions, + }, nil +} + +func runGoBenchmarks(ctx context.Context, cfg resolvedConfig, worktree, outputPath string) error { + args := []string{ + "test", + "-run", "^$", + "-bench", cfg.Bench, + "-benchmem", + "-count", strconv.Itoa(cfg.BenchCount), + "-benchtime", cfg.Benchtime, + } + args = append(args, strings.Fields(cfg.Packages)...) + + output, err := runCommand(ctx, worktree, nil, "go", args...) + if writeErr := os.WriteFile(outputPath, output, 0644); writeErr != nil { + return writeErr + } + if err != nil { + return fmt.Errorf("run Go benchmarks in %s: %w", worktree, err) + } + + return nil +} + +func runIntegrationComparison(ctx context.Context, cfg resolvedConfig, baseWorktree, targetWorktree string) (comparison, error) { + outDir := filepath.Join(cfg.OutDirAbs, "integration") + binDir := filepath.Join(cfg.OutDirAbs, "bin") + if err := os.MkdirAll(outDir, 0755); err != nil { + return comparison{}, err + } + if err := os.MkdirAll(binDir, 0755); err != nil { + return comparison{}, err + } + + baseBinary := filepath.Join(binDir, "benchmark-base") + targetBinary := filepath.Join(binDir, "benchmark-target") + if err := buildBenchmarkBinary(ctx, baseWorktree, baseBinary); err != nil { + return comparison{}, err + } + if err := buildBenchmarkBinary(ctx, targetWorktree, targetBinary); err != nil { + return comparison{}, err + } + + baseSupportsBenchfmt := benchmarkBinarySupportsFormat(ctx, baseBinary) + targetSupportsBenchfmt := benchmarkBinarySupportsFormat(ctx, targetBinary) + useNativeBenchfmt := baseSupportsBenchfmt && targetSupportsBenchfmt + + baseFile := filepath.Join(outDir, "base.bench") + targetFile := filepath.Join(outDir, "target.bench") + var notes []string + if useNativeBenchfmt { + if err := runBenchmarkBinary(ctx, cfg, baseWorktree, baseBinary, baseFile, true); err != nil { + return comparison{}, err + } + if err := runBenchmarkBinary(ctx, cfg, targetWorktree, targetBinary, targetFile, true); err != nil { + return comparison{}, err + } + notes = append(notes, "Used native benchfmt output from cmd/benchmark.") + } else { + baseMarkdown := filepath.Join(outDir, "base.md") + targetMarkdown := filepath.Join(outDir, "target.md") + + if err := runBenchmarkBinary(ctx, cfg, baseWorktree, baseBinary, baseMarkdown, false); err != nil { + return comparison{}, err + } + if err := runBenchmarkBinary(ctx, cfg, targetWorktree, targetBinary, targetMarkdown, false); err != nil { + return comparison{}, err + } + + if err := markdownFileToBenchfmt(baseMarkdown, baseFile, cfg.Driver); err != nil { + return comparison{}, err + } + if err := markdownFileToBenchfmt(targetMarkdown, targetFile, cfg.Driver); err != nil { + return comparison{}, err + } + + notes = append(notes, "Used Markdown compatibility mode because at least one ref does not support cmd/benchmark -format benchfmt.") + } + + benchstatOutput, err := runBenchstat(ctx, cfg, baseFile, targetFile) + if err != nil { + return comparison{}, err + } + benchstatFile := filepath.Join(outDir, "benchstat.txt") + if err := os.WriteFile(benchstatFile, benchstatOutput, 0644); err != nil { + return comparison{}, err + } + + regressions, err := regressionsForFiles(baseFile, targetFile, cfg.Threshold) + if err != nil { + return comparison{}, err + } + + return comparison{ + Name: "Integration Benchmarks", + BaseFile: baseFile, + TargetFile: targetFile, + BenchstatFile: benchstatFile, + Benchstat: string(benchstatOutput), + Notes: notes, + Regressions: regressions, + }, nil +} + +func buildBenchmarkBinary(ctx context.Context, worktree, outputPath string) error { + output, err := runCommand(ctx, worktree, nil, "go", "build", "-o", outputPath, "./cmd/benchmark") + if err != nil { + return fmt.Errorf("build cmd/benchmark in %s: %w", worktree, err) + } + if len(bytes.TrimSpace(output)) > 0 { + logPath := outputPath + ".log" + if writeErr := os.WriteFile(logPath, output, 0644); writeErr != nil { + return writeErr + } + } + + return nil +} + +func benchmarkBinarySupportsFormat(ctx context.Context, binaryPath string) bool { + output, err := runCommand(ctx, "", nil, binaryPath, "-h") + if err != nil { + return false + } + + return bytes.Contains(output, []byte("-format")) +} + +func runBenchmarkBinary(ctx context.Context, cfg resolvedConfig, worktree, binaryPath, outputPath string, benchfmt bool) error { + args := []string{ + "-driver", cfg.Driver, + "-iterations", strconv.Itoa(cfg.IntegrationIterations), + "-dataset-dir", cfg.DatasetDirAbs, + "-output", outputPath, + } + if benchfmt { + args = append(args, "-format", "benchfmt") + } + if cfg.Dataset != "" { + args = append(args, "-dataset", cfg.Dataset) + } + if cfg.LocalDataset != "" { + args = append(args, "-local-dataset", cfg.LocalDataset) + } + + output, err := runCommand(ctx, worktree, []string{"CONNECTION_STRING=" + cfg.Connection}, binaryPath, args...) + logPath := outputPath + ".log" + if writeErr := os.WriteFile(logPath, output, 0644); writeErr != nil { + return writeErr + } + if err != nil { + return fmt.Errorf("run integration benchmark in %s: %w", worktree, err) + } + + return nil +} + +func markdownFileToBenchfmt(markdownPath, benchfmtPath, driver string) error { + data, err := os.ReadFile(markdownPath) + if err != nil { + return err + } + + var output bytes.Buffer + if err := writeIntegrationBenchfmt(&output, driver, parseBenchmarkMarkdown(data)); err != nil { + return err + } + + return os.WriteFile(benchfmtPath, output.Bytes(), 0644) +} + +func runBenchstat(ctx context.Context, cfg resolvedConfig, baseFile, targetFile string) ([]byte, error) { + if cfg.Benchstat == "none" { + return []byte("benchstat skipped\n"), nil + } + + if cfg.Benchstat == "" || cfg.Benchstat == "auto" { + if benchstatPath, err := exec.LookPath("benchstat"); err == nil { + return runCommand(ctx, cfg.Root, nil, benchstatPath, baseFile, targetFile) + } + + return runCommand(ctx, cfg.Root, nil, "go", "run", "golang.org/x/perf/cmd/benchstat@latest", baseFile, targetFile) + } + + fields := strings.Fields(cfg.Benchstat) + if len(fields) == 0 { + return nil, fmt.Errorf("empty benchstat command") + } + + args := append(fields[1:], baseFile, targetFile) + return runCommand(ctx, cfg.Root, nil, fields[0], args...) +} + +func regressionsForFiles(baseFile, targetFile string, threshold float64) ([]regression, error) { + if threshold <= 0 { + return nil, nil + } + + base, err := parseBenchfmtNSFile(baseFile) + if err != nil { + return nil, err + } + target, err := parseBenchfmtNSFile(targetFile) + if err != nil { + return nil, err + } + + return findRegressions(base, target, threshold), nil +} diff --git a/cmd/benchdiff/main.go b/cmd/benchdiff/main.go new file mode 100644 index 00000000..df921705 --- /dev/null +++ b/cmd/benchdiff/main.go @@ -0,0 +1,180 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "flag" + "fmt" + "os" + "strconv" + "strings" + "time" +) + +const ( + benchKindAll = "all" + benchKindIntegration = "integration" + benchKindUnit = "unit" +) + +type config struct { + BaseRef string + TargetRef string + Kind string + Packages string + Bench string + BenchCount int + Benchtime string + Driver string + Connection string + Dataset string + LocalDataset string + DatasetDir string + IntegrationIterations int + OutDir string + Benchstat string + FailRegression string + KeepWorktrees bool +} + +type resolvedConfig struct { + config + Root string + BaseSHA string + TargetSHA string + BaseShortSHA string + TargetShortSHA string + DatasetDirAbs string + OutDirAbs string + Threshold float64 +} + +func main() { + cfg, err := parseConfig(os.Args[1:]) + if err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(2) + } + + if err := run(context.Background(), cfg); err != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } +} + +func parseConfig(args []string) (config, error) { + cfg := config{} + flags := flag.NewFlagSet("benchdiff", flag.ContinueOnError) + flags.SetOutput(os.Stderr) + + flags.StringVar(&cfg.BaseRef, "base", "main", "base git ref to benchmark") + flags.StringVar(&cfg.TargetRef, "target", "HEAD", "target git ref to benchmark") + flags.StringVar(&cfg.Kind, "kind", benchKindAll, "benchmark kind: all, unit, integration") + flags.StringVar(&cfg.Packages, "packages", "./...", "package list for Go benchmarks") + flags.StringVar(&cfg.Bench, "bench", ".", "Go benchmark regexp") + flags.IntVar(&cfg.BenchCount, "bench-count", 10, "Go benchmark repetition count") + flags.StringVar(&cfg.Benchtime, "benchtime", "1s", "Go benchmark benchtime") + flags.StringVar(&cfg.Driver, "driver", "pg", "integration benchmark database driver") + flags.StringVar(&cfg.Connection, "connection", "", "integration database connection string (or CONNECTION_STRING)") + flags.StringVar(&cfg.Dataset, "dataset", "", "run only this integration dataset") + flags.StringVar(&cfg.LocalDataset, "local-dataset", "", "additional local integration dataset") + flags.StringVar(&cfg.DatasetDir, "dataset-dir", "integration/testdata", "integration testdata directory") + flags.IntVar(&cfg.IntegrationIterations, "integration-iterations", 10, "timed iterations per integration scenario") + flags.StringVar(&cfg.OutDir, "out", "", "output directory") + flags.StringVar(&cfg.Benchstat, "benchstat", "auto", "benchstat command, auto, or none") + flags.StringVar(&cfg.FailRegression, "fail-regression", "0", "fail when median ns/op regression exceeds this percent, e.g. 10%") + flags.BoolVar(&cfg.KeepWorktrees, "keep-worktrees", false, "keep temporary git worktrees") + + if err := flags.Parse(args); err != nil { + return config{}, err + } + if flags.NArg() != 0 { + return config{}, fmt.Errorf("unexpected positional arguments: %s", strings.Join(flags.Args(), " ")) + } + if !isBenchKind(cfg.Kind) { + return config{}, fmt.Errorf("unsupported benchmark kind %q", cfg.Kind) + } + if cfg.BenchCount < 1 { + return config{}, fmt.Errorf("bench-count must be at least 1") + } + if cfg.IntegrationIterations < 1 { + return config{}, fmt.Errorf("integration-iterations must be at least 1") + } + if err := validateBenchtime(cfg.Benchtime); err != nil { + return config{}, err + } + if _, err := parseRegressionThreshold(cfg.FailRegression); err != nil { + return config{}, err + } + + return cfg, nil +} + +func isBenchKind(kind string) bool { + switch kind { + case benchKindAll, benchKindIntegration, benchKindUnit: + return true + default: + return false + } +} + +func (cfg config) runsUnitBenchmarks() bool { + return cfg.Kind == benchKindAll || cfg.Kind == benchKindUnit +} + +func (cfg config) runsIntegrationBenchmarks() bool { + return cfg.Kind == benchKindAll || cfg.Kind == benchKindIntegration +} + +func parseRegressionThreshold(value string) (float64, error) { + trimmed := strings.TrimSpace(value) + trimmed = strings.TrimSuffix(trimmed, "%") + + if trimmed == "" { + return 0, fmt.Errorf("fail-regression must be a non-negative percent") + } + + threshold, err := strconv.ParseFloat(trimmed, 64) + if err != nil { + return 0, fmt.Errorf("invalid fail-regression %q: %w", value, err) + } + if threshold < 0 { + return 0, fmt.Errorf("fail-regression must be a non-negative percent") + } + + return threshold, nil +} + +func validateBenchtime(value string) error { + trimmed := strings.TrimSpace(value) + if strings.HasSuffix(trimmed, "x") { + count, err := strconv.Atoi(strings.TrimSuffix(trimmed, "x")) + if err != nil || count < 1 { + return fmt.Errorf("invalid benchtime %q", value) + } + + return nil + } + + if _, err := time.ParseDuration(trimmed); err != nil { + return fmt.Errorf("invalid benchtime %q: %w", value, err) + } + + return nil +} diff --git a/cmd/benchdiff/markdown.go b/cmd/benchdiff/markdown.go new file mode 100644 index 00000000..b472b3b3 --- /dev/null +++ b/cmd/benchdiff/markdown.go @@ -0,0 +1,76 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bufio" + "bytes" + "strings" + "time" +) + +type markdownBenchmarkRow struct { + Query string + Dataset string + Median time.Duration +} + +func parseBenchmarkMarkdown(data []byte) []markdownBenchmarkRow { + var rows []markdownBenchmarkRow + scanner := bufio.NewScanner(bytes.NewReader(data)) + + for scanner.Scan() { + columns := splitMarkdownTableRow(scanner.Text()) + if len(columns) < 5 { + continue + } + if columns[0] == "Query" || strings.HasPrefix(columns[0], "---") { + continue + } + + median, err := parseBenchmarkDuration(columns[2]) + if err != nil { + continue + } + + rows = append(rows, markdownBenchmarkRow{ + Query: columns[0], + Dataset: columns[1], + Median: median, + }) + } + + return rows +} + +func splitMarkdownTableRow(line string) []string { + trimmed := strings.TrimSpace(line) + if !strings.HasPrefix(trimmed, "|") || !strings.HasSuffix(trimmed, "|") { + return nil + } + + trimmed = strings.TrimPrefix(trimmed, "|") + trimmed = strings.TrimSuffix(trimmed, "|") + + rawColumns := strings.Split(trimmed, "|") + columns := make([]string, 0, len(rawColumns)) + for _, column := range rawColumns { + columns = append(columns, strings.TrimSpace(column)) + } + + return columns +} diff --git a/cmd/benchdiff/report.go b/cmd/benchdiff/report.go new file mode 100644 index 00000000..43d77ec5 --- /dev/null +++ b/cmd/benchdiff/report.go @@ -0,0 +1,123 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "bytes" + "fmt" + "os" + "path/filepath" + "runtime" + "time" +) + +func writeRunReport(path string, summary runSummary) error { + var out bytes.Buffer + cfg := summary.Config + + fmt.Fprintln(&out, "# Benchmark Diff") + fmt.Fprintln(&out) + fmt.Fprintln(&out, "| Field | Value |") + fmt.Fprintln(&out, "|-------|-------|") + fmt.Fprintf(&out, "| Base | `%s` (`%s`) |\n", cfg.BaseRef, cfg.BaseShortSHA) + fmt.Fprintf(&out, "| Target | `%s` (`%s`) |\n", cfg.TargetRef, cfg.TargetShortSHA) + fmt.Fprintf(&out, "| Started | %s |\n", summary.StartedAt.UTC().Format(time.RFC3339)) + fmt.Fprintf(&out, "| Finished | %s |\n", summary.FinishedAt.UTC().Format(time.RFC3339)) + fmt.Fprintf(&out, "| Go | %s |\n", summary.GoVersion) + fmt.Fprintf(&out, "| Platform | %s/%s |\n", runtime.GOOS, runtime.GOARCH) + fmt.Fprintf(&out, "| Kind | `%s` |\n", cfg.Kind) + fmt.Fprintf(&out, "| Output | `%s` |\n", cfg.OutDirAbs) + if cfg.runsIntegrationBenchmarks() { + fmt.Fprintf(&out, "| Driver | `%s` |\n", cfg.Driver) + fmt.Fprintf(&out, "| Dataset Dir | `%s` |\n", cfg.DatasetDirAbs) + fmt.Fprintf(&out, "| Integration Iterations | %d |\n", cfg.IntegrationIterations) + } + if cfg.runsUnitBenchmarks() { + fmt.Fprintf(&out, "| Packages | `%s` |\n", cfg.Packages) + fmt.Fprintf(&out, "| Bench | `%s` |\n", cfg.Bench) + fmt.Fprintf(&out, "| Bench Count | %d |\n", cfg.BenchCount) + fmt.Fprintf(&out, "| Benchtime | `%s` |\n", cfg.Benchtime) + } + if cfg.Threshold > 0 { + fmt.Fprintf(&out, "| Regression Failure Threshold | %.2f%% |\n", cfg.Threshold) + } + fmt.Fprintln(&out) + + for _, comparison := range summary.Comparisons { + fmt.Fprintf(&out, "## %s\n\n", comparison.Name) + for _, note := range comparison.Notes { + fmt.Fprintf(&out, "- %s\n", note) + } + if len(comparison.Notes) > 0 { + fmt.Fprintln(&out) + } + + fmt.Fprintf(&out, "- Base raw: `%s`\n", relOrAbs(cfg.OutDirAbs, comparison.BaseFile)) + fmt.Fprintf(&out, "- Target raw: `%s`\n", relOrAbs(cfg.OutDirAbs, comparison.TargetFile)) + fmt.Fprintf(&out, "- Benchstat: `%s`\n\n", relOrAbs(cfg.OutDirAbs, comparison.BenchstatFile)) + + fmt.Fprintln(&out, "```text") + fmt.Fprint(&out, comparison.Benchstat) + if len(comparison.Benchstat) == 0 || comparison.Benchstat[len(comparison.Benchstat)-1] != '\n' { + fmt.Fprintln(&out) + } + fmt.Fprintln(&out, "```") + fmt.Fprintln(&out) + + writeRegressionSection(&out, comparison, cfg.Threshold) + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + + return os.WriteFile(path, out.Bytes(), 0644) +} + +func writeRegressionSection(out *bytes.Buffer, comparison comparison, threshold float64) { + if threshold <= 0 { + return + } + + fmt.Fprintf(out, "### Regressions Over %.2f%%\n\n", threshold) + if len(comparison.Regressions) == 0 { + fmt.Fprintln(out, "None.") + fmt.Fprintln(out) + return + } + + fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change |") + fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|") + for _, regression := range comparison.Regressions { + fmt.Fprintf(out, "| `%s` | %.0f ns/op | %.0f ns/op | +%.2f%% |\n", + regression.Name, + regression.BaseMedianNS, + regression.TargetMedianNS, + regression.Percent, + ) + } + fmt.Fprintln(out) +} + +func relOrAbs(base, path string) string { + rel, err := filepath.Rel(base, path) + if err != nil || rel == "." || len(rel) >= len(path) { + return path + } + + return rel +} diff --git a/cmd/benchdiff/run.go b/cmd/benchdiff/run.go new file mode 100644 index 00000000..c8f43137 --- /dev/null +++ b/cmd/benchdiff/run.go @@ -0,0 +1,259 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "context" + "fmt" + "os" + "path/filepath" + "runtime" + "strings" + "time" +) + +type comparison struct { + Name string + BaseFile string + TargetFile string + BenchstatFile string + Benchstat string + Notes []string + Regressions []regression +} + +type runSummary struct { + Config resolvedConfig + GoVersion string + StartedAt time.Time + FinishedAt time.Time + Comparisons []comparison + ReportPath string +} + +func run(ctx context.Context, cfg config) error { + resolved, err := resolveConfig(ctx, cfg) + if err != nil { + return err + } + + summary := runSummary{ + Config: resolved, + GoVersion: runtime.Version(), + StartedAt: time.Now(), + } + + if err := os.MkdirAll(resolved.OutDirAbs, 0755); err != nil { + return fmt.Errorf("create output directory: %w", err) + } + + worktreeRoot := filepath.Join(resolved.OutDirAbs, "worktrees") + baseWorktree := filepath.Join(worktreeRoot, "base") + targetWorktree := filepath.Join(worktreeRoot, "target") + + fmt.Fprintf(os.Stderr, "preparing benchmark worktrees for %s and %s...\n", resolved.BaseShortSHA, resolved.TargetShortSHA) + + if err := addWorktree(ctx, resolved.Root, baseWorktree, resolved.BaseSHA); err != nil { + return err + } + removeBase := true + defer func() { + if removeBase && !resolved.KeepWorktrees { + _ = removeWorktree(context.Background(), resolved.Root, baseWorktree) + } + }() + + if err := addWorktree(ctx, resolved.Root, targetWorktree, resolved.TargetSHA); err != nil { + return err + } + removeTarget := true + defer func() { + if removeTarget && !resolved.KeepWorktrees { + _ = removeWorktree(context.Background(), resolved.Root, targetWorktree) + } + }() + + if resolved.runsUnitBenchmarks() { + fmt.Fprintln(os.Stderr, "running unit benchmarks...") + unitComparison, err := runUnitComparison(ctx, resolved, baseWorktree, targetWorktree) + if err != nil { + return err + } + summary.Comparisons = append(summary.Comparisons, unitComparison) + } + + if resolved.runsIntegrationBenchmarks() { + fmt.Fprintln(os.Stderr, "running integration benchmarks...") + integrationComparison, err := runIntegrationComparison(ctx, resolved, baseWorktree, targetWorktree) + if err != nil { + return err + } + summary.Comparisons = append(summary.Comparisons, integrationComparison) + } + + removeBase = false + removeTarget = false + if !resolved.KeepWorktrees { + if err := removeWorktree(ctx, resolved.Root, baseWorktree); err != nil { + return err + } + if err := removeWorktree(ctx, resolved.Root, targetWorktree); err != nil { + return err + } + } + + summary.FinishedAt = time.Now() + summary.ReportPath = filepath.Join(resolved.OutDirAbs, "report.md") + if err := writeRunReport(summary.ReportPath, summary); err != nil { + return err + } + + fmt.Fprintf(os.Stderr, "wrote benchmark diff report: %s\n", summary.ReportPath) + + if regressions := summary.regressions(); len(regressions) > 0 && resolved.Threshold > 0 { + return fmt.Errorf("%d benchmark regressions exceeded %.2f%%; see %s", len(regressions), resolved.Threshold, summary.ReportPath) + } + + return nil +} + +func resolveConfig(ctx context.Context, cfg config) (resolvedConfig, error) { + root, err := gitOutput(ctx, "", "rev-parse", "--show-toplevel") + if err != nil { + return resolvedConfig{}, err + } + + baseSHA, err := resolveCommit(ctx, root, cfg.BaseRef) + if err != nil { + return resolvedConfig{}, err + } + targetSHA, err := resolveCommit(ctx, root, cfg.TargetRef) + if err != nil { + return resolvedConfig{}, err + } + + baseShortSHA, err := shortCommit(ctx, root, baseSHA) + if err != nil { + return resolvedConfig{}, err + } + targetShortSHA, err := shortCommit(ctx, root, targetSHA) + if err != nil { + return resolvedConfig{}, err + } + + threshold, err := parseRegressionThreshold(cfg.FailRegression) + if err != nil { + return resolvedConfig{}, err + } + + if cfg.runsIntegrationBenchmarks() && cfg.Connection == "" { + cfg.Connection = os.Getenv("CONNECTION_STRING") + } + if cfg.runsIntegrationBenchmarks() && cfg.Connection == "" { + return resolvedConfig{}, fmt.Errorf("integration benchmarks require -connection or CONNECTION_STRING") + } + + datasetDirAbs, err := resolvePath(root, cfg.DatasetDir) + if err != nil { + return resolvedConfig{}, err + } + + outDir := cfg.OutDir + if outDir == "" { + outDir = filepath.Join(".bench", "runs", fmt.Sprintf("%s..%s-%s", baseShortSHA, targetShortSHA, time.Now().UTC().Format("20060102T150405Z"))) + } + outDirAbs, err := resolvePath(root, outDir) + if err != nil { + return resolvedConfig{}, err + } + + return resolvedConfig{ + config: cfg, + Root: root, + BaseSHA: baseSHA, + TargetSHA: targetSHA, + BaseShortSHA: baseShortSHA, + TargetShortSHA: targetShortSHA, + DatasetDirAbs: datasetDirAbs, + OutDirAbs: outDirAbs, + Threshold: threshold, + }, nil +} + +func resolvePath(root, value string) (string, error) { + if filepath.IsAbs(value) { + return filepath.Clean(value), nil + } + + return filepath.Abs(filepath.Join(root, value)) +} + +func resolveCommit(ctx context.Context, root, ref string) (string, error) { + sha, err := gitOutput(ctx, root, "rev-parse", "--verify", ref+"^{commit}") + if err != nil { + return "", fmt.Errorf("resolve git ref %q: %w", ref, err) + } + + return sha, nil +} + +func shortCommit(ctx context.Context, root, sha string) (string, error) { + shortSHA, err := gitOutput(ctx, root, "rev-parse", "--short", sha) + if err != nil { + return "", err + } + + return shortSHA, nil +} + +func addWorktree(ctx context.Context, root, path, sha string) error { + if _, err := os.Stat(path); err == nil { + return fmt.Errorf("worktree path already exists: %s", path) + } else if !os.IsNotExist(err) { + return err + } + + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { + return err + } + + _, err := runCommand(ctx, root, nil, "git", "worktree", "add", "--detach", path, sha) + if err != nil { + return fmt.Errorf("add worktree %s: %w", path, err) + } + + return nil +} + +func removeWorktree(ctx context.Context, root, path string) error { + _, err := runCommand(ctx, root, nil, "git", "worktree", "remove", "--force", path) + if err != nil && !strings.Contains(err.Error(), "is not a working tree") { + return fmt.Errorf("remove worktree %s: %w", path, err) + } + + return nil +} + +func (summary runSummary) regressions() []regression { + var regressions []regression + + for _, comparison := range summary.Comparisons { + regressions = append(regressions, comparison.Regressions...) + } + + return regressions +} From ee9fdd934e24eaa9bc1c37fde098a0fa532f6da7 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Mon, 11 May 2026 08:32:52 -0700 Subject: [PATCH 52/55] feat: summarize benchmark diff findings --- README.md | 3 +- cmd/benchdiff/README.md | 4 +- cmd/benchdiff/benchfmt.go | 108 ++++++++++++++++++++++++++----- cmd/benchdiff/benchfmt_test.go | 30 +++++++-- cmd/benchdiff/compare.go | 25 ++----- cmd/benchdiff/report.go | 115 ++++++++++++++++++++++++++++++++- cmd/benchdiff/report_test.go | 101 +++++++++++++++++++++++++++++ cmd/benchdiff/run.go | 1 + 8 files changed, 340 insertions(+), 47 deletions(-) create mode 100644 cmd/benchdiff/report_test.go diff --git a/README.md b/README.md index 3d0724ae..4e049fe3 100644 --- a/README.md +++ b/README.md @@ -77,4 +77,5 @@ export CONNECTION_STRING="postgresql://dawgs:weneedbetterpasswords@localhost:654 go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg -fail-regression 10% ``` -The harness writes raw outputs and a Markdown report under `.bench/runs/` by default. +The harness writes raw outputs and a Markdown report under `.bench/runs/` by default. The report begins with comparison +findings before listing the raw `benchstat` output for each benchmark suite. diff --git a/cmd/benchdiff/README.md b/cmd/benchdiff/README.md index 82cb2734..c898ed80 100644 --- a/cmd/benchdiff/README.md +++ b/cmd/benchdiff/README.md @@ -17,7 +17,9 @@ go run ./cmd/benchdiff -base main -target HEAD -kind unit -fail-regression 10% ``` `benchdiff` creates detached worktrees under `.bench/`, runs each selected benchmark suite, writes raw output, and -produces a Markdown report. Worktrees are removed by default after the run; pass `-keep-worktrees` to preserve them. +produces a Markdown report. The report starts with comparison findings, including median regressions, improvements, +unchanged counts, and benchmark names that only appeared in one ref. Worktrees are removed by default after the run; +pass `-keep-worktrees` to preserve them. ## Flags diff --git a/cmd/benchdiff/benchfmt.go b/cmd/benchdiff/benchfmt.go index 818cb1a3..6ec669ea 100644 --- a/cmd/benchdiff/benchfmt.go +++ b/cmd/benchdiff/benchfmt.go @@ -36,6 +36,22 @@ var benchmarkLinePattern = regexp.MustCompile(`^(Benchmark\S+)\s+\d+\s+([0-9]+(? type benchmarkSamples map[string][]float64 +type comparisonFindings struct { + Compared int + Regressions []benchmarkFinding + Improvements []benchmarkFinding + Unchanged int + OnlyBase []string + OnlyTarget []string +} + +type benchmarkFinding struct { + Name string + BaseMedianNS float64 + TargetMedianNS float64 + DeltaPercent float64 +} + type regression struct { Name string BaseMedianNS float64 @@ -73,15 +89,27 @@ func parseBenchfmtNSFile(path string) (benchmarkSamples, error) { return parseBenchfmtNS(data), nil } -func findRegressions(base, target benchmarkSamples, threshold float64) []regression { - if threshold <= 0 { - return nil +func summarizeFindings(base, target benchmarkSamples) comparisonFindings { + findings := comparisonFindings{} + names := map[string]struct{}{} + + for name := range base { + names[name] = struct{}{} + } + for name := range target { + names[name] = struct{}{} } - var regressions []regression - for name, baseValues := range base { + for name := range names { + baseValues := base[name] targetValues := target[name] - if len(baseValues) == 0 || len(targetValues) == 0 { + + switch { + case len(baseValues) == 0: + findings.OnlyTarget = append(findings.OnlyTarget, name) + continue + case len(targetValues) == 0: + findings.OnlyBase = append(findings.OnlyBase, name) continue } @@ -91,20 +119,68 @@ func findRegressions(base, target benchmarkSamples, threshold float64) []regress continue } - percent := ((targetMedian - baseMedian) / baseMedian) * 100 - if percent > threshold { - regressions = append(regressions, regression{ - Name: name, - BaseMedianNS: baseMedian, - TargetMedianNS: targetMedian, - Percent: percent, - }) + findings.Compared++ + deltaPercent := ((targetMedian - baseMedian) / baseMedian) * 100 + finding := benchmarkFinding{ + Name: name, + BaseMedianNS: baseMedian, + TargetMedianNS: targetMedian, + DeltaPercent: deltaPercent, + } + + switch { + case deltaPercent > 0: + findings.Regressions = append(findings.Regressions, finding) + case deltaPercent < 0: + findings.Improvements = append(findings.Improvements, finding) + default: + findings.Unchanged++ } } - sort.Slice(regressions, func(i, j int) bool { - return regressions[i].Percent > regressions[j].Percent + sort.Slice(findings.Regressions, func(i, j int) bool { + return findings.Regressions[i].DeltaPercent > findings.Regressions[j].DeltaPercent + }) + sort.Slice(findings.Improvements, func(i, j int) bool { + return findings.Improvements[i].DeltaPercent < findings.Improvements[j].DeltaPercent }) + sort.Strings(findings.OnlyBase) + sort.Strings(findings.OnlyTarget) + + return findings +} + +func findingsForFiles(baseFile, targetFile string) (comparisonFindings, error) { + base, err := parseBenchfmtNSFile(baseFile) + if err != nil { + return comparisonFindings{}, err + } + target, err := parseBenchfmtNSFile(targetFile) + if err != nil { + return comparisonFindings{}, err + } + + return summarizeFindings(base, target), nil +} + +func (findings comparisonFindings) regressionsOver(threshold float64) []regression { + if threshold <= 0 { + return nil + } + + var regressions []regression + for _, finding := range findings.Regressions { + if finding.DeltaPercent <= threshold { + continue + } + + regressions = append(regressions, regression{ + Name: finding.Name, + BaseMedianNS: finding.BaseMedianNS, + TargetMedianNS: finding.TargetMedianNS, + Percent: finding.DeltaPercent, + }) + } return regressions } diff --git a/cmd/benchdiff/benchfmt_test.go b/cmd/benchdiff/benchfmt_test.go index 1e37b91e..9cdb1de1 100644 --- a/cmd/benchdiff/benchfmt_test.go +++ b/cmd/benchdiff/benchfmt_test.go @@ -36,20 +36,36 @@ BenchmarkOther/sub-12 1 200 ns/op require.Equal(t, []float64{200}, samples["BenchmarkOther/sub-12"]) } -func TestFindRegressions(t *testing.T) { +func TestSummarizeFindings(t *testing.T) { base := benchmarkSamples{ - "BenchmarkFast-12": {100, 110, 120}, - "BenchmarkSame-12": {100, 100, 100}, + "BenchmarkRegression-12": {100, 110, 120}, + "BenchmarkImprovement-12": {200, 200, 200}, + "BenchmarkSame-12": {100, 100, 100}, + "BenchmarkOnlyBase-12": {50}, } target := benchmarkSamples{ - "BenchmarkFast-12": {140, 150, 160}, - "BenchmarkSame-12": {105, 105, 105}, + "BenchmarkRegression-12": {140, 150, 160}, + "BenchmarkImprovement-12": {100, 100, 100}, + "BenchmarkSame-12": {100, 100, 100}, + "BenchmarkOnlyTarget-12": {75}, } - regressions := findRegressions(base, target, 10) + findings := summarizeFindings(base, target) + require.Equal(t, 3, findings.Compared) + require.Equal(t, 1, findings.Unchanged) + require.Equal(t, []string{"BenchmarkOnlyBase-12"}, findings.OnlyBase) + require.Equal(t, []string{"BenchmarkOnlyTarget-12"}, findings.OnlyTarget) + require.Len(t, findings.Regressions, 1) + require.Equal(t, "BenchmarkRegression-12", findings.Regressions[0].Name) + require.InDelta(t, 36.36, findings.Regressions[0].DeltaPercent, 0.01) + require.Len(t, findings.Improvements, 1) + require.Equal(t, "BenchmarkImprovement-12", findings.Improvements[0].Name) + require.Equal(t, -50.0, findings.Improvements[0].DeltaPercent) + + regressions := findings.regressionsOver(10) require.Len(t, regressions, 1) - require.Equal(t, "BenchmarkFast-12", regressions[0].Name) + require.Equal(t, "BenchmarkRegression-12", regressions[0].Name) require.InDelta(t, 36.36, regressions[0].Percent, 0.01) } diff --git a/cmd/benchdiff/compare.go b/cmd/benchdiff/compare.go index 3a5c99f2..10792f8b 100644 --- a/cmd/benchdiff/compare.go +++ b/cmd/benchdiff/compare.go @@ -52,10 +52,11 @@ func runUnitComparison(ctx context.Context, cfg resolvedConfig, baseWorktree, ta return comparison{}, err } - regressions, err := regressionsForFiles(baseFile, targetFile, cfg.Threshold) + findings, err := findingsForFiles(baseFile, targetFile) if err != nil { return comparison{}, err } + regressions := findings.regressionsOver(cfg.Threshold) return comparison{ Name: "Unit Benchmarks", @@ -63,6 +64,7 @@ func runUnitComparison(ctx context.Context, cfg resolvedConfig, baseWorktree, ta TargetFile: targetFile, BenchstatFile: benchstatFile, Benchstat: string(benchstatOutput), + Findings: findings, Regressions: regressions, }, nil } @@ -153,10 +155,11 @@ func runIntegrationComparison(ctx context.Context, cfg resolvedConfig, baseWorkt return comparison{}, err } - regressions, err := regressionsForFiles(baseFile, targetFile, cfg.Threshold) + findings, err := findingsForFiles(baseFile, targetFile) if err != nil { return comparison{}, err } + regressions := findings.regressionsOver(cfg.Threshold) return comparison{ Name: "Integration Benchmarks", @@ -165,6 +168,7 @@ func runIntegrationComparison(ctx context.Context, cfg resolvedConfig, baseWorkt BenchstatFile: benchstatFile, Benchstat: string(benchstatOutput), Notes: notes, + Findings: findings, Regressions: regressions, }, nil } @@ -257,20 +261,3 @@ func runBenchstat(ctx context.Context, cfg resolvedConfig, baseFile, targetFile args := append(fields[1:], baseFile, targetFile) return runCommand(ctx, cfg.Root, nil, fields[0], args...) } - -func regressionsForFiles(baseFile, targetFile string, threshold float64) ([]regression, error) { - if threshold <= 0 { - return nil, nil - } - - base, err := parseBenchfmtNSFile(baseFile) - if err != nil { - return nil, err - } - target, err := parseBenchfmtNSFile(targetFile) - if err != nil { - return nil, err - } - - return findRegressions(base, target, threshold), nil -} diff --git a/cmd/benchdiff/report.go b/cmd/benchdiff/report.go index 43d77ec5..b8d17488 100644 --- a/cmd/benchdiff/report.go +++ b/cmd/benchdiff/report.go @@ -22,9 +22,12 @@ import ( "os" "path/filepath" "runtime" + "strings" "time" ) +const maxFindingRows = 10 + func writeRunReport(path string, summary runSummary) error { var out bytes.Buffer cfg := summary.Config @@ -57,6 +60,8 @@ func writeRunReport(path string, summary runSummary) error { } fmt.Fprintln(&out) + writeFindingsSummary(&out, summary) + for _, comparison := range summary.Comparisons { fmt.Fprintf(&out, "## %s\n\n", comparison.Name) for _, note := range comparison.Notes { @@ -88,6 +93,110 @@ func writeRunReport(path string, summary runSummary) error { return os.WriteFile(path, out.Bytes(), 0644) } +func writeFindingsSummary(out *bytes.Buffer, summary runSummary) { + fmt.Fprintln(out, "## Findings") + fmt.Fprintln(out) + + if len(summary.Comparisons) == 0 { + fmt.Fprintln(out, "No benchmark comparisons were run.") + fmt.Fprintln(out) + return + } + + for _, comparison := range summary.Comparisons { + findings := comparison.Findings + fmt.Fprintf(out, "### %s\n\n", comparison.Name) + fmt.Fprintf(out, "- Compared %d matching benchmark%s.\n", findings.Compared, pluralSuffix(findings.Compared)) + fmt.Fprintf(out, "- Median regressions: %d; median improvements: %d; unchanged: %d.\n", + len(findings.Regressions), + len(findings.Improvements), + findings.Unchanged, + ) + if len(findings.OnlyBase) > 0 { + fmt.Fprintf(out, "- Only in base: %s.\n", inlineBenchmarkList(findings.OnlyBase, maxFindingRows)) + } + if len(findings.OnlyTarget) > 0 { + fmt.Fprintf(out, "- Only in target: %s.\n", inlineBenchmarkList(findings.OnlyTarget, maxFindingRows)) + } + fmt.Fprintln(out) + + writeFindingTable(out, "Top Median Regressions", findings.Regressions, maxFindingRows) + writeFindingTable(out, "Top Median Improvements", findings.Improvements, maxFindingRows) + } +} + +func writeFindingTable(out *bytes.Buffer, title string, findings []benchmarkFinding, limit int) { + fmt.Fprintf(out, "#### %s\n\n", title) + if len(findings) == 0 { + fmt.Fprintln(out, "None.") + fmt.Fprintln(out) + return + } + + fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change |") + fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|") + + for idx, finding := range findings { + if idx >= limit { + break + } + + fmt.Fprintf(out, "| `%s` | %s | %s | %+.2f%% |\n", + finding.Name, + formatNS(finding.BaseMedianNS), + formatNS(finding.TargetMedianNS), + finding.DeltaPercent, + ) + } + if len(findings) > limit { + fmt.Fprintf(out, "\n_%d more not shown._\n", len(findings)-limit) + } + fmt.Fprintln(out) +} + +func pluralSuffix(count int) string { + if count == 1 { + return "" + } + + return "s" +} + +func inlineBenchmarkList(names []string, limit int) string { + var builder strings.Builder + + for idx, name := range names { + if idx >= limit { + break + } + if idx > 0 { + builder.WriteString(", ") + } + builder.WriteByte('`') + builder.WriteString(name) + builder.WriteByte('`') + } + + if len(names) > limit { + fmt.Fprintf(&builder, ", and %d more", len(names)-limit) + } + + return builder.String() +} + +func formatNS(value float64) string { + switch { + case value >= float64(time.Second): + return fmt.Sprintf("%.2fs", value/float64(time.Second)) + case value >= float64(time.Millisecond): + return fmt.Sprintf("%.2fms", value/float64(time.Millisecond)) + case value >= float64(time.Microsecond): + return fmt.Sprintf("%.2fus", value/float64(time.Microsecond)) + default: + return fmt.Sprintf("%.0fns", value) + } +} + func writeRegressionSection(out *bytes.Buffer, comparison comparison, threshold float64) { if threshold <= 0 { return @@ -103,10 +212,10 @@ func writeRegressionSection(out *bytes.Buffer, comparison comparison, threshold fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change |") fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|") for _, regression := range comparison.Regressions { - fmt.Fprintf(out, "| `%s` | %.0f ns/op | %.0f ns/op | +%.2f%% |\n", + fmt.Fprintf(out, "| `%s` | %s | %s | +%.2f%% |\n", regression.Name, - regression.BaseMedianNS, - regression.TargetMedianNS, + formatNS(regression.BaseMedianNS), + formatNS(regression.TargetMedianNS), regression.Percent, ) } diff --git a/cmd/benchdiff/report_test.go b/cmd/benchdiff/report_test.go new file mode 100644 index 00000000..439d5b16 --- /dev/null +++ b/cmd/benchdiff/report_test.go @@ -0,0 +1,101 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestWriteRunReportIncludesTopLevelFindings(t *testing.T) { + outputPath := filepath.Join(t.TempDir(), "report.md") + summary := runSummary{ + Config: resolvedConfig{ + config: config{ + BaseRef: "main", + TargetRef: "HEAD", + Kind: benchKindUnit, + Packages: "./...", + Bench: ".", + BenchCount: 3, + Benchtime: "1s", + FailRegression: "10%", + }, + BaseShortSHA: "abc1234", + TargetShortSHA: "def5678", + OutDirAbs: filepath.Dir(outputPath), + Threshold: 10, + }, + GoVersion: "go-test", + StartedAt: time.Date(2026, 5, 11, 1, 2, 3, 0, time.UTC), + FinishedAt: time.Date(2026, 5, 11, 1, 2, 4, 0, time.UTC), + Comparisons: []comparison{{ + Name: "Unit Benchmarks", + BaseFile: filepath.Join(filepath.Dir(outputPath), "unit", "base.txt"), + TargetFile: filepath.Join(filepath.Dir(outputPath), "unit", "target.txt"), + BenchstatFile: filepath.Join(filepath.Dir(outputPath), "unit", "benchstat.txt"), + Benchstat: "benchstat output\n", + Findings: comparisonFindings{ + Compared: 3, + Unchanged: 1, + Regressions: []benchmarkFinding{{ + Name: "BenchmarkSlow-12", + BaseMedianNS: 100, + TargetMedianNS: 150, + DeltaPercent: 50, + }}, + Improvements: []benchmarkFinding{{ + Name: "BenchmarkFast-12", + BaseMedianNS: 200, + TargetMedianNS: 100, + DeltaPercent: -50, + }}, + OnlyBase: []string{"BenchmarkRemoved-12"}, + OnlyTarget: []string{"BenchmarkAdded-12"}, + }, + Regressions: []regression{{ + Name: "BenchmarkSlow-12", + BaseMedianNS: 100, + TargetMedianNS: 150, + Percent: 50, + }}, + }}, + } + + require.NoError(t, writeRunReport(outputPath, summary)) + + report, err := os.ReadFile(outputPath) + require.NoError(t, err) + output := string(report) + + require.Contains(t, output, "## Findings") + require.Contains(t, output, "### Unit Benchmarks") + require.Contains(t, output, "- Compared 3 matching benchmarks.") + require.Contains(t, output, "- Median regressions: 1; median improvements: 1; unchanged: 1.") + require.Contains(t, output, "- Only in base: `BenchmarkRemoved-12`.") + require.Contains(t, output, "- Only in target: `BenchmarkAdded-12`.") + require.Contains(t, output, "#### Top Median Regressions") + require.Contains(t, output, "| `BenchmarkSlow-12` | 100ns | 150ns | +50.00% |") + require.Contains(t, output, "#### Top Median Improvements") + require.Contains(t, output, "| `BenchmarkFast-12` | 200ns | 100ns | -50.00% |") + require.Contains(t, output, "## Unit Benchmarks") + require.Contains(t, output, "benchstat output") +} diff --git a/cmd/benchdiff/run.go b/cmd/benchdiff/run.go index c8f43137..12afb66a 100644 --- a/cmd/benchdiff/run.go +++ b/cmd/benchdiff/run.go @@ -33,6 +33,7 @@ type comparison struct { BenchstatFile string Benchstat string Notes []string + Findings comparisonFindings Regressions []regression } From d97dffa938cc82c75444fed0405279d2bc38990a Mon Sep 17 00:00:00 2001 From: John Hopper Date: Mon, 11 May 2026 08:41:32 -0700 Subject: [PATCH 53/55] feat: validate benchmark scenario row counts --- cmd/benchmark/runner.go | 24 +++++++++++- cmd/benchmark/scenarios.go | 49 ++++++++++++++---------- cmd/benchmark/scenarios_test.go | 67 +++++++++++++++++++++++++++++++++ 3 files changed, 118 insertions(+), 22 deletions(-) create mode 100644 cmd/benchmark/scenarios_test.go diff --git a/cmd/benchmark/runner.go b/cmd/benchmark/runner.go index ba2d5709..62d287b7 100644 --- a/cmd/benchmark/runner.go +++ b/cmd/benchmark/runner.go @@ -18,6 +18,7 @@ package main import ( "context" + "fmt" "sort" "time" @@ -43,7 +44,7 @@ type Result struct { // runScenario executes a scenario N times and returns timing stats. func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations int) (Result, error) { // Warm-up: one untimed run. - if err := db.ReadTransaction(ctx, s.Query); err != nil { + if err := runScenarioOnce(ctx, db, s); err != nil { return Result{}, err } @@ -51,7 +52,7 @@ func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations for i := range iterations { start := time.Now() - if err := db.ReadTransaction(ctx, s.Query); err != nil { + if err := runScenarioOnce(ctx, db, s); err != nil { return Result{}, err } durations[i] = time.Since(start) @@ -66,6 +67,25 @@ func runScenario(ctx context.Context, db graph.Database, s Scenario, iterations }, nil } +func runScenarioOnce(ctx context.Context, db graph.Database, s Scenario) error { + return db.ReadTransaction(ctx, func(tx graph.Transaction) error { + rows, err := s.Query(tx) + if err != nil { + return err + } + + return validateScenarioRows(s, rows) + }) +} + +func validateScenarioRows(s Scenario, actualRows int) error { + if s.ExpectedRows == nil || *s.ExpectedRows == actualRows { + return nil + } + + return fmt.Errorf("%s/%s on %s expected %d rows, got %d", s.Section, s.Label, s.Dataset, *s.ExpectedRows, actualRows) +} + func computeStats(durations []time.Duration) Stats { sort.Slice(durations, func(i, j int) bool { return durations[i] < durations[j] }) diff --git a/cmd/benchmark/scenarios.go b/cmd/benchmark/scenarios.go index 217ae63d..091d11ee 100644 --- a/cmd/benchmark/scenarios.go +++ b/cmd/benchmark/scenarios.go @@ -25,10 +25,11 @@ import ( // Scenario defines a single benchmark query to run against a loaded dataset. type Scenario struct { - Section string // grouping key in the report (e.g. "Match Nodes") - Dataset string - Label string // human-readable row label - Query func(tx graph.Transaction) error + Section string // grouping key in the report (e.g. "Match Nodes") + Dataset string + Label string // human-readable row label + ExpectedRows *int + Query func(tx graph.Transaction) (int, error) } // defaultDatasets is the set of datasets committed to the repo. @@ -46,23 +47,31 @@ func scenariosForDataset(dataset string, idMap opengraph.IDMap) []Scenario { } } -func countNodes(tx graph.Transaction) error { - _, err := tx.Nodes().Count() - return err +func expectRows(rows int) *int { + return &rows } -func countEdges(tx graph.Transaction) error { - _, err := tx.Relationships().Count() - return err +func countNodes(tx graph.Transaction) (int, error) { + count, err := tx.Nodes().Count() + return int(count), err } -func cypherQuery(cypher string) func(tx graph.Transaction) error { - return func(tx graph.Transaction) error { +func countEdges(tx graph.Transaction) (int, error) { + count, err := tx.Relationships().Count() + return int(count), err +} + +func cypherQuery(cypher string) func(tx graph.Transaction) (int, error) { + return func(tx graph.Transaction) (int, error) { result := tx.Query(cypher, nil) defer result.Close() + + rows := 0 for result.Next() { + rows++ } - return result.Error() + + return rows, result.Error() } } @@ -71,22 +80,22 @@ func cypherQuery(cypher string) func(tx graph.Transaction) error { func baseScenarios(idMap opengraph.IDMap) []Scenario { ds := "base" return []Scenario{ - {Section: "Match Nodes", Dataset: ds, Label: ds, Query: countNodes}, - {Section: "Match Edges", Dataset: ds, Label: ds, Query: countEdges}, - {Section: "Shortest Paths", Dataset: ds, Label: "n1 -> n3", Query: cypherQuery(fmt.Sprintf( + {Section: "Match Nodes", Dataset: ds, Label: ds, ExpectedRows: expectRows(3), Query: countNodes}, + {Section: "Match Edges", Dataset: ds, Label: ds, ExpectedRows: expectRows(2), Query: countEdges}, + {Section: "Shortest Paths", Dataset: ds, Label: "n1 -> n3", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf( "MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p", idMap["n1"], idMap["n3"], ))}, - {Section: "Traversal", Dataset: ds, Label: "n1", Query: cypherQuery(fmt.Sprintf( + {Section: "Traversal", Dataset: ds, Label: "n1", ExpectedRows: expectRows(2), Query: cypherQuery(fmt.Sprintf( "MATCH (s)-[*1..]->(e) WHERE id(s) = %d RETURN e", idMap["n1"], ))}, - {Section: "Match Return", Dataset: ds, Label: "n1", Query: cypherQuery(fmt.Sprintf( + {Section: "Match Return", Dataset: ds, Label: "n1", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf( "MATCH (s)-[]->(e) WHERE id(s) = %d RETURN e", idMap["n1"], ))}, - {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind1", Query: cypherQuery("MATCH (n:NodeKind1) RETURN n")}, - {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind2", Query: cypherQuery("MATCH (n:NodeKind2) RETURN n")}, + {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind1", ExpectedRows: expectRows(2), Query: cypherQuery("MATCH (n:NodeKind1) RETURN n")}, + {Section: "Filter By Kind", Dataset: ds, Label: "NodeKind2", ExpectedRows: expectRows(2), Query: cypherQuery("MATCH (n:NodeKind2) RETURN n")}, } } diff --git a/cmd/benchmark/scenarios_test.go b/cmd/benchmark/scenarios_test.go new file mode 100644 index 00000000..3bd6a450 --- /dev/null +++ b/cmd/benchmark/scenarios_test.go @@ -0,0 +1,67 @@ +// Copyright 2026 Specter Ops, Inc. +// +// Licensed under the Apache License, Version 2.0 +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// SPDX-License-Identifier: Apache-2.0 + +package main + +import ( + "testing" + + "github.com/specterops/dawgs/graph" + "github.com/specterops/dawgs/opengraph" + "github.com/stretchr/testify/require" +) + +func TestBaseScenariosDeclareExpectedRows(t *testing.T) { + scenarios := baseScenarios(opengraph.IDMap{ + "n1": graph.ID(1), + "n2": graph.ID(2), + "n3": graph.ID(3), + }) + + requireExpectedRows(t, scenarios, "Match Nodes", "base", 3) + requireExpectedRows(t, scenarios, "Match Edges", "base", 2) + requireExpectedRows(t, scenarios, "Shortest Paths", "n1 -> n3", 1) + requireExpectedRows(t, scenarios, "Traversal", "n1", 2) + requireExpectedRows(t, scenarios, "Match Return", "n1", 1) + requireExpectedRows(t, scenarios, "Filter By Kind", "NodeKind1", 2) + requireExpectedRows(t, scenarios, "Filter By Kind", "NodeKind2", 2) +} + +func TestValidateScenarioRows(t *testing.T) { + scenario := Scenario{ + Section: "Traversal", + Dataset: "base", + Label: "n1", + ExpectedRows: expectRows(2), + } + + require.NoError(t, validateScenarioRows(scenario, 2)) + require.ErrorContains(t, validateScenarioRows(scenario, 1), "Traversal/n1 on base expected 2 rows, got 1") +} + +func requireExpectedRows(t *testing.T, scenarios []Scenario, section, label string, expectedRows int) { + t.Helper() + + for _, scenario := range scenarios { + if scenario.Section == section && scenario.Label == label { + require.NotNil(t, scenario.ExpectedRows) + require.Equal(t, expectedRows, *scenario.ExpectedRows) + return + } + } + + require.Failf(t, "scenario not found", "%s/%s", section, label) +} From 2a6b6a8dee7cc77eaa85dc623d676afa08a06a5a Mon Sep 17 00:00:00 2001 From: John Hopper Date: Mon, 11 May 2026 08:43:10 -0700 Subject: [PATCH 54/55] feat: add traversal shape benchmarks --- README.md | 4 + cmd/benchmark/README.md | 9 ++- cmd/benchmark/scenarios.go | 64 ++++++++++++++- cmd/benchmark/scenarios_test.go | 53 ++++++++++++ integration/testdata/traversal_shapes.json | 94 ++++++++++++++++++++++ 5 files changed, 222 insertions(+), 2 deletions(-) create mode 100644 integration/testdata/traversal_shapes.json diff --git a/README.md b/README.md index 4e049fe3..84d02f2a 100644 --- a/README.md +++ b/README.md @@ -79,3 +79,7 @@ go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg -fail-regres The harness writes raw outputs and a Markdown report under `.bench/runs/` by default. The report begins with comparison findings before listing the raw `benchstat` output for each benchmark suite. + +The integration benchmark runner includes committed `base` and `traversal_shapes` datasets by default. The traversal +shape suite checks expected result counts for chain, fanout, bounded cycle, disconnected, edge-kind-selective, and +multi-path shortest-path scenarios before recording timings. diff --git a/cmd/benchmark/README.md b/cmd/benchmark/README.md index de9ca76b..fa9362c5 100644 --- a/cmd/benchmark/README.md +++ b/cmd/benchmark/README.md @@ -5,9 +5,12 @@ Runs query scenarios against a real database and outputs a markdown timing table ## Usage ```bash -# Default dataset (base) +# Default datasets (base and traversal_shapes) go run ./cmd/benchmark -connection "postgresql://dawgs:dawgs@localhost:5432/dawgs" +# Traversal shape dataset only +go run ./cmd/benchmark -connection "..." -dataset traversal_shapes + # Local dataset (not committed to repo) go run ./cmd/benchmark -connection "..." -dataset local/phantom @@ -40,6 +43,10 @@ go run ./cmd/benchmark -connection "..." -format benchfmt -output report.bench Use `-format benchfmt` when comparing scenario timings with `benchstat`. Each timed scenario iteration is emitted as a separate `ns/op` sample so two benchmark runs can be compared directly. +The committed default datasets are `base` and `traversal_shapes`. `traversal_shapes` covers chain, fanout, bounded +cycle, disconnected, edge-kind-selective, and multi-path shortest-path traversal shapes. Scenarios with declared +expected row counts fail before reporting timings if a query returns the wrong result shape. + ## Example: Neo4j on local/phantom ``` diff --git a/cmd/benchmark/scenarios.go b/cmd/benchmark/scenarios.go index 091d11ee..c445d966 100644 --- a/cmd/benchmark/scenarios.go +++ b/cmd/benchmark/scenarios.go @@ -32,14 +32,18 @@ type Scenario struct { Query func(tx graph.Transaction) (int, error) } +const traversalShapesDataset = "traversal_shapes" + // defaultDatasets is the set of datasets committed to the repo. -var defaultDatasets = []string{"base"} +var defaultDatasets = []string{"base", traversalShapesDataset} // scenariosForDataset returns all benchmark scenarios for a given dataset and its loaded ID map. func scenariosForDataset(dataset string, idMap opengraph.IDMap) []Scenario { switch dataset { case "base": return baseScenarios(idMap) + case traversalShapesDataset: + return traversalShapesScenarios(idMap) case "local/phantom": return phantomScenarios(idMap) default: @@ -99,6 +103,64 @@ func baseScenarios(idMap opengraph.IDMap) []Scenario { } } +// --- Traversal shape scenarios --- + +func traversalShapesScenarios(idMap opengraph.IDMap) []Scenario { + ds := traversalShapesDataset + return []Scenario{ + {Section: "Match Nodes", Dataset: ds, Label: ds, ExpectedRows: expectRows(45), Query: countNodes}, + {Section: "Match Edges", Dataset: ds, Label: ds, ExpectedRows: expectRows(41), Query: countEdges}, + {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 1", ExpectedRows: expectRows(1), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:ChainEdge*1..1]->(e) WHERE id(s) = %d RETURN e", + idMap["c0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 3", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:ChainEdge*1..3]->(e) WHERE id(s) = %d RETURN e", + idMap["c0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "chain depth 10", ExpectedRows: expectRows(10), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:ChainEdge*1..10]->(e) WHERE id(s) = %d RETURN e", + idMap["c0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 1", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:FanoutEdge*1..1]->(e) WHERE id(s) = %d RETURN e", + idMap["f0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 2", ExpectedRows: expectRows(9), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:FanoutEdge*1..2]->(e) WHERE id(s) = %d RETURN e", + idMap["f0"], + ))}, + {Section: "Traversal Depth", Dataset: ds, Label: "fanout depth 3", ExpectedRows: expectRows(15), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:FanoutEdge*1..3]->(e) WHERE id(s) = %d RETURN e", + idMap["f0"], + ))}, + {Section: "Traversal Cycle", Dataset: ds, Label: "bounded cycle", ExpectedRows: expectRows(4), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:CycleEdge*1..4]->(e) WHERE id(s) = %d RETURN e", + idMap["y0"], + ))}, + {Section: "Traversal Dead End", Dataset: ds, Label: "chain terminal", ExpectedRows: expectRows(0), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:ChainEdge*1..]->(e) WHERE id(s) = %d RETURN e", + idMap["c10"], + ))}, + {Section: "Edge Kind Traversal", Dataset: ds, Label: "Allowed", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[:Allowed*1..]->(e) WHERE id(s) = %d RETURN e", + idMap["s0"], + ))}, + {Section: "Edge Kind Traversal", Dataset: ds, Label: "all kinds", ExpectedRows: expectRows(6), Query: cypherQuery(fmt.Sprintf( + "MATCH (s)-[*1..]->(e) WHERE id(s) = %d RETURN e", + idMap["s0"], + ))}, + {Section: "Shortest Paths", Dataset: ds, Label: "diamond many paths", ExpectedRows: expectRows(3), Query: cypherQuery(fmt.Sprintf( + "MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p", + idMap["d0"], idMap["d4"], + ))}, + {Section: "Shortest Paths", Dataset: ds, Label: "disconnected", ExpectedRows: expectRows(0), Query: cypherQuery(fmt.Sprintf( + "MATCH p = allShortestPaths((s)-[*1..]->(e)) WHERE id(s) = %d AND id(e) = %d RETURN p", + idMap["x0"], idMap["x1"], + ))}, + } +} + // --- Phantom scenarios (hardcoded node IDs from the dataset) --- func phantomScenarios(idMap opengraph.IDMap) []Scenario { diff --git a/cmd/benchmark/scenarios_test.go b/cmd/benchmark/scenarios_test.go index 3bd6a450..647ee8ae 100644 --- a/cmd/benchmark/scenarios_test.go +++ b/cmd/benchmark/scenarios_test.go @@ -17,6 +17,7 @@ package main import ( + "os" "testing" "github.com/specterops/dawgs/graph" @@ -40,6 +41,40 @@ func TestBaseScenariosDeclareExpectedRows(t *testing.T) { requireExpectedRows(t, scenarios, "Filter By Kind", "NodeKind2", 2) } +func TestTraversalShapesDatasetIsValid(t *testing.T) { + file, err := os.Open("../../integration/testdata/traversal_shapes.json") + require.NoError(t, err) + defer file.Close() + + doc, err := opengraph.ParseDocument(file) + require.NoError(t, err) + require.Len(t, doc.Graph.Nodes, 45) + require.Len(t, doc.Graph.Edges, 41) +} + +func TestTraversalShapesScenariosDeclareExpectedRows(t *testing.T) { + scenarios := traversalShapesScenarios(traversalShapesIDMap()) + + requireExpectedRows(t, scenarios, "Match Nodes", traversalShapesDataset, 45) + requireExpectedRows(t, scenarios, "Match Edges", traversalShapesDataset, 41) + requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 1", 1) + requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 3", 3) + requireExpectedRows(t, scenarios, "Traversal Depth", "chain depth 10", 10) + requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 1", 3) + requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 2", 9) + requireExpectedRows(t, scenarios, "Traversal Depth", "fanout depth 3", 15) + requireExpectedRows(t, scenarios, "Traversal Cycle", "bounded cycle", 4) + requireExpectedRows(t, scenarios, "Traversal Dead End", "chain terminal", 0) + requireExpectedRows(t, scenarios, "Edge Kind Traversal", "Allowed", 3) + requireExpectedRows(t, scenarios, "Edge Kind Traversal", "all kinds", 6) + requireExpectedRows(t, scenarios, "Shortest Paths", "diamond many paths", 3) + requireExpectedRows(t, scenarios, "Shortest Paths", "disconnected", 0) +} + +func TestDefaultDatasetsIncludeTraversalShapes(t *testing.T) { + require.Contains(t, defaultDatasets, traversalShapesDataset) +} + func TestValidateScenarioRows(t *testing.T) { scenario := Scenario{ Section: "Traversal", @@ -52,6 +87,24 @@ func TestValidateScenarioRows(t *testing.T) { require.ErrorContains(t, validateScenarioRows(scenario, 1), "Traversal/n1 on base expected 2 rows, got 1") } +func traversalShapesIDMap() opengraph.IDMap { + ids := []string{ + "c0", "c10", + "f0", + "d0", "d4", + "y0", + "x0", "x1", + "s0", + } + + idMap := opengraph.IDMap{} + for idx, id := range ids { + idMap[id] = graph.ID(idx + 1) + } + + return idMap +} + func requireExpectedRows(t *testing.T, scenarios []Scenario, section, label string, expectedRows int) { t.Helper() diff --git a/integration/testdata/traversal_shapes.json b/integration/testdata/traversal_shapes.json new file mode 100644 index 00000000..d1041096 --- /dev/null +++ b/integration/testdata/traversal_shapes.json @@ -0,0 +1,94 @@ +{ + "graph": { + "nodes": [ + {"id": "c0", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c1", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c2", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c3", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c4", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c5", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c6", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c7", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c8", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c9", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "c10", "kinds": ["TraversalNode", "ChainNode"]}, + {"id": "f0", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1a", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1b", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2a", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2b", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3a", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3b", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1a1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f1b1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2a1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f2b1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3a1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "f3b1", "kinds": ["TraversalNode", "FanoutNode"]}, + {"id": "d0", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "d1", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "d2", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "d3", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "d4", "kinds": ["TraversalNode", "DiamondNode"]}, + {"id": "y0", "kinds": ["TraversalNode", "CycleNode"]}, + {"id": "y1", "kinds": ["TraversalNode", "CycleNode"]}, + {"id": "y2", "kinds": ["TraversalNode", "CycleNode"]}, + {"id": "y3", "kinds": ["TraversalNode", "CycleNode"]}, + {"id": "x0", "kinds": ["TraversalNode", "DisconnectedNode"]}, + {"id": "x1", "kinds": ["TraversalNode", "DisconnectedNode"]}, + {"id": "s0", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "s1", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "s2", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "s3", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "t1", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "t2", "kinds": ["TraversalNode", "SelectiveNode"]}, + {"id": "t3", "kinds": ["TraversalNode", "SelectiveNode"]} + ], + "edges": [ + {"start_id": "c0", "end_id": "c1", "kind": "ChainEdge"}, + {"start_id": "c1", "end_id": "c2", "kind": "ChainEdge"}, + {"start_id": "c2", "end_id": "c3", "kind": "ChainEdge"}, + {"start_id": "c3", "end_id": "c4", "kind": "ChainEdge"}, + {"start_id": "c4", "end_id": "c5", "kind": "ChainEdge"}, + {"start_id": "c5", "end_id": "c6", "kind": "ChainEdge"}, + {"start_id": "c6", "end_id": "c7", "kind": "ChainEdge"}, + {"start_id": "c7", "end_id": "c8", "kind": "ChainEdge"}, + {"start_id": "c8", "end_id": "c9", "kind": "ChainEdge"}, + {"start_id": "c9", "end_id": "c10", "kind": "ChainEdge"}, + {"start_id": "f0", "end_id": "f1", "kind": "FanoutEdge"}, + {"start_id": "f0", "end_id": "f2", "kind": "FanoutEdge"}, + {"start_id": "f0", "end_id": "f3", "kind": "FanoutEdge"}, + {"start_id": "f1", "end_id": "f1a", "kind": "FanoutEdge"}, + {"start_id": "f1", "end_id": "f1b", "kind": "FanoutEdge"}, + {"start_id": "f2", "end_id": "f2a", "kind": "FanoutEdge"}, + {"start_id": "f2", "end_id": "f2b", "kind": "FanoutEdge"}, + {"start_id": "f3", "end_id": "f3a", "kind": "FanoutEdge"}, + {"start_id": "f3", "end_id": "f3b", "kind": "FanoutEdge"}, + {"start_id": "f1a", "end_id": "f1a1", "kind": "FanoutEdge"}, + {"start_id": "f1b", "end_id": "f1b1", "kind": "FanoutEdge"}, + {"start_id": "f2a", "end_id": "f2a1", "kind": "FanoutEdge"}, + {"start_id": "f2b", "end_id": "f2b1", "kind": "FanoutEdge"}, + {"start_id": "f3a", "end_id": "f3a1", "kind": "FanoutEdge"}, + {"start_id": "f3b", "end_id": "f3b1", "kind": "FanoutEdge"}, + {"start_id": "d0", "end_id": "d1", "kind": "DiamondEdge"}, + {"start_id": "d0", "end_id": "d2", "kind": "DiamondEdge"}, + {"start_id": "d0", "end_id": "d3", "kind": "DiamondEdge"}, + {"start_id": "d1", "end_id": "d4", "kind": "DiamondEdge"}, + {"start_id": "d2", "end_id": "d4", "kind": "DiamondEdge"}, + {"start_id": "d3", "end_id": "d4", "kind": "DiamondEdge"}, + {"start_id": "y0", "end_id": "y1", "kind": "CycleEdge"}, + {"start_id": "y1", "end_id": "y2", "kind": "CycleEdge"}, + {"start_id": "y2", "end_id": "y0", "kind": "CycleEdge"}, + {"start_id": "y2", "end_id": "y3", "kind": "CycleEdge"}, + {"start_id": "s0", "end_id": "s1", "kind": "Allowed"}, + {"start_id": "s1", "end_id": "s2", "kind": "Allowed"}, + {"start_id": "s2", "end_id": "s3", "kind": "Allowed"}, + {"start_id": "s0", "end_id": "t1", "kind": "Blocked"}, + {"start_id": "t1", "end_id": "t2", "kind": "Blocked"}, + {"start_id": "t2", "end_id": "t3", "kind": "Blocked"} + ] + } +} From f847ea641869ddb6622034984bdf10005ff08833 Mon Sep 17 00:00:00 2001 From: John Hopper Date: Mon, 11 May 2026 08:55:10 -0700 Subject: [PATCH 55/55] feat: include all benchmark numbers in diff report --- README.md | 3 +- cmd/benchdiff/README.md | 5 +-- cmd/benchdiff/benchfmt.go | 32 ++++++++++++++++++ cmd/benchdiff/benchfmt_test.go | 26 +++++++++++++++ cmd/benchdiff/report.go | 60 ++++++++++++++++++++++++++++++++++ cmd/benchdiff/report_test.go | 40 +++++++++++++++++++++++ 6 files changed, 163 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 84d02f2a..72b280a5 100644 --- a/README.md +++ b/README.md @@ -78,7 +78,8 @@ go run ./cmd/benchdiff -base main -target HEAD -kind all -driver pg -fail-regres ``` The harness writes raw outputs and a Markdown report under `.bench/runs/` by default. The report begins with comparison -findings before listing the raw `benchstat` output for each benchmark suite. +findings, includes the raw `benchstat` output for each benchmark suite, and ends with a table of all captured benchmark +numbers. The integration benchmark runner includes committed `base` and `traversal_shapes` datasets by default. The traversal shape suite checks expected result counts for chain, fanout, bounded cycle, disconnected, edge-kind-selective, and diff --git a/cmd/benchdiff/README.md b/cmd/benchdiff/README.md index c898ed80..41774194 100644 --- a/cmd/benchdiff/README.md +++ b/cmd/benchdiff/README.md @@ -18,8 +18,9 @@ go run ./cmd/benchdiff -base main -target HEAD -kind unit -fail-regression 10% `benchdiff` creates detached worktrees under `.bench/`, runs each selected benchmark suite, writes raw output, and produces a Markdown report. The report starts with comparison findings, including median regressions, improvements, -unchanged counts, and benchmark names that only appeared in one ref. Worktrees are removed by default after the run; -pass `-keep-worktrees` to preserve them. +unchanged counts, and benchmark names that only appeared in one ref. It ends with an `All Executed Benchmark Numbers` +section that lists the median, percent change, and sample counts for every benchmark captured in either ref. Worktrees +are removed by default after the run; pass `-keep-worktrees` to preserve them. ## Flags diff --git a/cmd/benchdiff/benchfmt.go b/cmd/benchdiff/benchfmt.go index 6ec669ea..cdc5a68d 100644 --- a/cmd/benchdiff/benchfmt.go +++ b/cmd/benchdiff/benchfmt.go @@ -43,6 +43,7 @@ type comparisonFindings struct { Unchanged int OnlyBase []string OnlyTarget []string + Results []benchmarkResult } type benchmarkFinding struct { @@ -52,6 +53,17 @@ type benchmarkFinding struct { DeltaPercent float64 } +type benchmarkResult struct { + Name string + BaseMedianNS float64 + TargetMedianNS float64 + DeltaPercent float64 + BaseSamples int + TargetSamples int + HasBase bool + HasTarget bool +} + type regression struct { Name string BaseMedianNS float64 @@ -103,12 +115,23 @@ func summarizeFindings(base, target benchmarkSamples) comparisonFindings { for name := range names { baseValues := base[name] targetValues := target[name] + result := benchmarkResult{ + Name: name, + BaseSamples: len(baseValues), + TargetSamples: len(targetValues), + HasBase: len(baseValues) > 0, + HasTarget: len(targetValues) > 0, + } switch { case len(baseValues) == 0: + result.TargetMedianNS = median(targetValues) + findings.Results = append(findings.Results, result) findings.OnlyTarget = append(findings.OnlyTarget, name) continue case len(targetValues) == 0: + result.BaseMedianNS = median(baseValues) + findings.Results = append(findings.Results, result) findings.OnlyBase = append(findings.OnlyBase, name) continue } @@ -116,11 +139,17 @@ func summarizeFindings(base, target benchmarkSamples) comparisonFindings { baseMedian := median(baseValues) targetMedian := median(targetValues) if baseMedian <= 0 { + findings.Results = append(findings.Results, result) continue } + result.BaseMedianNS = baseMedian + result.TargetMedianNS = targetMedian findings.Compared++ deltaPercent := ((targetMedian - baseMedian) / baseMedian) * 100 + result.DeltaPercent = deltaPercent + findings.Results = append(findings.Results, result) + finding := benchmarkFinding{ Name: name, BaseMedianNS: baseMedian, @@ -146,6 +175,9 @@ func summarizeFindings(base, target benchmarkSamples) comparisonFindings { }) sort.Strings(findings.OnlyBase) sort.Strings(findings.OnlyTarget) + sort.Slice(findings.Results, func(i, j int) bool { + return findings.Results[i].Name < findings.Results[j].Name + }) return findings } diff --git a/cmd/benchdiff/benchfmt_test.go b/cmd/benchdiff/benchfmt_test.go index 9cdb1de1..0b692352 100644 --- a/cmd/benchdiff/benchfmt_test.go +++ b/cmd/benchdiff/benchfmt_test.go @@ -56,6 +56,7 @@ func TestSummarizeFindings(t *testing.T) { require.Equal(t, 1, findings.Unchanged) require.Equal(t, []string{"BenchmarkOnlyBase-12"}, findings.OnlyBase) require.Equal(t, []string{"BenchmarkOnlyTarget-12"}, findings.OnlyTarget) + require.Len(t, findings.Results, 5) require.Len(t, findings.Regressions, 1) require.Equal(t, "BenchmarkRegression-12", findings.Regressions[0].Name) require.InDelta(t, 36.36, findings.Regressions[0].DeltaPercent, 0.01) @@ -67,6 +68,18 @@ func TestSummarizeFindings(t *testing.T) { require.Len(t, regressions, 1) require.Equal(t, "BenchmarkRegression-12", regressions[0].Name) require.InDelta(t, 36.36, regressions[0].Percent, 0.01) + + onlyBase := requireBenchmarkResult(t, findings.Results, "BenchmarkOnlyBase-12") + require.True(t, onlyBase.HasBase) + require.False(t, onlyBase.HasTarget) + require.Equal(t, 1, onlyBase.BaseSamples) + require.Equal(t, 50.0, onlyBase.BaseMedianNS) + + onlyTarget := requireBenchmarkResult(t, findings.Results, "BenchmarkOnlyTarget-12") + require.False(t, onlyTarget.HasBase) + require.True(t, onlyTarget.HasTarget) + require.Equal(t, 1, onlyTarget.TargetSamples) + require.Equal(t, 75.0, onlyTarget.TargetMedianNS) } func TestParseBenchmarkMarkdown(t *testing.T) { @@ -115,3 +128,16 @@ func TestValidateBenchtime(t *testing.T) { require.Error(t, validateBenchtime("0x")) require.Error(t, validateBenchtime("soon")) } + +func requireBenchmarkResult(t *testing.T, results []benchmarkResult, name string) benchmarkResult { + t.Helper() + + for _, result := range results { + if result.Name == name { + return result + } + } + + require.Failf(t, "benchmark result not found", "%s", name) + return benchmarkResult{} +} diff --git a/cmd/benchdiff/report.go b/cmd/benchdiff/report.go index b8d17488..416c69b3 100644 --- a/cmd/benchdiff/report.go +++ b/cmd/benchdiff/report.go @@ -86,6 +86,8 @@ func writeRunReport(path string, summary runSummary) error { writeRegressionSection(&out, comparison, cfg.Threshold) } + writeAllExecutedNumbers(&out, summary) + if err := os.MkdirAll(filepath.Dir(path), 0755); err != nil { return err } @@ -93,6 +95,40 @@ func writeRunReport(path string, summary runSummary) error { return os.WriteFile(path, out.Bytes(), 0644) } +func writeAllExecutedNumbers(out *bytes.Buffer, summary runSummary) { + fmt.Fprintln(out, "## All Executed Benchmark Numbers") + fmt.Fprintln(out) + + if len(summary.Comparisons) == 0 { + fmt.Fprintln(out, "No benchmark numbers were captured.") + fmt.Fprintln(out) + return + } + + for _, comparison := range summary.Comparisons { + fmt.Fprintf(out, "### %s\n\n", comparison.Name) + if len(comparison.Findings.Results) == 0 { + fmt.Fprintln(out, "No benchmark numbers were captured.") + fmt.Fprintln(out) + continue + } + + fmt.Fprintln(out, "| Benchmark | Base Median | Target Median | Change | Base Samples | Target Samples |") + fmt.Fprintln(out, "|-----------|------------:|--------------:|-------:|-------------:|---------------:|") + for _, result := range comparison.Findings.Results { + fmt.Fprintf(out, "| `%s` | %s | %s | %s | %s | %s |\n", + result.Name, + formatOptionalNS(result.HasBase, result.BaseMedianNS), + formatOptionalNS(result.HasTarget, result.TargetMedianNS), + formatOptionalPercent(result.HasBase && result.HasTarget, result.DeltaPercent), + formatOptionalInt(result.HasBase, result.BaseSamples), + formatOptionalInt(result.HasTarget, result.TargetSamples), + ) + } + fmt.Fprintln(out) + } +} + func writeFindingsSummary(out *bytes.Buffer, summary runSummary) { fmt.Fprintln(out, "## Findings") fmt.Fprintln(out) @@ -197,6 +233,30 @@ func formatNS(value float64) string { } } +func formatOptionalNS(ok bool, value float64) string { + if !ok { + return "-" + } + + return formatNS(value) +} + +func formatOptionalPercent(ok bool, value float64) string { + if !ok { + return "-" + } + + return fmt.Sprintf("%+.2f%%", value) +} + +func formatOptionalInt(ok bool, value int) string { + if !ok { + return "-" + } + + return fmt.Sprintf("%d", value) +} + func writeRegressionSection(out *bytes.Buffer, comparison comparison, threshold float64) { if threshold <= 0 { return diff --git a/cmd/benchdiff/report_test.go b/cmd/benchdiff/report_test.go index 439d5b16..1e09694d 100644 --- a/cmd/benchdiff/report_test.go +++ b/cmd/benchdiff/report_test.go @@ -70,6 +70,40 @@ func TestWriteRunReportIncludesTopLevelFindings(t *testing.T) { }}, OnlyBase: []string{"BenchmarkRemoved-12"}, OnlyTarget: []string{"BenchmarkAdded-12"}, + Results: []benchmarkResult{ + { + Name: "BenchmarkAdded-12", + TargetMedianNS: 75, + TargetSamples: 1, + HasTarget: true, + }, + { + Name: "BenchmarkFast-12", + BaseMedianNS: 200, + TargetMedianNS: 100, + DeltaPercent: -50, + BaseSamples: 3, + TargetSamples: 3, + HasBase: true, + HasTarget: true, + }, + { + Name: "BenchmarkRemoved-12", + BaseMedianNS: 50, + BaseSamples: 1, + HasBase: true, + }, + { + Name: "BenchmarkSlow-12", + BaseMedianNS: 100, + TargetMedianNS: 150, + DeltaPercent: 50, + BaseSamples: 3, + TargetSamples: 3, + HasBase: true, + HasTarget: true, + }, + }, }, Regressions: []regression{{ Name: "BenchmarkSlow-12", @@ -98,4 +132,10 @@ func TestWriteRunReportIncludesTopLevelFindings(t *testing.T) { require.Contains(t, output, "| `BenchmarkFast-12` | 200ns | 100ns | -50.00% |") require.Contains(t, output, "## Unit Benchmarks") require.Contains(t, output, "benchstat output") + require.Contains(t, output, "## All Executed Benchmark Numbers") + require.Contains(t, output, "| Benchmark | Base Median | Target Median | Change | Base Samples | Target Samples |") + require.Contains(t, output, "| `BenchmarkSlow-12` | 100ns | 150ns | +50.00% | 3 | 3 |") + require.Contains(t, output, "| `BenchmarkFast-12` | 200ns | 100ns | -50.00% | 3 | 3 |") + require.Contains(t, output, "| `BenchmarkRemoved-12` | 50ns | - | - | 1 | - |") + require.Contains(t, output, "| `BenchmarkAdded-12` | - | 75ns | - | - | 1 |") }