diff --git a/internal/graph/graph.go b/internal/graph/graph.go index cadb178..5662a50 100644 --- a/internal/graph/graph.go +++ b/internal/graph/graph.go @@ -219,3 +219,79 @@ func (g *Graph[V]) TopologicallySortWithPriority(isLowerPriority func(V, V) bool return output, nil } + +// taggedVertex pairs a vertex with its pre-computed string ID to avoid +// repeated GetId() calls to avoid extra allocations during sorting +type taggedVertex[V Vertex] struct { + v V + id string +} + +// TopologicallySortWithPriorityFast is an optimized version of TopologicallySortWithPriority that: +// - Counts incoming edges in a single pass over forward edges (no Copy + Reverse) +// - Pre-caches vertex ID strings to avoid repeated GetId() calls during sorting +// +// It produces identical output to TopologicallySortWithPriority +func (g *Graph[V]) TopologicallySortWithPriorityFast(isLowerPriority func(V, V) bool) ([]V, error) { + numVertices := len(g.verticesById) + if numVertices == 0 { + return nil, nil + } + + // Count incoming edges directly from the forward edge map + incomingEdgeCount := make(map[string]int, numVertices) + for id := range g.verticesById { + incomingEdgeCount[id] = 0 + } + for _, adjacentEdges := range g.edges { + for target, hasEdge := range adjacentEdges { + if hasEdge { + incomingEdgeCount[target]++ + } + } + } + + output := make([]V, 0, numVertices) + for len(incomingEdgeCount) > 0 { + var sources []taggedVertex[V] + for id, count := range incomingEdgeCount { + if count == 0 { + sources = append(sources, taggedVertex[V]{v: g.verticesById[id], id: id}) + } + } + + sort.Slice(sources, func(i, j int) bool { + return sources[i].id < sources[j].id + }) + + // Take the source with the highest priority from the sorted array of sources + bestIdx := -1 + for i := range sources { + if bestIdx == -1 || isLowerPriority(sources[bestIdx].v, sources[i].v) { + bestIdx = i + } + } + if bestIdx == -1 { + dotSB := strings.Builder{} + if err := EncodeDOT(g, &dotSB, true); err != nil { + dotSB.Reset() + dotSB.WriteString(fmt.Sprintf("failed to encode graph to DOT: %v", err)) + } + return nil, fmt.Errorf("cycle detected: %+v\n%s", incomingEdgeCount, dotSB.String()) + } + best := sources[bestIdx] + + output = append(output, best.v) + + // Decrement incoming edge counts for vertices adjacent to the removed source + for target, hasEdge := range g.edges[best.id] { + if hasEdge { + incomingEdgeCount[target]-- + } + } + + delete(incomingEdgeCount, best.id) + } + + return output, nil +} diff --git a/internal/graph/graph_test.go b/internal/graph/graph_test.go index c6b765e..5b4f3d2 100644 --- a/internal/graph/graph_test.go +++ b/internal/graph/graph_test.go @@ -430,3 +430,144 @@ func getVertexIds(g *Graph[vertex]) []string { } return output } + +// sprintfVertex is a vertex whose GetId() calls fmt.Sprintf every time +type sprintfVertex struct { + objType string + objId string + diffType string + priority int +} + +func (s sprintfVertex) GetId() string { + return fmt.Sprintf("%s:%s:%s", s.objType, s.objId, s.diffType) +} + +func TestTopologicallySortWithPriorityFast(t *testing.T) { + // Same graph and test cases as TestTopologicallySortWithPriority to prove + // the fast version produces identical output. + // Source: https://en.wikipedia.org/wiki/Topological_sorting#Examples + g := NewGraph[vertex]() + v5 := NewV("05") + g.AddVertex(v5) + v7 := NewV("07") + g.AddVertex(v7) + v3 := NewV("03") + g.AddVertex(v3) + v11 := NewV("11") + g.AddVertex(v11) + v8 := NewV("08") + g.AddVertex(v8) + v2 := NewV("02") + g.AddVertex(v2) + v9 := NewV("09") + g.AddVertex(v9) + v10 := NewV("10") + g.AddVertex(v10) + assert.NoError(t, g.AddEdge("05", "11")) + assert.NoError(t, g.AddEdge("07", "11")) + assert.NoError(t, g.AddEdge("07", "08")) + assert.NoError(t, g.AddEdge("03", "08")) + assert.NoError(t, g.AddEdge("03", "10")) + assert.NoError(t, g.AddEdge("11", "02")) + assert.NoError(t, g.AddEdge("11", "09")) + assert.NoError(t, g.AddEdge("11", "10")) + assert.NoError(t, g.AddEdge("08", "09")) + + for _, tc := range []struct { + name string + isLowerPriority func(v1, v2 vertex) bool + expectedOrdering []vertex + }{ + { + name: "largest-numbered available vertex first (string-based GetPriority)", + isLowerPriority: IsLowerPriorityFromGetPriority(func(v vertex) string { + return v.GetId() + }), + expectedOrdering: []vertex{v7, v5, v11, v3, v10, v8, v9, v2}, + }, + { + name: "smallest-numbered available vertex first (numeric-based GetPriority)", + isLowerPriority: IsLowerPriorityFromGetPriority(func(v vertex) int { + idAsInt, err := strconv.Atoi(v.GetId()) + require.NoError(t, err) + return -idAsInt + }), + expectedOrdering: []vertex{v3, v5, v7, v8, v11, v2, v9, v10}, + }, + { + name: "fewest edges first (prioritize high id's for tie breakers)", + isLowerPriority: func(v1, v2 vertex) bool { + v1EdgeCount := getEdgeCount(g, v1) + v2EdgeCount := getEdgeCount(g, v2) + if v1EdgeCount == v2EdgeCount { + return v1.GetId() < v2.GetId() + } + return v1EdgeCount > v2EdgeCount + }, + expectedOrdering: []vertex{v5, v7, v3, v8, v11, v10, v9, v2}, + }, + } { + t.Run(tc.name, func(t *testing.T) { + fast, err := g.TopologicallySortWithPriorityFast(tc.isLowerPriority) + assert.NoError(t, err) + assert.Equal(t, tc.expectedOrdering, fast, "fast version should match expected ordering") + + original, err := g.TopologicallySortWithPriority(tc.isLowerPriority) + assert.NoError(t, err) + assert.Equal(t, original, fast, "fast version should match original version") + }) + } + + // Cycle should error + assert.NoError(t, g.AddEdge("10", "07")) + _, err := g.TopologicallySortWithPriorityFast(func(_, _ vertex) bool { return false }) + assert.Error(t, err) +} + +// buildBenchGraph creates a graph with N vertices simulating schema objects, +// using sprintfVertex +func buildBenchGraph(n int) *Graph[sprintfVertex] { + g := NewGraph[sprintfVertex]() + for i := 0; i < n; i++ { + g.AddVertex(sprintfVertex{ + objType: "TABLE", + objId: fmt.Sprintf("public.table_%04d", i), + diffType: "ADDALTER", + priority: i % 3, + }) + } + // Add edges: each vertex i depends on vertex i-1 (linear chain) + // plus some cross-edges to make it more realistic + for i := 1; i < n; i++ { + src := fmt.Sprintf("TABLE:public.table_%04d:ADDALTER", i-1) + dst := fmt.Sprintf("TABLE:public.table_%04d:ADDALTER", i) + _ = g.AddEdge(src, dst) + // Add a cross-edge from every 10th vertex to create wider fan-out + if i >= 10 && i%10 == 0 { + crossSrc := fmt.Sprintf("TABLE:public.table_%04d:ADDALTER", i-10) + _ = g.AddEdge(crossSrc, dst) + } + } + return g +} + +func BenchmarkTopologicallySortWithPriority(b *testing.B) { + for _, size := range []int{50, 200, 500} { + g := buildBenchGraph(size) + isLowerPriority := IsLowerPriorityFromGetPriority(func(v sprintfVertex) int { + return v.priority + }) + + b.Run(fmt.Sprintf("original/n=%d", size), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = g.TopologicallySortWithPriority(isLowerPriority) + } + }) + b.Run(fmt.Sprintf("fast/n=%d", size), func(b *testing.B) { + for i := 0; i < b.N; i++ { + _, _ = g.TopologicallySortWithPriorityFast(isLowerPriority) + } + }) + } +} diff --git a/pkg/diff/materialized_view_sql_generator.go b/pkg/diff/materialized_view_sql_generator.go index 9b5d216..be9bec6 100644 --- a/pkg/diff/materialized_view_sql_generator.go +++ b/pkg/diff/materialized_view_sql_generator.go @@ -110,15 +110,11 @@ func (mvsg *materializedViewSQLGenerator) Add(mv schema.MaterializedView) (parti } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: addVertexId, - priority: sqlPrioritySooner, - statements: []Statement{{ - DDL: materializedViewSb.String(), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - }}, - }}, + vertices: []sqlVertex{newSqlVertex(addVertexId, sqlPrioritySooner, []Statement{{ + DDL: materializedViewSb.String(), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }})}, dependencies: deps, }, nil } @@ -134,15 +130,11 @@ func (mvsg *materializedViewSQLGenerator) Delete(mv schema.MaterializedView) (pa } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: deleteVertexId, - priority: sqlPriorityLater, - statements: []Statement{{ - DDL: fmt.Sprintf("DROP MATERIALIZED VIEW %s", mv.GetFQEscapedName()), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - }}, - }}, + vertices: []sqlVertex{newSqlVertex(deleteVertexId, sqlPriorityLater, []Statement{{ + DDL: fmt.Sprintf("DROP MATERIALIZED VIEW %s", mv.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }})}, dependencies: deps, }, nil } diff --git a/pkg/diff/procedure_sql_vertex_generator.go b/pkg/diff/procedure_sql_vertex_generator.go index d9f725b..eb19b5c 100644 --- a/pkg/diff/procedure_sql_vertex_generator.go +++ b/pkg/diff/procedure_sql_vertex_generator.go @@ -40,22 +40,18 @@ func (p procedureSQLVertexGenerator) Add(s schema.Procedure) (partialSQLGraph, e } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: buildProcedureVertexId(s.SchemaQualifiedName, diffTypeAddAlter), - priority: sqlPrioritySooner, - statements: []Statement{{ - DDL: s.Def, - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - Hazards: []MigrationHazard{{ - Type: MigrationHazardTypeHasUntrackableDependencies, - Message: "Dependencies of procedures are not tracked by Postgres. " + - "As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " + - "this statement. For adds, this means you need to ensure that all objects this function depends on " + - "are added before this statement.", - }}, + vertices: []sqlVertex{newSqlVertex(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeAddAlter), sqlPrioritySooner, []Statement{{ + DDL: s.Def, + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "Dependencies of procedures are not tracked by Postgres. " + + "As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " + + "this statement. For adds, this means you need to ensure that all objects this function depends on " + + "are added before this statement.", }}, - }}, + }})}, dependencies: deps, }, nil } @@ -84,22 +80,18 @@ func (p procedureSQLVertexGenerator) Delete(s schema.Procedure) (partialSQLGraph } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: buildProcedureVertexId(s.SchemaQualifiedName, diffTypeDelete), - priority: sqlPriorityLater, - statements: []Statement{{ - DDL: fmt.Sprintf("DROP PROCEDURE %s", s.GetFQEscapedName()), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - Hazards: []MigrationHazard{{ - Type: MigrationHazardTypeHasUntrackableDependencies, - Message: "Dependencies of procedures are not tracked by Postgres. " + - "As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " + - "this statement. For drops, this means you need to ensure that all objects this function depends on " + - "are dropped after this statement.", - }}, + vertices: []sqlVertex{newSqlVertex(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeDelete), sqlPriorityLater, []Statement{{ + DDL: fmt.Sprintf("DROP PROCEDURE %s", s.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + Hazards: []MigrationHazard{{ + Type: MigrationHazardTypeHasUntrackableDependencies, + Message: "Dependencies of procedures are not tracked by Postgres. " + + "As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " + + "this statement. For drops, this means you need to ensure that all objects this function depends on " + + "are dropped after this statement.", }}, - }}, + }})}, dependencies: deps, }, nil } diff --git a/pkg/diff/sql_graph.go b/pkg/diff/sql_graph.go index e5bf681..69a04cb 100644 --- a/pkg/diff/sql_graph.go +++ b/pkg/diff/sql_graph.go @@ -48,6 +48,10 @@ type sqlVertex struct { // id is used to identify the sql vertex id sqlVertexId + // idStr is the pre-computed string representation of id, cached to avoid + // repeated fmt.Sprintf calls (e.g., during topological sort comparisons) + idStr string + // priority is used to determine if the sql vertex should be included sooner or later in the topological // sort of the graph priority sqlPriority @@ -56,8 +60,17 @@ type sqlVertex struct { statements []Statement } +func newSqlVertex(id sqlVertexId, priority sqlPriority, statements []Statement) sqlVertex { + return sqlVertex{ + id: id, + idStr: id.String(), + priority: priority, + statements: statements, + } +} + func (s sqlVertex) GetId() string { - return s.id.String() + return s.idStr } func (s sqlVertex) GetPriority() int { @@ -114,7 +127,7 @@ func newSqlGraph() *sqlGraph { } func (s *sqlGraph) toOrderedStatements() ([]Statement, error) { - vertices, err := s.TopologicallySortWithPriority(graph.IsLowerPriorityFromGetPriority(func(v sqlVertex) int { + vertices, err := s.TopologicallySortWithPriorityFast(graph.IsLowerPriorityFromGetPriority(func(v sqlVertex) int { return v.GetPriority() })) if err != nil { diff --git a/pkg/diff/sql_vertex_generator.go b/pkg/diff/sql_vertex_generator.go index 3b54459..bf85cd9 100644 --- a/pkg/diff/sql_vertex_generator.go +++ b/pkg/diff/sql_vertex_generator.go @@ -72,21 +72,13 @@ func mergeVertices(old, new sqlVertex) sqlVertex { priority = new.priority } - return sqlVertex{ - id: old.id, - priority: priority, - statements: append(old.statements, new.statements...), - } + return newSqlVertex(old.id, priority, append(old.statements, new.statements...)) } func addVertexIfNotExists(graph *sqlGraph, id sqlVertexId) { if !graph.HasVertexWithId(id.String()) { // Create a filler node - graph.AddVertex(sqlVertex{ - id: id, - priority: sqlPriorityUnset, - statements: nil, - }) + graph.AddVertex(newSqlVertex(id, sqlPriorityUnset, nil)) } } @@ -175,11 +167,7 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Add(o S) (partialSQLGraph, er } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: s.generator.GetSQLVertexId(o, diffTypeAddAlter), - priority: sqlPrioritySooner, - statements: statements, - }}, + vertices: []sqlVertex{newSqlVertex(s.generator.GetSQLVertexId(o, diffTypeAddAlter), sqlPrioritySooner, statements)}, dependencies: deps, }, nil } @@ -195,11 +183,7 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Delete(o S) (partialSQLGraph, } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: s.generator.GetSQLVertexId(o, diffTypeDelete), - priority: sqlPriorityLater, - statements: statements, - }}, + vertices: []sqlVertex{newSqlVertex(s.generator.GetSQLVertexId(o, diffTypeDelete), sqlPriorityLater, statements)}, dependencies: deps, }, nil } @@ -215,11 +199,7 @@ func (s *wrappedLegacySqlVertexGenerator[S, Diff]) Alter(d Diff) (partialSQLGrap } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: s.generator.GetSQLVertexId(d.GetNew(), diffTypeAddAlter), - priority: sqlPrioritySooner, - statements: statements, - }}, + vertices: []sqlVertex{newSqlVertex(s.generator.GetSQLVertexId(d.GetNew(), diffTypeAddAlter), sqlPrioritySooner, statements)}, dependencies: deps, }, nil } diff --git a/pkg/diff/view_sql_generator.go b/pkg/diff/view_sql_generator.go index d96f28d..10aa3cf 100644 --- a/pkg/diff/view_sql_generator.go +++ b/pkg/diff/view_sql_generator.go @@ -105,15 +105,11 @@ func (vsg *viewSQLGenerator) Add(v schema.View) (partialSQLGraph, error) { } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: addVertexId, - priority: sqlPrioritySooner, - statements: []Statement{{ - DDL: viewSb.String(), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - }}, - }}, + vertices: []sqlVertex{newSqlVertex(addVertexId, sqlPrioritySooner, []Statement{{ + DDL: viewSb.String(), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }})}, dependencies: deps, }, nil } @@ -129,15 +125,11 @@ func (vsg *viewSQLGenerator) Delete(v schema.View) (partialSQLGraph, error) { } return partialSQLGraph{ - vertices: []sqlVertex{{ - id: deleteVertexId, - priority: sqlPriorityLater, - statements: []Statement{{ - DDL: fmt.Sprintf("DROP VIEW %s", v.GetFQEscapedName()), - Timeout: statementTimeoutDefault, - LockTimeout: lockTimeoutDefault, - }}, - }}, + vertices: []sqlVertex{newSqlVertex(deleteVertexId, sqlPriorityLater, []Statement{{ + DDL: fmt.Sprintf("DROP VIEW %s", v.GetFQEscapedName()), + Timeout: statementTimeoutDefault, + LockTimeout: lockTimeoutDefault, + }})}, dependencies: deps, }, nil }