@@ -16,6 +16,7 @@ import (
1616 "github.com/kyleconroy/sqlc/internal/config"
1717 core "github.com/kyleconroy/sqlc/internal/pg"
1818 "github.com/kyleconroy/sqlc/internal/postgres"
19+ "github.com/kyleconroy/sqlc/internal/postgresql/ast"
1920
2021 "github.com/davecgh/go-spew/spew"
2122 pg "github.com/lfittl/pg_query_go"
@@ -315,13 +316,13 @@ func pluckQuery(source string, n nodes.RawStmt) (string, error) {
315316
316317func rangeVars (root nodes.Node ) []nodes.RangeVar {
317318 var vars []nodes.RangeVar
318- find := VisitorFunc (func (node nodes.Node ) {
319+ find := ast . VisitorFunc (func (node nodes.Node ) {
319320 switch n := node .(type ) {
320321 case nodes.RangeVar :
321322 vars = append (vars , n )
322323 }
323324 })
324- Walk (find , root )
325+ ast . Walk (find , root )
325326 return vars
326327}
327328
@@ -416,6 +417,9 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
416417 if err := validateParamRef (stmt ); err != nil {
417418 return nil , err
418419 }
420+ if err := validateParamStyle (stmt ); err != nil {
421+ return nil , err
422+ }
419423 raw , ok := stmt .(nodes.RawStmt )
420424 if ! ok {
421425 return nil , errors .New ("node is not a statement" )
@@ -449,9 +453,12 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
449453 if err := validateCmd (raw .Stmt , name , cmd ); err != nil {
450454 return nil , err
451455 }
456+
457+ // Re-write query AST
458+ raw , namedParams , edits := rewriteNamedParameters (raw )
452459 rvs := rangeVars (raw .Stmt )
453460 refs := findParameters (raw .Stmt )
454- params , err := resolveCatalogRefs (c , rvs , refs )
461+ params , err := resolveCatalogRefs (c , rvs , refs , namedParams )
455462 if err != nil {
456463 return nil , err
457464 }
@@ -464,10 +471,23 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
464471 if err != nil {
465472 return nil , err
466473 }
467- expanded , err := expand (qc , raw , rawSQL )
474+
475+ expandEdits , err := expand (qc , raw )
468476 if err != nil {
469477 return nil , err
470478 }
479+ edits = append (edits , expandEdits ... )
480+ expanded , err := editQuery (rawSQL , edits )
481+ if err != nil {
482+ return nil , err
483+ }
484+
485+ // If the query string was edited, make sure the syntax is valid
486+ if expanded != rawSQL {
487+ if _ , err := pg .Parse (expanded ); err != nil {
488+ return nil , fmt .Errorf ("edited query syntax is invalid: %w" , err )
489+ }
490+ }
471491
472492 trimmed , comments , err := stripComments (strings .TrimSpace (expanded ))
473493 if err != nil {
@@ -506,7 +526,7 @@ type edit struct {
506526 New string
507527}
508528
509- func expand (qc * QueryCatalog , raw nodes.RawStmt , sql string ) (string , error ) {
529+ func expand (qc * QueryCatalog , raw nodes.RawStmt ) ([] edit , error ) {
510530 list := search (raw , func (node nodes.Node ) bool {
511531 switch node .(type ) {
512532 case nodes.DeleteStmt :
@@ -519,17 +539,17 @@ func expand(qc *QueryCatalog, raw nodes.RawStmt, sql string) (string, error) {
519539 return true
520540 })
521541 if len (list .Items ) == 0 {
522- return sql , nil
542+ return nil , nil
523543 }
524544 var edits []edit
525545 for _ , item := range list .Items {
526546 edit , err := expandStmt (qc , raw , item )
527547 if err != nil {
528- return "" , err
548+ return nil , err
529549 }
530550 edits = append (edits , edit ... )
531551 }
532- return editQuery ( sql , edits )
552+ return edits , nil
533553}
534554
535555func expandStmt (qc * QueryCatalog , raw nodes.RawStmt , node nodes.Node ) ([]edit , error ) {
@@ -983,6 +1003,7 @@ type paramRef struct {
9831003 parent nodes.Node
9841004 rv * nodes.RangeVar
9851005 ref nodes.ParamRef
1006+ name string // Named parameter support
9861007}
9871008
9881009type paramSearch struct {
@@ -1014,10 +1035,16 @@ type limitOffset struct {
10141035 nodeImpl
10151036}
10161037
1017- func (p paramSearch ) Visit (node nodes.Node ) Visitor {
1038+ func (p paramSearch ) Visit (node nodes.Node ) ast. Visitor {
10181039 switch n := node .(type ) {
10191040
10201041 case nodes.A_Expr :
1042+ if join (n .Name , "-" ) == "@" && n .Lexpr == nil {
1043+ param := nodes.ParamRef {Number : 1 }
1044+ // TODO: Remove hard-coded slug
1045+ p .refs [1 ] = paramRef {parent : p .parent , rv : p .rangeVar , name : "slug" , ref : param }
1046+ return nil
1047+ }
10211048 p .parent = node
10221049
10231050 case nodes.FuncCall :
@@ -1111,7 +1138,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
11111138
11121139func findParameters (root nodes.Node ) []paramRef {
11131140 v := paramSearch {refs : map [int ]paramRef {}}
1114- Walk (v , root )
1141+ ast . Walk (v , root )
11151142 refs := make ([]paramRef , 0 )
11161143 for _ , r := range v .refs {
11171144 refs = append (refs , r )
@@ -1125,7 +1152,7 @@ type nodeSearch struct {
11251152 check func (nodes.Node ) bool
11261153}
11271154
1128- func (s * nodeSearch ) Visit (node nodes.Node ) Visitor {
1155+ func (s * nodeSearch ) Visit (node nodes.Node ) ast. Visitor {
11291156 if s .check (node ) {
11301157 s .list .Items = append (s .list .Items , node )
11311158 }
@@ -1134,16 +1161,23 @@ func (s *nodeSearch) Visit(node nodes.Node) Visitor {
11341161
11351162func search (root nodes.Node , f func (nodes.Node ) bool ) nodes.List {
11361163 ns := & nodeSearch {check : f }
1137- Walk (ns , root )
1164+ ast . Walk (ns , root )
11381165 return ns .list
11391166}
11401167
1141- func resolveCatalogRefs (c core.Catalog , rvs []nodes.RangeVar , args []paramRef ) ([]Parameter , error ) {
1168+ func resolveCatalogRefs (c core.Catalog , rvs []nodes.RangeVar , args []paramRef , names map [ int ] string ) ([]Parameter , error ) {
11421169 aliasMap := map [string ]core.FQN {}
11431170 // TODO: Deprecate defaultTable
11441171 var defaultTable * core.FQN
11451172 var tables []core.FQN
11461173
1174+ parameterName := func (n int , defaultName string ) string {
1175+ if n , ok := names [n ]; ok {
1176+ return n
1177+ }
1178+ return defaultName
1179+ }
1180+
11471181 for _ , rv := range rvs {
11481182 if rv .Relname == nil {
11491183 continue
@@ -1193,7 +1227,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
11931227 a = append (a , Parameter {
11941228 Number : ref .ref .Number ,
11951229 Column : core.Column {
1196- Name : "offset" ,
1230+ Name : parameterName ( ref . ref . Number , "offset" ) ,
11971231 DataType : "integer" ,
11981232 NotNull : true ,
11991233 },
@@ -1203,7 +1237,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
12031237 a = append (a , Parameter {
12041238 Number : ref .ref .Number ,
12051239 Column : core.Column {
1206- Name : "limit" ,
1240+ Name : parameterName ( ref . ref . Number , "limit" ) ,
12071241 DataType : "integer" ,
12081242 NotNull : true ,
12091243 },
@@ -1256,10 +1290,13 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
12561290 for _ , table := range search {
12571291 if c , ok := typeMap [table.Schema ][table.Rel ][key ]; ok {
12581292 found += 1
1293+ if ref .name != "" {
1294+ key = ref .name
1295+ }
12591296 a = append (a , Parameter {
12601297 Number : ref .ref .Number ,
12611298 Column : core.Column {
1262- Name : key ,
1299+ Name : parameterName ( ref . ref . Number , key ) ,
12631300 DataType : c .DataType ,
12641301 NotNull : c .NotNull ,
12651302 IsArray : c .IsArray ,
@@ -1312,7 +1349,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13121349 a = append (a , Parameter {
13131350 Number : ref .ref .Number ,
13141351 Column : core.Column {
1315- Name : fun .Name ,
1352+ Name : parameterName ( ref . ref . Number , fun .Name ) ,
13161353 DataType : "any" ,
13171354 },
13181355 })
@@ -1329,7 +1366,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13291366 a = append (a , Parameter {
13301367 Number : ref .ref .Number ,
13311368 Column : core.Column {
1332- Name : name ,
1369+ Name : parameterName ( ref . ref . Number , name ) ,
13331370 DataType : arg .DataType ,
13341371 NotNull : true ,
13351372 },
@@ -1345,7 +1382,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13451382 a = append (a , Parameter {
13461383 Number : ref .ref .Number ,
13471384 Column : core.Column {
1348- Name : key ,
1385+ Name : parameterName ( ref . ref . Number , key ) ,
13491386 DataType : c .DataType ,
13501387 NotNull : c .NotNull ,
13511388 IsArray : c .IsArray ,
@@ -1364,9 +1401,11 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13641401 if n .TypeName == nil {
13651402 return nil , fmt .Errorf ("nodes.TypeCast has nil type name" )
13661403 }
1404+ col := catalog .ToColumn (n .TypeName )
1405+ col .Name = parameterName (ref .ref .Number , col .Name )
13671406 a = append (a , Parameter {
13681407 Number : ref .ref .Number ,
1369- Column : catalog . ToColumn ( n . TypeName ) ,
1408+ Column : col ,
13701409 })
13711410
13721411 case nodes.ParamRef :
0 commit comments