Skip to content

Commit 5584dd5

Browse files
PECO-1054 Expose Arrow batches to users, part three
Added DBSqlRows and DBSQLArrowBatchIterator public interfaces. Added arrowRecordIterator which implements DBSQLArrowBatchIterator. Moved closing the db operation from rows type into resultPageIterator as well as properties that are only used by resultPageIterator. Added GetArrowBatches function to rows and arrowRowScanner types. Added HasNext function to BatchIterator and SparkArrowBatch interfaces. Added example for accessing Arrow batches and updated doc.go Signed-off-by: Raymond Cypher <raymond.cypher@databricks.com>
1 parent 86525e6 commit 5584dd5

19 files changed

Lines changed: 1780 additions & 99 deletions

File tree

examples/arrrowbatches/main.go

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"database/sql"
6+
"database/sql/driver"
7+
"io"
8+
"log"
9+
"os"
10+
"strconv"
11+
"time"
12+
13+
"github.com/apache/arrow/go/v12/arrow"
14+
dbsql "github.com/databricks/databricks-sql-go"
15+
dbsqlrows "github.com/databricks/databricks-sql-go/rows"
16+
"github.com/joho/godotenv"
17+
)
18+
19+
func main() {
20+
// Opening a driver typically will not attempt to connect to the database.
21+
err := godotenv.Load()
22+
if err != nil {
23+
log.Fatal(err.Error())
24+
}
25+
26+
// dbsqllog.SetLogLevel("debug")
27+
28+
port, err := strconv.Atoi(os.Getenv("DATABRICKS_PORT"))
29+
if err != nil {
30+
log.Fatal(err.Error())
31+
}
32+
connector, err := dbsql.NewConnector(
33+
dbsql.WithServerHostname(os.Getenv("DATABRICKS_HOST")),
34+
dbsql.WithPort(port),
35+
dbsql.WithHTTPPath(os.Getenv("DATABRICKS_HTTPPATH")),
36+
dbsql.WithAccessToken(os.Getenv("DATABRICKS_ACCESSTOKEN")),
37+
dbsql.WithMaxRows(10000),
38+
)
39+
40+
if err != nil {
41+
// This will not be a connection error, but a DSN parse error or
42+
// another initialization error.
43+
log.Fatal(err)
44+
}
45+
46+
db := sql.OpenDB(connector)
47+
defer db.Close()
48+
49+
loopWithHasNext(db)
50+
loopWithNext(db)
51+
}
52+
53+
func loopWithHasNext(db *sql.DB) {
54+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
55+
defer cancel()
56+
57+
conn, _ := db.Conn(ctx)
58+
defer conn.Close()
59+
60+
query := `select * from hive_metastore.main.taxi_trip_data`
61+
62+
var rows driver.Rows
63+
var err error
64+
err = conn.Raw(func(d interface{}) error {
65+
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
66+
return err
67+
})
68+
69+
if err != nil {
70+
log.Fatalf("unable to run the query. err: %v", err)
71+
}
72+
defer rows.Close()
73+
74+
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
75+
defer cancel2()
76+
77+
batches, err := rows.(dbsqlrows.DBSQLRows).GetArrowBatches(ctx2)
78+
if err != nil {
79+
log.Fatalf("unable to get arrow batches. err: %v", err)
80+
}
81+
82+
var iBatch, nRows int
83+
for batches.HasNext() {
84+
b, err := batches.Next()
85+
if err != nil {
86+
log.Fatalf("Failure retrieving batch. err: %v", err)
87+
}
88+
89+
log.Printf("batch %v: nRecords=%v\n", iBatch, b.NumRows())
90+
iBatch += 1
91+
nRows += int(b.NumRows())
92+
}
93+
log.Printf("NRows: %v\n", nRows)
94+
}
95+
96+
func loopWithNext(db *sql.DB) {
97+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
98+
defer cancel()
99+
100+
conn, _ := db.Conn(ctx)
101+
defer conn.Close()
102+
103+
query := `select * from hive_metastore.main.taxi_trip_data`
104+
105+
var rows driver.Rows
106+
var err error
107+
108+
err = conn.Raw(func(d interface{}) error {
109+
rows, err = d.(driver.QueryerContext).QueryContext(ctx, query, nil)
110+
return err
111+
})
112+
if err != nil {
113+
log.Fatalf("unable to run the query. err: %v", err)
114+
}
115+
defer rows.Close()
116+
117+
ctx2, cancel2 := context.WithTimeout(context.Background(), 30*time.Second)
118+
defer cancel2()
119+
120+
batches, err := rows.(dbsqlrows.DBSQLRows).GetArrowBatches(ctx2)
121+
if err != nil {
122+
log.Fatalf("unable to get arrow batches. err: %v", err)
123+
}
124+
125+
var iBatch, nRows int
126+
var b arrow.Record
127+
for b, err = batches.Next(); err == nil; b, err = batches.Next() {
128+
log.Printf("batch %v: nRecords=%v\n", iBatch, b.NumRows())
129+
iBatch += 1
130+
nRows += int(b.NumRows())
131+
}
132+
133+
log.Printf("NRows: %v\n", nRows)
134+
if err == io.EOF {
135+
log.Println("normal loop termination")
136+
} else {
137+
log.Printf("loop terminated with error: %v", err)
138+
}
139+
}
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
package arrowbased
2+
3+
import (
4+
"context"
5+
"io"
6+
7+
"github.com/apache/arrow/go/v12/arrow"
8+
"github.com/databricks/databricks-sql-go/internal/cli_service"
9+
"github.com/databricks/databricks-sql-go/internal/config"
10+
dbsqlerr "github.com/databricks/databricks-sql-go/internal/errors"
11+
"github.com/databricks/databricks-sql-go/internal/rows/rowscanner"
12+
"github.com/databricks/databricks-sql-go/rows"
13+
)
14+
15+
func NewArrowRecordIterator(ctx context.Context, rpi rowscanner.ResultPageIterator, bi BatchIterator, arrowSchemaBytes []byte, cfg config.Config) rows.DBSQLArrowBatchIterator {
16+
ari := arrowRecordIterator{
17+
cfg: cfg,
18+
batchIterator: bi,
19+
resultPageIterator: rpi,
20+
ctx: ctx,
21+
arrowSchemaBytes: arrowSchemaBytes,
22+
}
23+
24+
return &ari
25+
26+
}
27+
28+
// A type implemented DBSQLArrowBatchIterator
29+
type arrowRecordIterator struct {
30+
ctx context.Context
31+
cfg config.Config
32+
batchIterator BatchIterator
33+
resultPageIterator rowscanner.ResultPageIterator
34+
currentBatch SparkArrowBatch
35+
isFinished bool
36+
arrowSchemaBytes []byte
37+
}
38+
39+
var _ rows.DBSQLArrowBatchIterator = (*arrowRecordIterator)(nil)
40+
41+
// Retrieve the next arrow record
42+
func (ri *arrowRecordIterator) Next() (arrow.Record, error) {
43+
if !ri.HasNext() {
44+
// returning EOF indicates that there are no more records to iterate
45+
return nil, io.EOF
46+
}
47+
48+
// make sure we have the current batch
49+
err := ri.getCurrentBatch()
50+
if err != nil {
51+
return nil, err
52+
}
53+
54+
// return next record in current batch
55+
r, err := ri.currentBatch.Next()
56+
57+
ri.checkFinished()
58+
59+
return r, err
60+
}
61+
62+
// Indicate whether there are any more records available
63+
func (ri *arrowRecordIterator) HasNext() bool {
64+
return !ri.isFinished
65+
}
66+
67+
// Free any resources associated with this iterator
68+
func (ri *arrowRecordIterator) Close() {
69+
if !ri.isFinished {
70+
ri.isFinished = true
71+
if ri.currentBatch != nil {
72+
ri.currentBatch.Close()
73+
}
74+
75+
if ri.batchIterator != nil {
76+
ri.batchIterator.Close()
77+
}
78+
79+
if ri.resultPageIterator != nil {
80+
ri.resultPageIterator.Close()
81+
}
82+
}
83+
}
84+
85+
func (ri *arrowRecordIterator) checkFinished() {
86+
finished := !ri.currentBatch.HasNext() && !ri.batchIterator.HasNext() && !ri.resultPageIterator.HasNext()
87+
88+
if finished {
89+
// Reached end of result set so Close
90+
ri.Close()
91+
}
92+
}
93+
94+
// Update the current batch if necessary
95+
func (ri *arrowRecordIterator) getCurrentBatch() error {
96+
97+
// only need to update if no current batch or current batch has no more records
98+
if ri.currentBatch == nil || !ri.currentBatch.HasNext() {
99+
100+
// ensure up to date batch iterator
101+
err := ri.getBatchIterator()
102+
if err != nil {
103+
return err
104+
}
105+
106+
// release current batch
107+
if ri.currentBatch != nil {
108+
ri.currentBatch.Close()
109+
}
110+
111+
// Get next batch from batch iterator
112+
ri.currentBatch, err = ri.batchIterator.Next()
113+
if err != nil {
114+
return err
115+
}
116+
}
117+
118+
return nil
119+
}
120+
121+
// Update batch iterator if necessary
122+
func (ri *arrowRecordIterator) getBatchIterator() error {
123+
// only need to update if there is no batch iterator or the
124+
// batch iterator has no more batches
125+
if ri.batchIterator == nil || !ri.batchIterator.HasNext() {
126+
if ri.batchIterator != nil {
127+
// release any resources held by the current batch iterator
128+
ri.batchIterator.Close()
129+
ri.batchIterator = nil
130+
}
131+
132+
// Get the next page of the result set
133+
resp, err := ri.resultPageIterator.Next()
134+
if err != nil {
135+
return err
136+
}
137+
138+
// Check the result format
139+
resultFormat := resp.ResultSetMetadata.GetResultFormat()
140+
if resultFormat != cli_service.TSparkRowSetType_ARROW_BASED_SET && resultFormat != cli_service.TSparkRowSetType_URL_BASED_SET {
141+
return dbsqlerr.NewDriverError(ri.ctx, errArrowRowsNotArrowFormat, nil)
142+
}
143+
144+
if ri.arrowSchemaBytes == nil {
145+
ri.arrowSchemaBytes = resp.ResultSetMetadata.ArrowSchema
146+
}
147+
148+
// Create a new batch iterator for the batches in the result page
149+
bi, err := ri.newBatchIterator(resp)
150+
if err != nil {
151+
return err
152+
}
153+
154+
ri.batchIterator = bi
155+
}
156+
157+
return nil
158+
}
159+
160+
// Create a new batch iterator from a page of the result set
161+
func (ri *arrowRecordIterator) newBatchIterator(fr *cli_service.TFetchResultsResp) (BatchIterator, error) {
162+
bl, err := ri.newBatchLoader(fr)
163+
if err != nil {
164+
return nil, err
165+
}
166+
167+
bi, err := NewBatchIterator(bl)
168+
169+
return bi, err
170+
}
171+
172+
// Create a new batch loader from a page of the result set
173+
func (ri *arrowRecordIterator) newBatchLoader(fr *cli_service.TFetchResultsResp) (BatchLoader, error) {
174+
rowSet := fr.Results
175+
var bl BatchLoader
176+
var err error
177+
if len(rowSet.ResultLinks) > 0 {
178+
bl, err = NewCloudBatchLoader(ri.ctx, rowSet.ResultLinks, rowSet.StartRowOffset, &ri.cfg)
179+
} else {
180+
bl, err = NewLocalBatchLoader(ri.ctx, rowSet.ArrowBatches, rowSet.StartRowOffset, ri.arrowSchemaBytes, &ri.cfg)
181+
}
182+
if err != nil {
183+
return nil, err
184+
}
185+
186+
return bl, nil
187+
}

0 commit comments

Comments
 (0)