Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions internal/graph/graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,79 @@ func (g *Graph[V]) TopologicallySortWithPriority(isLowerPriority func(V, V) bool

return output, nil
}

// taggedVertex pairs a vertex with its pre-computed string ID to avoid
// repeated GetId() calls to avoid extra allocations during sorting
type taggedVertex[V Vertex] struct {
v V
id string
}

// TopologicallySortWithPriorityFast is an optimized version of TopologicallySortWithPriority that:
// - Counts incoming edges in a single pass over forward edges (no Copy + Reverse)
// - Pre-caches vertex ID strings to avoid repeated GetId() calls during sorting
//
// It produces identical output to TopologicallySortWithPriority
func (g *Graph[V]) TopologicallySortWithPriorityFast(isLowerPriority func(V, V) bool) ([]V, error) {
numVertices := len(g.verticesById)
if numVertices == 0 {
return nil, nil
}

// Count incoming edges directly from the forward edge map
incomingEdgeCount := make(map[string]int, numVertices)
for id := range g.verticesById {
incomingEdgeCount[id] = 0
}
for _, adjacentEdges := range g.edges {
for target, hasEdge := range adjacentEdges {
if hasEdge {
incomingEdgeCount[target]++
}
}
}

output := make([]V, 0, numVertices)
for len(incomingEdgeCount) > 0 {
var sources []taggedVertex[V]
for id, count := range incomingEdgeCount {
if count == 0 {
sources = append(sources, taggedVertex[V]{v: g.verticesById[id], id: id})
}
}

sort.Slice(sources, func(i, j int) bool {
return sources[i].id < sources[j].id
})

// Take the source with the highest priority from the sorted array of sources
bestIdx := -1
for i := range sources {
if bestIdx == -1 || isLowerPriority(sources[bestIdx].v, sources[i].v) {
bestIdx = i
}
}
if bestIdx == -1 {
dotSB := strings.Builder{}
if err := EncodeDOT(g, &dotSB, true); err != nil {
dotSB.Reset()
dotSB.WriteString(fmt.Sprintf("failed to encode graph to DOT: %v", err))
}
return nil, fmt.Errorf("cycle detected: %+v\n%s", incomingEdgeCount, dotSB.String())
}
best := sources[bestIdx]

output = append(output, best.v)

// Decrement incoming edge counts for vertices adjacent to the removed source
for target, hasEdge := range g.edges[best.id] {
if hasEdge {
incomingEdgeCount[target]--
}
}

delete(incomingEdgeCount, best.id)
}

return output, nil
}
141 changes: 141 additions & 0 deletions internal/graph/graph_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -430,3 +430,144 @@ func getVertexIds(g *Graph[vertex]) []string {
}
return output
}

// sprintfVertex is a vertex whose GetId() calls fmt.Sprintf every time
type sprintfVertex struct {
objType string
objId string
diffType string
priority int
}

func (s sprintfVertex) GetId() string {
return fmt.Sprintf("%s:%s:%s", s.objType, s.objId, s.diffType)
}

