Skip to content

Commit d247b40

Browse files
authored
Named parameters for PostgreSQL (#328)
This change adds PostgreSQL support for sqlc.arg(name) and @name parameter styles. It's powered by a new postgresql/ast package which support rewriting query ASTs.
1 parent bb679c4 commit d247b40

15 files changed

Lines changed: 1985 additions & 25 deletions

File tree

internal/dinosql/checks.go

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,18 @@ import (
44
"fmt"
55
"strings"
66

7+
nodes "github.com/lfittl/pg_query_go/nodes"
8+
79
"github.com/kyleconroy/sqlc/internal/catalog"
810
"github.com/kyleconroy/sqlc/internal/pg"
9-
nodes "github.com/lfittl/pg_query_go/nodes"
11+
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
1012
)
1113

1214
func validateParamRef(n nodes.Node) error {
1315
var allrefs []nodes.ParamRef
1416

1517
// Find all parameter references
16-
Walk(VisitorFunc(func(node nodes.Node) {
18+
ast.Walk(ast.VisitorFunc(func(node nodes.Node) {
1719
switch n := node.(type) {
1820
case nodes.ParamRef:
1921
allrefs = append(allrefs, n)
@@ -41,7 +43,7 @@ type funcCallVisitor struct {
4143
err error
4244
}
4345

44-
func (v *funcCallVisitor) Visit(node nodes.Node) Visitor {
46+
func (v *funcCallVisitor) Visit(node nodes.Node) ast.Visitor {
4547
if v.err != nil {
4648
return nil
4749
}
@@ -91,7 +93,7 @@ func (v *funcCallVisitor) Visit(node nodes.Node) Visitor {
9193

9294
func validateFuncCall(c *pg.Catalog, n nodes.Node) error {
9395
visitor := funcCallVisitor{catalog: c}
94-
Walk(&visitor, n)
96+
ast.Walk(&visitor, n)
9597
return visitor.err
9698
}
9799

@@ -120,3 +122,29 @@ func validateInsertStmt(stmt nodes.InsertStmt) error {
120122
}
121123
return nil
122124
}
125+
126+
// A query can use one (and only one) of the following formats:
127+
// - positional parameters $1
128+
// - named parameter operator @param
129+
// - named parameter function calls sqlc.arg(param)
130+
func validateParamStyle(n nodes.Node) error {
131+
positional := search(n, func(node nodes.Node) bool {
132+
_, ok := node.(nodes.ParamRef)
133+
return ok
134+
})
135+
namedFunc := search(n, isNamedParamFunc)
136+
namedSign := search(n, isNamedParamSign)
137+
for _, check := range []bool{
138+
len(positional.Items) > 0 && len(namedSign.Items)+len(namedFunc.Items) > 0,
139+
len(namedFunc.Items) > 0 && len(namedSign.Items)+len(positional.Items) > 0,
140+
len(namedSign.Items) > 0 && len(positional.Items)+len(namedFunc.Items) > 0,
141+
} {
142+
if check {
143+
return pg.Error{
144+
Code: "", // TODO: Pick a new error code
145+
Message: "query mixes positional parameters ($1) and named parameters (sqlc.arg or @arg)",
146+
}
147+
}
148+
}
149+
return nil
150+
}

internal/dinosql/parser.go

Lines changed: 59 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616
"github.com/kyleconroy/sqlc/internal/config"
1717
core "github.com/kyleconroy/sqlc/internal/pg"
1818
"github.com/kyleconroy/sqlc/internal/postgres"
19+
"github.com/kyleconroy/sqlc/internal/postgresql/ast"
1920

2021
"github.com/davecgh/go-spew/spew"
2122
pg "github.com/lfittl/pg_query_go"
@@ -315,13 +316,13 @@ func pluckQuery(source string, n nodes.RawStmt) (string, error) {
315316

316317
func rangeVars(root nodes.Node) []nodes.RangeVar {
317318
var vars []nodes.RangeVar
318-
find := VisitorFunc(func(node nodes.Node) {
319+
find := ast.VisitorFunc(func(node nodes.Node) {
319320
switch n := node.(type) {
320321
case nodes.RangeVar:
321322
vars = append(vars, n)
322323
}
323324
})
324-
Walk(find, root)
325+
ast.Walk(find, root)
325326
return vars
326327
}
327328

@@ -416,6 +417,9 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
416417
if err := validateParamRef(stmt); err != nil {
417418
return nil, err
418419
}
420+
if err := validateParamStyle(stmt); err != nil {
421+
return nil, err
422+
}
419423
raw, ok := stmt.(nodes.RawStmt)
420424
if !ok {
421425
return nil, errors.New("node is not a statement")
@@ -449,9 +453,12 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
449453
if err := validateCmd(raw.Stmt, name, cmd); err != nil {
450454
return nil, err
451455
}
456+
457+
// Re-write query AST
458+
raw, namedParams, edits := rewriteNamedParameters(raw)
452459
rvs := rangeVars(raw.Stmt)
453460
refs := findParameters(raw.Stmt)
454-
params, err := resolveCatalogRefs(c, rvs, refs)
461+
params, err := resolveCatalogRefs(c, rvs, refs, namedParams)
455462
if err != nil {
456463
return nil, err
457464
}
@@ -464,10 +471,23 @@ func parseQuery(c core.Catalog, stmt nodes.Node, source string) (*Query, error)
464471
if err != nil {
465472
return nil, err
466473
}
467-
expanded, err := expand(qc, raw, rawSQL)
474+
475+
expandEdits, err := expand(qc, raw)
468476
if err != nil {
469477
return nil, err
470478
}
479+
edits = append(edits, expandEdits...)
480+
expanded, err := editQuery(rawSQL, edits)
481+
if err != nil {
482+
return nil, err
483+
}
484+
485+
// If the query string was edited, make sure the syntax is valid
486+
if expanded != rawSQL {
487+
if _, err := pg.Parse(expanded); err != nil {
488+
return nil, fmt.Errorf("edited query syntax is invalid: %w", err)
489+
}
490+
}
471491

472492
trimmed, comments, err := stripComments(strings.TrimSpace(expanded))
473493
if err != nil {
@@ -506,7 +526,7 @@ type edit struct {
506526
New string
507527
}
508528

509-
func expand(qc *QueryCatalog, raw nodes.RawStmt, sql string) (string, error) {
529+
func expand(qc *QueryCatalog, raw nodes.RawStmt) ([]edit, error) {
510530
list := search(raw, func(node nodes.Node) bool {
511531
switch node.(type) {
512532
case nodes.DeleteStmt:
@@ -519,17 +539,17 @@ func expand(qc *QueryCatalog, raw nodes.RawStmt, sql string) (string, error) {
519539
return true
520540
})
521541
if len(list.Items) == 0 {
522-
return sql, nil
542+
return nil, nil
523543
}
524544
var edits []edit
525545
for _, item := range list.Items {
526546
edit, err := expandStmt(qc, raw, item)
527547
if err != nil {
528-
return "", err
548+
return nil, err
529549
}
530550
edits = append(edits, edit...)
531551
}
532-
return editQuery(sql, edits)
552+
return edits, nil
533553
}
534554

535555
func expandStmt(qc *QueryCatalog, raw nodes.RawStmt, node nodes.Node) ([]edit, error) {
@@ -983,6 +1003,7 @@ type paramRef struct {
9831003
parent nodes.Node
9841004
rv *nodes.RangeVar
9851005
ref nodes.ParamRef
1006+
name string // Named parameter support
9861007
}
9871008

9881009
type paramSearch struct {
@@ -1014,10 +1035,16 @@ type limitOffset struct {
10141035
nodeImpl
10151036
}
10161037

1017-
func (p paramSearch) Visit(node nodes.Node) Visitor {
1038+
func (p paramSearch) Visit(node nodes.Node) ast.Visitor {
10181039
switch n := node.(type) {
10191040

10201041
case nodes.A_Expr:
1042+
if join(n.Name, "-") == "@" && n.Lexpr == nil {
1043+
param := nodes.ParamRef{Number: 1}
1044+
// TODO: Remove hard-coded slug
1045+
p.refs[1] = paramRef{parent: p.parent, rv: p.rangeVar, name: "slug", ref: param}
1046+
return nil
1047+
}
10211048
p.parent = node
10221049

10231050
case nodes.FuncCall:
@@ -1111,7 +1138,7 @@ func (p paramSearch) Visit(node nodes.Node) Visitor {
11111138

11121139
func findParameters(root nodes.Node) []paramRef {
11131140
v := paramSearch{refs: map[int]paramRef{}}
1114-
Walk(v, root)
1141+
ast.Walk(v, root)
11151142
refs := make([]paramRef, 0)
11161143
for _, r := range v.refs {
11171144
refs = append(refs, r)
@@ -1125,7 +1152,7 @@ type nodeSearch struct {
11251152
check func(nodes.Node) bool
11261153
}
11271154

1128-
func (s *nodeSearch) Visit(node nodes.Node) Visitor {
1155+
func (s *nodeSearch) Visit(node nodes.Node) ast.Visitor {
11291156
if s.check(node) {
11301157
s.list.Items = append(s.list.Items, node)
11311158
}
@@ -1134,16 +1161,23 @@ func (s *nodeSearch) Visit(node nodes.Node) Visitor {
11341161

11351162
func search(root nodes.Node, f func(nodes.Node) bool) nodes.List {
11361163
ns := &nodeSearch{check: f}
1137-
Walk(ns, root)
1164+
ast.Walk(ns, root)
11381165
return ns.list
11391166
}
11401167

1141-
func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) ([]Parameter, error) {
1168+
func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef, names map[int]string) ([]Parameter, error) {
11421169
aliasMap := map[string]core.FQN{}
11431170
// TODO: Deprecate defaultTable
11441171
var defaultTable *core.FQN
11451172
var tables []core.FQN
11461173

1174+
parameterName := func(n int, defaultName string) string {
1175+
if n, ok := names[n]; ok {
1176+
return n
1177+
}
1178+
return defaultName
1179+
}
1180+
11471181
for _, rv := range rvs {
11481182
if rv.Relname == nil {
11491183
continue
@@ -1193,7 +1227,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
11931227
a = append(a, Parameter{
11941228
Number: ref.ref.Number,
11951229
Column: core.Column{
1196-
Name: "offset",
1230+
Name: parameterName(ref.ref.Number, "offset"),
11971231
DataType: "integer",
11981232
NotNull: true,
11991233
},
@@ -1203,7 +1237,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
12031237
a = append(a, Parameter{
12041238
Number: ref.ref.Number,
12051239
Column: core.Column{
1206-
Name: "limit",
1240+
Name: parameterName(ref.ref.Number, "limit"),
12071241
DataType: "integer",
12081242
NotNull: true,
12091243
},
@@ -1256,10 +1290,13 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
12561290
for _, table := range search {
12571291
if c, ok := typeMap[table.Schema][table.Rel][key]; ok {
12581292
found += 1
1293+
if ref.name != "" {
1294+
key = ref.name
1295+
}
12591296
a = append(a, Parameter{
12601297
Number: ref.ref.Number,
12611298
Column: core.Column{
1262-
Name: key,
1299+
Name: parameterName(ref.ref.Number, key),
12631300
DataType: c.DataType,
12641301
NotNull: c.NotNull,
12651302
IsArray: c.IsArray,
@@ -1312,7 +1349,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13121349
a = append(a, Parameter{
13131350
Number: ref.ref.Number,
13141351
Column: core.Column{
1315-
Name: fun.Name,
1352+
Name: parameterName(ref.ref.Number, fun.Name),
13161353
DataType: "any",
13171354
},
13181355
})
@@ -1329,7 +1366,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13291366
a = append(a, Parameter{
13301367
Number: ref.ref.Number,
13311368
Column: core.Column{
1332-
Name: name,
1369+
Name: parameterName(ref.ref.Number, name),
13331370
DataType: arg.DataType,
13341371
NotNull: true,
13351372
},
@@ -1345,7 +1382,7 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13451382
a = append(a, Parameter{
13461383
Number: ref.ref.Number,
13471384
Column: core.Column{
1348-
Name: key,
1385+
Name: parameterName(ref.ref.Number, key),
13491386
DataType: c.DataType,
13501387
NotNull: c.NotNull,
13511388
IsArray: c.IsArray,
@@ -1364,9 +1401,11 @@ func resolveCatalogRefs(c core.Catalog, rvs []nodes.RangeVar, args []paramRef) (
13641401
if n.TypeName == nil {
13651402
return nil, fmt.Errorf("nodes.TypeCast has nil type name")
13661403
}
1404+
col := catalog.ToColumn(n.TypeName)
1405+
col.Name = parameterName(ref.ref.Number, col.Name)
13671406
a = append(a, Parameter{
13681407
Number: ref.ref.Number,
1369-
Column: catalog.ToColumn(n.TypeName),
1408+
Column: col,
13701409
})
13711410

13721411
case nodes.ParamRef:

0 commit comments

Comments
 (0)