@@ -202,10 +202,88 @@ func Imports(r Generateable, settings GenerateSettings) func(string) [][]string
202202 return ModelImports (r , settings )
203203 }
204204
205+ if filename == "querier.go" {
206+ return InterfaceImports (r , settings )
207+ }
208+
205209 return QueryImports (r , settings , filename )
206210 }
207211}
208212
213+ func InterfaceImports (r Generateable , settings GenerateSettings ) [][]string {
214+ gq := r .GoQueries (settings )
215+ uses := func (name string ) bool {
216+ for _ , q := range gq {
217+ if ! q .Ret .isEmpty () {
218+ if strings .HasPrefix (q .Ret .Type (), name ) {
219+ return true
220+ }
221+ }
222+ if ! q .Arg .isEmpty () {
223+ if strings .HasPrefix (q .Arg .Type (), name ) {
224+ return true
225+ }
226+ }
227+ }
228+ return false
229+ }
230+
231+ std := map [string ]struct {}{
232+ "context" : struct {}{},
233+ }
234+ if uses ("sql.Null" ) {
235+ std ["database/sql" ] = struct {}{}
236+ }
237+ if uses ("json.RawMessage" ) {
238+ std ["encoding/json" ] = struct {}{}
239+ }
240+ if uses ("time.Time" ) {
241+ std ["time" ] = struct {}{}
242+ }
243+ if uses ("net.IP" ) {
244+ std ["net" ] = struct {}{}
245+ }
246+
247+ pkg := make (map [string ]struct {})
248+ overrideTypes := map [string ]string {}
249+ for _ , o := range append (settings .Overrides , settings .PackageMap [r .PkgName ()].Overrides ... ) {
250+ if o .goBasicType {
251+ continue
252+ }
253+ overrideTypes [o .goTypeName ] = o .goPackage
254+ }
255+
256+ _ , overrideNullTime := overrideTypes ["pq.NullTime" ]
257+ if uses ("pq.NullTime" ) && ! overrideNullTime {
258+ pkg ["github.com/lib/pq" ] = struct {}{}
259+ }
260+ _ , overrideUUID := overrideTypes ["uuid.UUID" ]
261+ if uses ("uuid.UUID" ) && ! overrideUUID {
262+ pkg ["github.com/google/uuid" ] = struct {}{}
263+ }
264+
265+ // Custom imports
266+ for goType , importPath := range overrideTypes {
267+ if _ , ok := std [importPath ]; ! ok && uses (goType ) {
268+ pkg [importPath ] = struct {}{}
269+ }
270+ }
271+
272+ pkgs := make ([]string , 0 , len (pkg ))
273+ for p , _ := range pkg {
274+ pkgs = append (pkgs , p )
275+ }
276+
277+ stds := make ([]string , 0 , len (std ))
278+ for s , _ := range std {
279+ stds = append (stds , s )
280+ }
281+
282+ sort .Strings (stds )
283+ sort .Strings (pkgs )
284+ return [][]string {stds , pkgs }
285+ }
286+
209287func ModelImports (r Generateable , settings GenerateSettings ) [][]string {
210288 std := make (map [string ]struct {})
211289 if UsesType (r , "sql.Null" , settings ) {
@@ -903,8 +981,19 @@ func (q *Queries) WithTx(tx *sql.Tx) *Queries {
903981 {{- end}}
904982 }
905983}
984+ `
985+
986+ var ifaceTmpl = `// Code generated by sqlc. DO NOT EDIT.
987+
988+ package {{.Package}}
989+
990+ import (
991+ {{range imports .SourceName}}
992+ {{range .}}"{{.}}"
993+ {{end}}
994+ {{end}}
995+ )
906996
907- {{if .EmitInterface }}
908997type Querier interface {
909998 {{- range .GoQueries}}
910999 {{- if eq .Cmd ":one"}}
@@ -923,7 +1012,6 @@ type Querier interface {
9231012}
9241013
9251014var _ Querier = (*Queries)(nil)
926- {{end}}
9271015`
9281016
9291017var modelsTmpl = `// Code generated by sqlc. DO NOT EDIT.
@@ -1112,6 +1200,7 @@ func Generate(r Generateable, settings GenerateSettings) (map[string]string, err
11121200 dbFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (dbTmpl ))
11131201 modelsFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (modelsTmpl ))
11141202 sqlFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (sqlTmpl ))
1203+ ifaceFile := template .Must (template .New ("table" ).Funcs (funcMap ).Parse (ifaceTmpl ))
11151204
11161205 tctx := tmplCtx {
11171206 Settings : settings ,
@@ -1154,6 +1243,11 @@ func Generate(r Generateable, settings GenerateSettings) (map[string]string, err
11541243 if err := execute ("models.go" , modelsFile ); err != nil {
11551244 return nil , err
11561245 }
1246+ if pkgConfig .EmitInterface {
1247+ if err := execute ("querier.go" , ifaceFile ); err != nil {
1248+ return nil , err
1249+ }
1250+ }
11571251
11581252 files := map [string ]struct {}{}
11591253 for _ , gq := range r .GoQueries (settings ) {
0 commit comments