@@ -453,11 +453,15 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
453453 return nil , err
454454 }
455455
456- cols , err := outputColumns (c , raw .Stmt )
456+ qc , err := buildQueryCatalog (c , raw .Stmt )
457457 if err != nil {
458458 return nil , err
459459 }
460- expanded , err := expand (c , raw , rawSQL )
460+ cols , err := outputColumns (qc , raw .Stmt )
461+ if err != nil {
462+ return nil , err
463+ }
464+ expanded , err := expand (qc , raw , rawSQL )
461465 if err != nil {
462466 return nil , err
463467 }
@@ -499,7 +503,7 @@ type edit struct {
499503 New string
500504}
501505
502- func expand (c core. Catalog , raw nodes.RawStmt , sql string ) (string , error ) {
506+ func expand (qc * QueryCatalog , raw nodes.RawStmt , sql string ) (string , error ) {
503507 list := search (raw , func (node nodes.Node ) bool {
504508 switch node .(type ) {
505509 case nodes.DeleteStmt :
@@ -516,7 +520,7 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
516520 }
517521 var edits []edit
518522 for _ , item := range list .Items {
519- edit , err := expandStmt (c , raw , item )
523+ edit , err := expandStmt (qc , raw , item )
520524 if err != nil {
521525 return "" , err
522526 }
@@ -525,8 +529,8 @@ func expand(c core.Catalog, raw nodes.RawStmt, sql string) (string, error) {
525529 return editQuery (sql , edits )
526530}
527531
528- func expandStmt (c core. Catalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
529- tables , err := sourceTables (c , node )
532+ func expandStmt (qc * QueryCatalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
533+ tables , err := sourceTables (qc , node )
530534 if err != nil {
531535 return nil , err
532536 }
@@ -629,23 +633,32 @@ type QueryCatalog struct {
629633 ctes map [string ]core.Table
630634}
631635
632- func NewQueryCatalog (c core.Catalog , with * nodes.WithClause ) QueryCatalog {
633- ctes := map [string ]core.Table {}
636+ func buildQueryCatalog (c core.Catalog , node nodes.Node ) (* QueryCatalog , error ) {
637+ var with * nodes.WithClause
638+ switch n := node .(type ) {
639+ case nodes.UpdateStmt :
640+ with = n .WithClause
641+ case nodes.SelectStmt :
642+ with = n .WithClause
643+ default :
644+ with = nil
645+ }
646+ qc := & QueryCatalog {catalog : c , ctes : map [string ]core.Table {}}
634647 if with != nil {
635648 for _ , item := range with .Ctes .Items {
636649 if cte , ok := item .(nodes.CommonTableExpr ); ok {
637- cols , err := outputColumns (c , cte .Ctequery )
650+ cols , err := outputColumns (qc , cte .Ctequery )
638651 if err != nil {
639- panic ( err . Error ())
652+ return nil , err
640653 }
641- ctes [* cte .Ctename ] = core.Table {
654+ qc . ctes [* cte .Ctename ] = core.Table {
642655 Name : * cte .Ctename ,
643656 Columns : cols ,
644657 }
645658 }
646659 }
647660 }
648- return QueryCatalog { catalog : c , ctes : ctes }
661+ return qc , nil
649662}
650663
651664func (qc QueryCatalog ) GetTable (fqn core.FQN ) (core.Table , * core.Error ) {
@@ -673,9 +686,8 @@ func (qc QueryCatalog) GetTable(fqn core.FQN) (core.Table, *core.Error) {
673686// Return an error if column references don't exist
674687// Return an error if a table is referenced twice
675688// Return an error if an unknown column is referenced
676- func sourceTables (c core. Catalog , node nodes.Node ) ([]core.Table , error ) {
689+ func sourceTables (qc * QueryCatalog , node nodes.Node ) ([]core.Table , error ) {
677690 var list nodes.List
678- var with * nodes.WithClause
679691 switch n := node .(type ) {
680692 case nodes.DeleteStmt :
681693 list = nodes.List {
@@ -686,12 +698,10 @@ func sourceTables(c core.Catalog, node nodes.Node) ([]core.Table, error) {
686698 Items : []nodes.Node {* n .Relation },
687699 }
688700 case nodes.UpdateStmt :
689- with = n .WithClause
690701 list = nodes.List {
691702 Items : append (n .FromClause .Items , * n .Relation ),
692703 }
693704 case nodes.SelectStmt :
694- with = n .WithClause
695705 list = search (n .FromClause , func (node nodes.Node ) bool {
696706 _ , ok := node .(nodes.RangeVar )
697707 return ok
@@ -700,8 +710,6 @@ func sourceTables(c core.Catalog, node nodes.Node) ([]core.Table, error) {
700710 return nil , fmt .Errorf ("sourceTables: unsupported node type: %T" , n )
701711 }
702712
703- qc := NewQueryCatalog (c , with )
704-
705713 var tables []core.Table
706714 for _ , item := range list .Items {
707715 switch n := item .(type ) {
@@ -736,8 +744,8 @@ func HasStarRef(cf nodes.ColumnRef) bool {
736744//
737745// Return an error if column references are ambiguous
738746// Return an error if column references don't exist
739- func outputColumns (c core. Catalog , node nodes.Node ) ([]core.Column , error ) {
740- tables , err := sourceTables (c , node )
747+ func outputColumns (qc * QueryCatalog , node nodes.Node ) ([]core.Column , error ) {
748+ tables , err := sourceTables (qc , node )
741749 if err != nil {
742750 return nil , err
743751 }
@@ -865,7 +873,7 @@ func outputColumns(c core.Catalog, node nodes.Node) ([]core.Column, error) {
865873 name = * res .Name
866874 }
867875
868- fun , err := c .LookupFunctionN (fqn , len (n .Args .Items ))
876+ fun , err := qc . catalog .LookupFunctionN (fqn , len (n .Args .Items ))
869877 if err == nil {
870878 cols = append (cols , core.Column {Name : name , DataType : fun .ReturnType , NotNull : true })
871879 } else {
0 commit comments