Skip to content

Commit 2e9294f

Browse files
authored
dolphin: Implement exapansion with reserved words (#620)
compiler: rename the Engine struct to Compiler compiler: move methods onto the Compiler struct compiler: Use backticks for reserved words in MySQL codegen: Handle queries with backticks
1 parent 07eb5d4 commit 2e9294f

18 files changed

Lines changed: 170 additions & 54 deletions

File tree

internal/cmd/generate.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,8 @@ func parseMySQL(e Env, name, dir string, sql config.SQL, combo config.CombinedSe
208208
}
209209

210210
func parse(e Env, name, dir string, sql config.SQL, combo config.CombinedSettings, parserOpts opts.Parser, stderr io.Writer) (*compiler.Result, bool) {
211-
eng := compiler.NewEngine(sql, combo)
212-
if err := eng.ParseCatalog(sql.Schema); err != nil {
211+
c := compiler.NewCompiler(sql, combo)
212+
if err := c.ParseCatalog(sql.Schema); err != nil {
213213
fmt.Fprintf(stderr, "# package %s\n", name)
214214
if parserErr, ok := err.(*multierr.Error); ok {
215215
for _, fileErr := range parserErr.Errs() {
@@ -221,9 +221,9 @@ func parse(e Env, name, dir string, sql config.SQL, combo config.CombinedSetting
221221
return nil, true
222222
}
223223
if parserOpts.Debug.DumpCatalog {
224-
debug.Dump(eng.Catalog())
224+
debug.Dump(c.Catalog())
225225
}
226-
if err := eng.ParseQueries(sql.Queries, parserOpts); err != nil {
226+
if err := c.ParseQueries(sql.Queries, parserOpts); err != nil {
227227
fmt.Fprintf(stderr, "# package %s\n", name)
228228
if parserErr, ok := err.(*multierr.Error); ok {
229229
for _, fileErr := range parserErr.Errs() {
@@ -234,5 +234,5 @@ func parse(e Env, name, dir string, sql config.SQL, combo config.CombinedSetting
234234
}
235235
return nil, true
236236
}
237-
return eng.Result(), false
237+
return c.Result(), false
238238
}

internal/codegen/golang/gen.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ import (
238238
{{range .GoQueries}}
239239
{{if $.OutputQuery .SourceName}}
240240
const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
241-
{{.SQL}}
241+
{{escape .SQL}}
242242
{{$.Q}}
243243
244244
{{if .Arg.EmitStruct}}
@@ -393,6 +393,7 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
393393
funcMap := template.FuncMap{
394394
"lowerTitle": codegen.LowerTitle,
395395
"comment": codegen.DoubleSlashComment,
396+
"escape": codegen.EscapeBacktick,
396397
"imports": i.Imports,
397398
}
398399

internal/codegen/utils.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@ func LowerTitle(s string) string {
1111
return string(a)
1212
}
1313

14+
// Go string literals cannot contain backtick. If a string contains
15+
// a backtick, replace it the following way:
16+
//
17+
// input:
18+
// SELECT `group` FROM foo
19+
//
20+
// output:
21+
// SELECT ` + "`" + `group` + "`" + ` FROM foo
22+
//
23+
// The escaped string must be rendered inside an existing string literal
24+
//
25+
// A string cannot be escaped twice
26+
func EscapeBacktick(s string) string {
27+
return strings.Replace(s, "`", "`+\"`\"+`", -1)
28+
}
29+
1430
func DoubleSlashComment(s string) string {
1531
return "// " + strings.ReplaceAll(s, "\n", "\n// ")
1632
}

internal/compiler/compile.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,11 @@ func parseCatalog(p Parser, c *catalog.Catalog, schemas []string) error {
8383
return nil
8484
}
8585

86-
func parseQueries(p Parser, c *catalog.Catalog, queries []string, o opts.Parser) (*Result, error) {
86+
func (c *Compiler) parseQueries(o opts.Parser) (*Result, error) {
8787
var q []*Query
8888
merr := multierr.New()
8989
set := map[string]struct{}{}
90-
files, err := sqlpath.Glob(queries)
90+
files, err := sqlpath.Glob(c.conf.Queries)
9191
if err != nil {
9292
return nil, err
9393
}
@@ -98,13 +98,13 @@ func parseQueries(p Parser, c *catalog.Catalog, queries []string, o opts.Parser)
9898
continue
9999
}
100100
src := string(blob)
101-
stmts, err := p.Parse(strings.NewReader(src))
101+
stmts, err := c.parser.Parse(strings.NewReader(src))
102102
if err != nil {
103103
merr.Add(filename, src, 0, err)
104104
continue
105105
}
106106
for _, stmt := range stmts {
107-
query, err := parseQuery(p, c, stmt.Raw, src, o)
107+
query, err := c.parseQuery(stmt.Raw, src, o)
108108
if err == ErrUnsupportedStatementType {
109109
continue
110110
}
@@ -134,10 +134,10 @@ func parseQueries(p Parser, c *catalog.Catalog, queries []string, o opts.Parser)
134134
return nil, merr
135135
}
136136
if len(q) == 0 {
137-
return nil, fmt.Errorf("no queries contained in paths %s", strings.Join(queries, ","))
137+
return nil, fmt.Errorf("no queries contained in paths %s", strings.Join(c.conf.Queries, ","))
138138
}
139139
return &Result{
140-
Catalog: c,
140+
Catalog: c.catalog,
141141
Queries: q,
142142
}, nil
143143
}

internal/compiler/engine.go

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,51 +11,49 @@ import (
1111
"github.com/kyleconroy/sqlc/internal/sql/catalog"
1212
)
1313

14-
// The Engine type only exists as a compatibility shim between the old dinosql
15-
// package and the new compiler package.
16-
type Engine struct {
14+
type Compiler struct {
1715
conf config.SQL
1816
combo config.CombinedSettings
1917
catalog *catalog.Catalog
2018
parser Parser
2119
result *Result
2220
}
2321

24-
func NewEngine(conf config.SQL, combo config.CombinedSettings) *Engine {
25-
e := &Engine{conf: conf, combo: combo}
22+
func NewCompiler(conf config.SQL, combo config.CombinedSettings) *Compiler {
23+
c := &Compiler{conf: conf, combo: combo}
2624
switch conf.Engine {
2725
case config.EngineXLemon:
28-
e.parser = sqlite.NewParser()
29-
e.catalog = catalog.New("main")
26+
c.parser = sqlite.NewParser()
27+
c.catalog = catalog.New("main")
3028
case config.EngineMySQL, config.EngineXDolphin:
31-
e.parser = dolphin.NewParser()
32-
e.catalog = catalog.New("public") // TODO: What is the default database for MySQL?
29+
c.parser = dolphin.NewParser()
30+
c.catalog = catalog.New("public") // TODO: What is the default database for MySQL?
3331
case config.EnginePostgreSQL:
34-
e.parser = postgresql.NewParser()
35-
e.catalog = postgresql.NewCatalog()
32+
c.parser = postgresql.NewParser()
33+
c.catalog = postgresql.NewCatalog()
3634
default:
3735
panic(fmt.Sprintf("unknown engine: %s", conf.Engine))
3836
}
39-
return e
37+
return c
4038
}
4139

42-
func (e *Engine) Catalog() *catalog.Catalog {
43-
return e.catalog
40+
func (c *Compiler) Catalog() *catalog.Catalog {
41+
return c.catalog
4442
}
4543

46-
func (e *Engine) ParseCatalog(schema []string) error {
47-
return parseCatalog(e.parser, e.catalog, schema)
44+
func (c *Compiler) ParseCatalog(schema []string) error {
45+
return parseCatalog(c.parser, c.catalog, schema)
4846
}
4947

50-
func (e *Engine) ParseQueries(queries []string, o opts.Parser) error {
51-
r, err := parseQueries(e.parser, e.catalog, e.conf.Queries, o)
48+
func (c *Compiler) ParseQueries(queries []string, o opts.Parser) error {
49+
r, err := c.parseQueries(o)
5250
if err != nil {
5351
return err
5452
}
55-
e.result = r
53+
c.result = r
5654
return nil
5755
}
5856

59-
func (e *Engine) Result() *Result {
60-
return e.result
57+
func (c *Compiler) Result() *Result {
58+
return c.result
6159
}

internal/compiler/expand.go

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ import (
44
"fmt"
55
"strings"
66

7+
"github.com/kyleconroy/sqlc/internal/config"
78
"github.com/kyleconroy/sqlc/internal/source"
89
"github.com/kyleconroy/sqlc/internal/sql/ast"
910
"github.com/kyleconroy/sqlc/internal/sql/ast/pg"
1011
"github.com/kyleconroy/sqlc/internal/sql/astutils"
1112
"github.com/kyleconroy/sqlc/internal/sql/lang"
1213
)
1314

14-
func expand(qc *QueryCatalog, raw *ast.RawStmt) ([]source.Edit, error) {
15+
func (c *Compiler) expand(qc *QueryCatalog, raw *ast.RawStmt) ([]source.Edit, error) {
1516
list := astutils.Search(raw, func(node ast.Node) bool {
1617
switch node.(type) {
1718
case *pg.DeleteStmt:
@@ -28,7 +29,7 @@ func expand(qc *QueryCatalog, raw *ast.RawStmt) ([]source.Edit, error) {
2829
}
2930
var edits []source.Edit
3031
for _, item := range list.Items {
31-
edit, err := expandStmt(qc, raw, item)
32+
edit, err := c.expandStmt(qc, raw, item)
3233
if err != nil {
3334
return nil, err
3435
}
@@ -37,14 +38,20 @@ func expand(qc *QueryCatalog, raw *ast.RawStmt) ([]source.Edit, error) {
3738
return edits, nil
3839
}
3940

40-
func quoteIdent(ident string) string {
41+
func (c *Compiler) quoteIdent(ident string) string {
42+
// TODO: Add a method to the parser / engine for this instead
4143
if lang.IsReservedKeyword(ident) {
42-
return "\"" + ident + "\""
44+
switch c.conf.Engine {
45+
case config.EngineMySQL, config.EngineXDolphin:
46+
return "`" + ident + "`"
47+
default:
48+
return "\"" + ident + "\""
49+
}
4350
}
4451
return ident
4552
}
4653

47-
func expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edit, error) {
54+
func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edit, error) {
4855
tables, err := sourceTables(qc, node)
4956
if err != nil {
5057
return nil, err
@@ -103,14 +110,14 @@ func expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edi
103110
if scope != "" && scope != t.Rel.Name {
104111
continue
105112
}
106-
tableName := quoteIdent(t.Rel.Name)
107-
scopeName := quoteIdent(scope)
108-
for _, c := range t.Columns {
109-
cname := c.Name
113+
tableName := c.quoteIdent(t.Rel.Name)
114+
scopeName := c.quoteIdent(scope)
115+
for _, column := range t.Columns {
116+
cname := column.Name
110117
if res.Name != nil {
111118
cname = *res.Name
112119
}
113-
cname = quoteIdent(cname)
120+
cname = c.quoteIdent(cname)
114121
if scope != "" {
115122
cname = scopeName + "." + cname
116123
}
@@ -122,7 +129,7 @@ func expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) ([]source.Edi
122129
}
123130
var old []string
124131
for _, p := range parts {
125-
old = append(old, quoteIdent(p))
132+
old = append(old, c.quoteIdent(p))
126133
}
127134
edits = append(edits, source.Edit{
128135
Location: res.Location - raw.StmtLocation,

internal/compiler/parse.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"github.com/kyleconroy/sqlc/internal/sql/ast"
1414
"github.com/kyleconroy/sqlc/internal/sql/ast/pg"
1515
"github.com/kyleconroy/sqlc/internal/sql/astutils"
16-
"github.com/kyleconroy/sqlc/internal/sql/catalog"
1716
"github.com/kyleconroy/sqlc/internal/sql/rewrite"
1817
"github.com/kyleconroy/sqlc/internal/sql/validate"
1918
)
@@ -32,7 +31,7 @@ func rewriteNumberedParameters(refs []paramRef, raw *ast.RawStmt, sql string) ([
3231
return edits, nil
3332
}
3433

35-
func parseQuery(p Parser, c *catalog.Catalog, stmt ast.Node, src string, o opts.Parser) (*Query, error) {
34+
func (c *Compiler) parseQuery(stmt ast.Node, src string, o opts.Parser) (*Query, error) {
3635
if o.Debug.DumpAST {
3736
debug.Dump(stmt)
3837
}
@@ -66,10 +65,10 @@ func parseQuery(p Parser, c *catalog.Catalog, stmt ast.Node, src string, o opts.
6665
if rawSQL == "" {
6766
return nil, errors.New("missing semicolon at end of file")
6867
}
69-
if err := validate.FuncCall(c, raw); err != nil {
68+
if err := validate.FuncCall(c.catalog, raw); err != nil {
7069
return nil, err
7170
}
72-
name, cmd, err := metadata.Parse(strings.TrimSpace(rawSQL), p.CommentSyntax())
71+
name, cmd, err := metadata.Parse(strings.TrimSpace(rawSQL), c.parser.CommentSyntax())
7372
if err != nil {
7473
return nil, err
7574
}
@@ -89,12 +88,12 @@ func parseQuery(p Parser, c *catalog.Catalog, stmt ast.Node, src string, o opts.
8988
refs = uniqueParamRefs(refs)
9089
sort.Slice(refs, func(i, j int) bool { return refs[i].ref.Number < refs[j].ref.Number })
9190
}
92-
params, err := resolveCatalogRefs(c, rvs, refs, namedParams)
91+
params, err := resolveCatalogRefs(c.catalog, rvs, refs, namedParams)
9392
if err != nil {
9493
return nil, err
9594
}
9695

97-
qc, err := buildQueryCatalog(c, raw.Stmt)
96+
qc, err := buildQueryCatalog(c.catalog, raw.Stmt)
9897
if err != nil {
9998
return nil, err
10099
}
@@ -103,7 +102,7 @@ func parseQuery(p Parser, c *catalog.Catalog, stmt ast.Node, src string, o opts.
103102
return nil, err
104103
}
105104

106-
expandEdits, err := expand(qc, raw)
105+
expandEdits, err := c.expand(qc, raw)
107106
if err != nil {
108107
return nil, err
109108
}
@@ -116,7 +115,7 @@ func parseQuery(p Parser, c *catalog.Catalog, stmt ast.Node, src string, o opts.
116115

117116
// If the query string was edited, make sure the syntax is valid
118117
if expanded != rawSQL {
119-
if _, err := p.Parse(strings.NewReader(expanded)); err != nil {
118+
if _, err := c.parser.Parse(strings.NewReader(expanded)); err != nil {
120119
return nil, fmt.Errorf("edited query syntax is invalid: %w", err)
121120
}
122121
}

internal/endtoend/testdata/star_expansion_reserved/go/db.go renamed to internal/endtoend/testdata/star_expansion_reserved/mysql/go/db.go

File renamed without changes.

internal/endtoend/testdata/star_expansion_reserved/go/models.go renamed to internal/endtoend/testdata/star_expansion_reserved/mysql/go/models.go

File renamed without changes.

internal/endtoend/testdata/star_expansion_reserved/mysql/go/query.sql.go

Lines changed: 36 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)