func TestTopologicallySortWithPriorityFast(t *testing.T) {
// Same graph and test cases as TestTopologicallySortWithPriority to prove
// the fast version produces identical output.
// Source: https://en.wikipedia.org/wiki/Topological_sorting#Examples
g := NewGraph[vertex]()
v5 := NewV("05")
g.AddVertex(v5)
v7 := NewV("07")
g.AddVertex(v7)
v3 := NewV("03")
g.AddVertex(v3)
v11 := NewV("11")
g.AddVertex(v11)
v8 := NewV("08")
g.AddVertex(v8)
v2 := NewV("02")
g.AddVertex(v2)
v9 := NewV("09")
g.AddVertex(v9)
v10 := NewV("10")
g.AddVertex(v10)
assert.NoError(t, g.AddEdge("05", "11"))
assert.NoError(t, g.AddEdge("07", "11"))
assert.NoError(t, g.AddEdge("07", "08"))
assert.NoError(t, g.AddEdge("03", "08"))
assert.NoError(t, g.AddEdge("03", "10"))
assert.NoError(t, g.AddEdge("11", "02"))
assert.NoError(t, g.AddEdge("11", "09"))
assert.NoError(t, g.AddEdge("11", "10"))
assert.NoError(t, g.AddEdge("08", "09"))

for _, tc := range []struct {
name string
isLowerPriority func(v1, v2 vertex) bool
expectedOrdering []vertex
}{
{
name: "largest-numbered available vertex first (string-based GetPriority)",
isLowerPriority: IsLowerPriorityFromGetPriority(func(v vertex) string {
return v.GetId()
}),
expectedOrdering: []vertex{v7, v5, v11, v3, v10, v8, v9, v2},
},
{
name: "smallest-numbered available vertex first (numeric-based GetPriority)",
isLowerPriority: IsLowerPriorityFromGetPriority(func(v vertex) int {
idAsInt, err := strconv.Atoi(v.GetId())
require.NoError(t, err)
return -idAsInt
}),
expectedOrdering: []vertex{v3, v5, v7, v8, v11, v2, v9, v10},
},
{
name: "fewest edges first (prioritize high id's for tie breakers)",
isLowerPriority: func(v1, v2 vertex) bool {
v1EdgeCount := getEdgeCount(g, v1)
v2EdgeCount := getEdgeCount(g, v2)
if v1EdgeCount == v2EdgeCount {
return v1.GetId() < v2.GetId()
}
return v1EdgeCount > v2EdgeCount
},
expectedOrdering: []vertex{v5, v7, v3, v8, v11, v10, v9, v2},
},
} {
t.Run(tc.name, func(t *testing.T) {
fast, err := g.TopologicallySortWithPriorityFast(tc.isLowerPriority)
assert.NoError(t, err)
assert.Equal(t, tc.expectedOrdering, fast, "fast version should match expected ordering")

original, err := g.TopologicallySortWithPriority(tc.isLowerPriority)
assert.NoError(t, err)
assert.Equal(t, original, fast, "fast version should match original version")
})
}

// Cycle should error
assert.NoError(t, g.AddEdge("10", "07"))
_, err := g.TopologicallySortWithPriorityFast(func(_, _ vertex) bool { return false })
assert.Error(t, err)
}

// buildBenchGraph creates a graph with N vertices simulating schema objects,
// using sprintfVertex
func buildBenchGraph(n int) *Graph[sprintfVertex] {
g := NewGraph[sprintfVertex]()
for i := 0; i < n; i++ {
g.AddVertex(sprintfVertex{
objType: "TABLE",
objId: fmt.Sprintf("public.table_%04d", i),
diffType: "ADDALTER",
priority: i % 3,
})
}
// Add edges: each vertex i depends on vertex i-1 (linear chain)
// plus some cross-edges to make it more realistic
for i := 1; i < n; i++ {
src := fmt.Sprintf("TABLE:public.table_%04d:ADDALTER", i-1)
dst := fmt.Sprintf("TABLE:public.table_%04d:ADDALTER", i)
_ = g.AddEdge(src, dst)
// Add a cross-edge from every 10th vertex to create wider fan-out
if i >= 10 && i%10 == 0 {
crossSrc := fmt.Sprintf("TABLE:public.table_%04d:ADDALTER", i-10)
_ = g.AddEdge(crossSrc, dst)
}
}
return g
}

func BenchmarkTopologicallySortWithPriority(b *testing.B) {
for _, size := range []int{50, 200, 500} {
g := buildBenchGraph(size)
isLowerPriority := IsLowerPriorityFromGetPriority(func(v sprintfVertex) int {
return v.priority
})

b.Run(fmt.Sprintf("original/n=%d", size), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = g.TopologicallySortWithPriority(isLowerPriority)
}
})
b.Run(fmt.Sprintf("fast/n=%d", size), func(b *testing.B) {
for i := 0; i < b.N; i++ {
_, _ = g.TopologicallySortWithPriorityFast(isLowerPriority)
}
})
}
}
28 changes: 10 additions & 18 deletions pkg/diff/materialized_view_sql_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,15 +110,11 @@ func (mvsg *materializedViewSQLGenerator) Add(mv schema.MaterializedView) (parti
}

return partialSQLGraph{
vertices: []sqlVertex{{
id: addVertexId,
priority: sqlPrioritySooner,
statements: []Statement{{
DDL: materializedViewSb.String(),
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
}},
}},
vertices: []sqlVertex{newSqlVertex(addVertexId, sqlPrioritySooner, []Statement{{
DDL: materializedViewSb.String(),
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
}})},
dependencies: deps,
}, nil
}
Expand All @@ -134,15 +130,11 @@ func (mvsg *materializedViewSQLGenerator) Delete(mv schema.MaterializedView) (pa
}

