diff --git a/README.md b/README.md index 3de3a89..eb9662d 100644 --- a/README.md +++ b/README.md @@ -29,6 +29,9 @@ explains how to use `database/sql` along with sqlx. ## Changes compared to the original sqlx +* Better scanning in the case of outer joins. If a struct contains a nested + struct pointer, it will no longer be a scan error. + * Made complex joins easier to scan by using the position of the field to help map duplicate column names into structs. See the [joins example](./examples/joins/main.go). diff --git a/convert.go b/convert.go new file mode 100644 index 0000000..3964a91 --- /dev/null +++ b/convert.go @@ -0,0 +1,8 @@ +package sqlx + +import ( + _ "unsafe" +) + +//go:linkname convertAssign database/sql.convertAssign +func convertAssign(dest, src interface{}) error diff --git a/examples/generics/main.go b/examples/generics/main.go index ed82b82..1837488 100644 --- a/examples/generics/main.go +++ b/examples/generics/main.go @@ -15,7 +15,8 @@ import ( // docker run --name sqlxpg -p 5444:5432 -e POSTGRES_PASSWORD=password -d docker.io/postgres:17.4 const schema = ` - CREATE TABLE IF NOT EXISTS person ( + DROP TABLE IF EXISTS person; + CREATE TABLE person ( id SERIAL PRIMARY KEY, first_name text, last_name text, @@ -23,7 +24,8 @@ const schema = ` ); TRUNCATE TABLE person; - CREATE TABLE IF NOT EXISTS place ( + DROP TABLE IF EXISTS place; + CREATE TABLE place ( country text, city text NULL, telcode integer diff --git a/reflectx/reflect.go b/reflectx/reflect.go index 1663bf0..8c3ce04 100644 --- a/reflectx/reflect.go +++ b/reflectx/reflect.go @@ -235,8 +235,7 @@ func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { v = reflect.Indirect(v).Field(i) // if this is a pointer and it's nil, allocate a new value and set it if v.Kind() == reflect.Ptr && v.IsNil() { - alloc := reflect.New(Deref(v.Type())) - v.Set(alloc) + v.Set(reflect.New(v.Type().Elem())) } if v.Kind() == reflect.Map && v.IsNil() { v.Set(reflect.MakeMap(v.Type())) diff --git a/sqlx.go b/sqlx.go index a0870cf..341bccc 100644 --- a/sqlx.go +++ b/sqlx.go @@ -773,7 +773,7 @@ func (r *Rows) StructScan(dest interface{}) error { r.started = true } - err := fieldsByTraversal(v, r.fields, r.values, true) + err := fieldsByTraversal(v, r.fields, r.values) if err != nil { return err } @@ -990,7 +990,7 @@ func (r *Row) scanAny(dest interface{}, structOnly bool) error { } values := make([]interface{}, len(columns)) - err = fieldsByTraversal(v, fields, values, true) + err = fieldsByTraversal(v, fields, values) if err != nil { return err } @@ -1165,7 +1165,7 @@ func scanAll(rows rowsi, dest interface{}, structOnly bool) error { vp = reflect.New(base) v = reflect.Indirect(vp) - err = fieldsByTraversal(v, fields, values, true) + err = fieldsByTraversal(v, fields, values) if err != nil { return err } @@ -1231,7 +1231,7 @@ func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) { // when iterating over many rows. Empty traversals will get an interface pointer. // Because of the necessity of requesting ptrs or values, it's considered a bit too // specialized for inclusion in reflectx itself. -func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}, ptrs bool) error { +func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{}) error { v = reflect.Indirect(v) if v.Kind() != reflect.Struct { return errors.New("argument not a struct") @@ -1240,23 +1240,37 @@ func fieldsByTraversal(v reflect.Value, traversals [][]int, values []interface{} for i, traversal := range traversals { if len(traversal) == 0 { values[i] = new(interface{}) - continue - } - f := reflectx.FieldByIndexes(v, traversal) - if ptrs { - values[i] = f.Addr().Interface() + } else if len(traversal) == 1 { + values[i] = reflectx.FieldByIndexes(v, traversal).Addr().Interface() } else { - values[i] = f.Interface() + // reflectx.FieldByIndexes initializes pointer fields, including pointers to nested structs. + // Use optDest to delay it until the first non-NULL value is scanned into a field of a nested struct. + // That way we can support LEFT JOINs with optional nested structs. + values[i] = optDest(func() interface{} { + return reflectx.FieldByIndexes(v, traversal).Addr().Interface() + }) } } return nil } -func missingFields(transversals [][]int) (field int, err error) { - for i, t := range transversals { +func missingFields(traversals [][]int) (field int, err error) { + for i, t := range traversals { if len(t) == 0 { return i, errors.New("missing field") } } return 0, nil } + +// optDest will only forward the Scan to the nested value if +// the database value is not nil. +type optDest func() interface{} + +// Scan implements sql.Scanner. +func (dest optDest) Scan(src interface{}) error { + if src == nil { + return nil + } + return convertAssign(dest(), src) +} diff --git a/sqlx_context_test.go b/sqlx_context_test.go index 4fe93b1..c0762cc 100644 --- a/sqlx_context_test.go +++ b/sqlx_context_test.go @@ -473,12 +473,17 @@ func TestNamedQueryContext(t *testing.T) { "FIRST" text NULL, last_name text NULL, "EMAIL" text NULL + ); + CREATE TABLE persondetails ( + email text NULL, + notes text NULL );`, drop: ` drop table person; drop table jsperson; drop table place; drop table placeperson; + drop table persondetails; `, } @@ -495,28 +500,28 @@ func TestNamedQueryContext(t *testing.T) { Email: sql.NullString{String: "ben@doe.com", Valid: true}, } - q1 := `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)` - _, err := db.NamedExecContext(ctx, q1, p) - if err != nil { - log.Fatal(err) - } + _, err := db.NamedExecContext(ctx, `INSERT INTO person (first_name, last_name, email) VALUES (:first_name, :last_name, :email)`, p) + require.NoError(t, err) - p2 := &Person{} - rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) - if err != nil { - log.Fatal(err) - } - for rows.Next() { - err = rows.StructScan(p2) + { + p2 := &Person{} + rows, err := db.NamedQueryContext(ctx, "SELECT * FROM person WHERE first_name=:first_name", p) if err != nil { - t.Error(err) + log.Fatal(err) } - if p2.FirstName.String != "ben" { - t.Error("Expected first name of `ben`, got " + p2.FirstName.String) - } - if p2.LastName.String != "doe" { - t.Error("Expected first name of `doe`, got " + p2.LastName.String) + for rows.Next() { + err = rows.StructScan(p2) + if err != nil { + t.Error(err) + } + if p2.FirstName.String != "ben" { + t.Error("Expected first name of `ben`, got " + p2.FirstName.String) + } + if p2.LastName.String != "doe" { + t.Error("Expected first name of `doe`, got " + p2.LastName.String) + } } + rows.Close() } // these are tests for #73; they verify that named queries work if you've @@ -548,8 +553,7 @@ func TestNamedQueryContext(t *testing.T) { return s } - q1 = `INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)` - _, err = db.NamedExecContext(ctx, pdb(q1, db), jp) + _, err = db.NamedExecContext(ctx, pdb(`INSERT INTO jsperson ("FIRST", last_name, "EMAIL") VALUES (:FIRST, :last_name, :EMAIL)`, db), jp) if err != nil { t.Fatal(err, db.DriverName()) } @@ -581,16 +585,13 @@ func TestNamedQueryContext(t *testing.T) { last_name=:last_name AND "EMAIL"=:EMAIL `, db)) + require.NoError(t, err) - if err != nil { - t.Fatal(err) - } - rows, err = ns.QueryxContext(ctx, jp) - if err != nil { - t.Fatal(err) - } + rows, err := ns.QueryxContext(ctx, jp) + require.NoError(t, err) check(t, rows) + rows.Close() // Check exactly the same thing, but with db.NamedQuery, which does not go // through the PrepareNamed/NamedStmt path. @@ -601,11 +602,10 @@ func TestNamedQueryContext(t *testing.T) { last_name=:last_name AND "EMAIL"=:EMAIL `, db), jp) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) check(t, rows) + rows.Close() db.Mapper = old @@ -625,29 +625,23 @@ func TestNamedQueryContext(t *testing.T) { Name: sql.NullString{String: "myplace", Valid: true}, } - pp := PlacePerson{ + benDoe := PlacePerson{ FirstName: sql.NullString{String: "ben", Valid: true}, LastName: sql.NullString{String: "doe", Valid: true}, Email: sql.NullString{String: "ben@doe.com", Valid: true}, } - q2 := `INSERT INTO place (id, name) VALUES (1, :name)` - _, err = db.NamedExecContext(ctx, q2, pl) - if err != nil { - log.Fatal(err) - } + _, err = db.NamedExecContext(ctx, `INSERT INTO place (id, name) VALUES (1, :name)`, pl) + require.NoError(t, err) id := 1 - pp.Place.ID = id + benDoe.Place.ID = id - q3 := `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)` - _, err = db.NamedExecContext(ctx, q3, pp) - if err != nil { - log.Fatal(err) - } + _, err = db.NamedExecContext(ctx, `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`, benDoe) + require.NoError(t, err) - pp2 := &PlacePerson{} - rows, err = db.NamedQueryContext(ctx, ` + { + rows, err = db.NamedQueryContext(ctx, ` SELECT first_name, last_name, @@ -657,26 +651,170 @@ func TestNamedQueryContext(t *testing.T) { FROM placeperson INNER JOIN place ON place.id = placeperson.place_id WHERE - place.id=:place.id`, pp) - if err != nil { - log.Fatal(err) + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp2, err := range AllRows[PlacePerson](rows) { + require.NoError(t, err) + assert.Equal(t, benDoe.FirstName.String, pp2.FirstName.String) + assert.Equal(t, benDoe.LastName.String, pp2.LastName.String) + assert.Equal(t, benDoe.Email.String, pp2.Email.String) + assert.Equal(t, benDoe.Place.ID, pp2.Place.ID) + } } - for rows.Next() { - err = rows.StructScan(pp2) - if err != nil { - t.Error(err) + + type Owner struct { + Email *string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + } + + // Test optional nested structs with left join + type PlaceOwner struct { + Place Place `db:"place"` + Owner *Owner `db:"owner"` + } + + pl = Place{ + Name: sql.NullString{String: "the-house", Valid: true}, + } + + _, err = db.NamedExecContext(ctx, `INSERT INTO place (id, name) VALUES (2, :name)`, pl) + require.NoError(t, err) + + id = 2 + benDoe.Place.ID = id + + _, err = db.NamedExecContext(ctx, `INSERT INTO placeperson (first_name, last_name, email, place_id) VALUES (:first_name, :last_name, :email, :place.id)`, benDoe) + require.NoError(t, err) + + { + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id, + place.name, + placeperson.first_name, + placeperson.last_name, + placeperson.email + FROM place + LEFT JOIN placeperson ON false -- null left join + WHERE + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp3, err := range AllRows[PlaceOwner](rows) { + require.NoError(t, err) + assert.Nil(t, pp3.Owner, "Expected `Owner` to be nil") + assert.Equal(t, "the-house", pp3.Place.Name.String) + assert.Equal(t, benDoe.Place.ID, pp3.Place.ID) } - if pp2.FirstName.String != "ben" { - t.Error("Expected first name of `ben`, got " + pp2.FirstName.String) + } + + { + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id, + place.name, + placeperson.first_name, + placeperson.last_name, + placeperson.email + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + WHERE + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp4, err := range AllRows[PlaceOwner](rows) { + require.NoError(t, err) + assert.NotNil(t, pp4.Owner, "Expected `Owner` to not be nil") + assert.Equal(t, "ben", pp4.Owner.FirstName) + assert.Equal(t, "doe", pp4.Owner.LastName) + assert.Equal(t, "the-house", pp4.Place.Name.String) + assert.Equal(t, benDoe.Place.ID, pp4.Place.ID) } - if pp2.LastName.String != "doe" { - t.Error("Expected first name of `doe`, got " + pp2.LastName.String) + } + + type Details struct { + Email string `db:"email"` + Notes string `db:"notes"` + } + + type OwnerDetails struct { + Email *string `db:"email"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Details *Details `db:"details"` + } + + type PlaceOwnerDetails struct { + Place Place `db:"place"` + Owner *OwnerDetails `db:"owner"` + } + + { + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id, + place.name, + placeperson.first_name, + placeperson.last_name, + placeperson.email, + persondetails.email, + persondetails.notes + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN persondetails ON false + WHERE + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp5, err := range AllRows[PlaceOwnerDetails](rows) { + require.NoError(t, err) + assert.NotNil(t, pp5.Owner, "Expected `Owner`, to not be nil") + assert.Equal(t, "ben", pp5.Owner.FirstName) + assert.Equal(t, "doe", pp5.Owner.LastName) + assert.Equal(t, benDoe.Email.String, *pp5.Owner.Email) + assert.Equal(t, "the-house", pp5.Place.Name.String) + assert.Equal(t, pp5.Place.ID, benDoe.Place.ID) + assert.Nil(t, pp5.Owner.Details) } - if pp2.Place.Name.String != "myplace" { - t.Error("Expected place name of `myplace`, got " + pp2.Place.Name.String) + } + + { + details := Details{ + Email: benDoe.Email.String, + Notes: "this is a test person", } - if pp2.Place.ID != pp.Place.ID { - t.Errorf("Expected place name of %v, got %v", pp.Place.ID, pp2.Place.ID) + + _, err = db.NamedExecContext(ctx, `INSERT INTO persondetails (email, notes) VALUES (:email, :notes)`, details) + require.NoError(t, err) + + rows, err = db.NamedQueryContext(ctx, ` + SELECT + place.id, + place.name, + placeperson.first_name, + placeperson.last_name, + placeperson.email, + persondetails.email, + persondetails.notes + FROM place + LEFT JOIN placeperson ON placeperson.place_id = place.id + LEFT JOIN persondetails ON persondetails.email = placeperson.email + WHERE + place.id = :place.id`, benDoe) + require.NoError(t, err) + + for pp6, err := range AllRows[PlaceOwnerDetails](rows) { + require.NoError(t, err) + assert.NotNil(t, pp6.Owner, "Expected `Owner`, to not be nil") + assert.Equal(t, "ben", pp6.Owner.FirstName) + assert.Equal(t, "doe", pp6.Owner.LastName) + assert.Equal(t, "the-house", pp6.Place.Name.String) + assert.Equal(t, pp6.Place.ID, pp6.Place.ID) + assert.NotNil(t, pp6.Owner.Details, "Expected `Details` to not be nil") + assert.Equal(t, details.Email, pp6.Owner.Details.Email) + assert.Equal(t, details.Notes, pp6.Owner.Details.Notes) } } })