return partialSQLGraph{
vertices: []sqlVertex{{
id: deleteVertexId,
priority: sqlPriorityLater,
statements: []Statement{{
DDL: fmt.Sprintf("DROP MATERIALIZED VIEW %s", mv.GetFQEscapedName()),
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
}},
}},
vertices: []sqlVertex{newSqlVertex(deleteVertexId, sqlPriorityLater, []Statement{{
DDL: fmt.Sprintf("DROP MATERIALIZED VIEW %s", mv.GetFQEscapedName()),
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
}})},
dependencies: deps,
}, nil
}
Expand Down
52 changes: 22 additions & 30 deletions pkg/diff/procedure_sql_vertex_generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,22 +40,18 @@ func (p procedureSQLVertexGenerator) Add(s schema.Procedure) (partialSQLGraph, e
}

return partialSQLGraph{
vertices: []sqlVertex{{
id: buildProcedureVertexId(s.SchemaQualifiedName, diffTypeAddAlter),
priority: sqlPrioritySooner,
statements: []Statement{{
DDL: s.Def,
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
Hazards: []MigrationHazard{{
Type: MigrationHazardTypeHasUntrackableDependencies,
Message: "Dependencies of procedures are not tracked by Postgres. " +
"As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " +
"this statement. For adds, this means you need to ensure that all objects this function depends on " +
"are added before this statement.",
}},
vertices: []sqlVertex{newSqlVertex(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeAddAlter), sqlPrioritySooner, []Statement{{
DDL: s.Def,
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
Hazards: []MigrationHazard{{
Type: MigrationHazardTypeHasUntrackableDependencies,
Message: "Dependencies of procedures are not tracked by Postgres. " +
"As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " +
"this statement. For adds, this means you need to ensure that all objects this function depends on " +
"are added before this statement.",
}},
}},
}})},
dependencies: deps,
}, nil
}
Expand Down Expand Up @@ -84,22 +80,18 @@ func (p procedureSQLVertexGenerator) Delete(s schema.Procedure) (partialSQLGraph
}

return partialSQLGraph{
vertices: []sqlVertex{{
id: buildProcedureVertexId(s.SchemaQualifiedName, diffTypeDelete),
priority: sqlPriorityLater,
statements: []Statement{{
DDL: fmt.Sprintf("DROP PROCEDURE %s", s.GetFQEscapedName()),
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
Hazards: []MigrationHazard{{
Type: MigrationHazardTypeHasUntrackableDependencies,
Message: "Dependencies of procedures are not tracked by Postgres. " +
"As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " +
"this statement. For drops, this means you need to ensure that all objects this function depends on " +
"are dropped after this statement.",
}},
vertices: []sqlVertex{newSqlVertex(buildProcedureVertexId(s.SchemaQualifiedName, diffTypeDelete), sqlPriorityLater, []Statement{{
DDL: fmt.Sprintf("DROP PROCEDURE %s", s.GetFQEscapedName()),
Timeout: statementTimeoutDefault,
LockTimeout: lockTimeoutDefault,
Hazards: []MigrationHazard{{
Type: MigrationHazardTypeHasUntrackableDependencies,
Message: "Dependencies of procedures are not tracked by Postgres. " +
"As a result, we cannot guarantee that this procedure's dependencies are ordered properly relative to " +
"this statement. For drops, this means you need to ensure that all objects this function depends on " +
"are dropped after this statement.",
}},
}},
}})},
dependencies: deps,
}, nil
}
Expand Down
17 changes: 15 additions & 2 deletions pkg/diff/sql_graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ type sqlVertex struct {
// id is used to identify the sql vertex
id sqlVertexId

// idStr is the pre-computed string representation of id, cached to avoid
// repeated fmt.Sprintf calls (e.g., during topological sort comparisons)
idStr string

// priority is used to determine if the sql vertex should be included sooner or later in the topological
// sort of the graph
priority sqlPriority
Expand All @@ -56,8 +60,17 @@ type sqlVertex struct {
statements []Statement
}

func newSqlVertex(id sqlVertexId, priority sqlPriority, statements []Statement) sqlVertex {
return sqlVertex{
id: id,
idStr: id.String(),
priority: priority,
statements: statements,
}
}

func (s sqlVertex) GetId() string {
return s.id.String()
return s.idStr
}

func (s sqlVertex) GetPriority() int {
Expand Down Expand Up @@ -114,7 +127,7 @@ func newSqlGraph() *sqlGraph {
}

func (s *sqlGraph) toOrderedStatements() ([]Statement, error) {
vertices, err := s.TopologicallySortWithPriority(graph.IsLowerPriorityFromGetPriority(func(v sqlVertex) int {
vertices, err := s.TopologicallySortWithPriorityFast(graph.IsLowerPriorityFromGetPriority(func(v sqlVertex) int {
return v.GetPriority()
}))
if err != nil {
Expand Down
Loading