Skip to content

Commit f7c0286

Browse files
PECO-1054 Expose Arrow batches to users (#160)
Step one of exposing arrow batches directly to users. Moved the logic for iterating over the pages in a result set into ResultPageIterator. Rows now composes in ResultPageIterator. Introduced Delimeter type. Delimeter tracks a start/end point and provides functions for determining if a point is within the delimiter range and the direction of the point if it is outside the delimeter range. Updated sparkArrowBatch, arrowRowScanner, columnRows, rows to use Delimiter. Updated the Fetch logic for cloudURL and localBatch so that the concurrentFetcher doesn't need to hold or pass through a Config instance.
2 parents 73073d2 + 69bfdef commit f7c0286

14 files changed

Lines changed: 735 additions & 446 deletions

internal/fetcher/fetcher.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@ package fetcher
22

33
import (
44
"context"
5-
"github.com/databricks/databricks-sql-go/internal/config"
65
"sync"
76

87
"github.com/databricks/databricks-sql-go/driverctx"
98
dbsqllog "github.com/databricks/databricks-sql-go/logger"
109
)
1110

1211
type FetchableItems[OutputType any] interface {
13-
Fetch(ctx context.Context, cfg *config.Config) ([]OutputType, error)
12+
Fetch(ctx context.Context) ([]OutputType, error)
1413
}
1514

1615
type Fetcher[OutputType any] interface {
@@ -24,7 +23,6 @@ type concurrentFetcher[I FetchableItems[O], O any] struct {
2423
outChan chan O
2524
err error
2625
nWorkers int
27-
cfg *config.Config
2826
mu sync.Mutex
2927
start sync.Once
3028
ctx context.Context
@@ -100,10 +98,17 @@ func (f *concurrentFetcher[I, O]) logger() *dbsqllog.DBSQLLogger {
10098
return f.DBSQLLogger
10199
}
102100

103-
func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers int, cfg *config.Config, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) {
101+
func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWorkers, maxItemsInMemory int, inputChan <-chan FetchableItems[O]) (Fetcher[O], error) {
102+
if nWorkers < 1 {
103+
nWorkers = 1
104+
}
105+
if maxItemsInMemory < 1 {
106+
maxItemsInMemory = 1
107+
}
108+
104109
// channel for loaded items
105110
// TODO: pass buffer size
106-
outputChannel := make(chan O, 100)
111+
outputChannel := make(chan O, maxItemsInMemory)
107112

108113
// channel to signal a cancel
109114
stopChannel := make(chan bool)
@@ -118,7 +123,6 @@ func NewConcurrentFetcher[I FetchableItems[O], O any](ctx context.Context, nWork
118123
cancelChan: stopChannel,
119124
ctx: ctx,
120125
nWorkers: nWorkers,
121-
cfg: cfg,
122126
}
123127

124128
return fetcher, nil
@@ -139,7 +143,7 @@ func work[I FetchableItems[O], O any](f *concurrentFetcher[I, O], workerIndex in
139143
case input, ok := <-f.inputChan:
140144
if ok {
141145
f.logger().Debug().Msgf("concurrent fetcher worker %d loading item", workerIndex)
142-
result, err := input.Fetch(f.ctx, f.cfg)
146+
result, err := input.Fetch(f.ctx)
143147
if err != nil {
144148
f.logger().Debug().Msgf("concurrent fetcher worker %d received error", workerIndex)
145149
f.setErr(err)

internal/fetcher/fetcher_test.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@ package fetcher
22

33
import (
44
"context"
5-
"github.com/databricks/databricks-sql-go/internal/config"
6-
"github.com/pkg/errors"
75
"math"
86
"testing"
97
"time"
8+
9+
"github.com/pkg/errors"
1010
)
1111

1212
// Create a mock struct for FetchableItems
@@ -20,7 +20,7 @@ type mockOutput struct {
2020
}
2121

2222
// Implement the Fetch method
23-
func (m *mockFetchableItem) Fetch(ctx context.Context, cfg *config.Config) ([]*mockOutput, error) {
23+
func (m *mockFetchableItem) Fetch(ctx context.Context) ([]*mockOutput, error) {
2424
time.Sleep(m.wait)
2525
outputs := make([]*mockOutput, 5)
2626
for i := range outputs {
@@ -35,7 +35,7 @@ var _ FetchableItems[*mockOutput] = (*mockFetchableItem)(nil)
3535
func TestConcurrentFetcher(t *testing.T) {
3636
t.Run("Comprehensively tests the concurrent fetcher", func(t *testing.T) {
3737
ctx := context.Background()
38-
cfg := &config.Config{}
38+
3939
inputChan := make(chan FetchableItems[*mockOutput], 10)
4040
for i := 0; i < 10; i++ {
4141
item := mockFetchableItem{item: i, wait: 1 * time.Second}
@@ -44,7 +44,7 @@ func TestConcurrentFetcher(t *testing.T) {
4444
close(inputChan)
4545

4646
// Create a fetcher
47-
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, cfg, inputChan)
47+
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 3, 3, inputChan)
4848
if err != nil {
4949
t.Fatalf("Error creating fetcher: %v", err)
5050
}
@@ -95,7 +95,7 @@ func TestConcurrentFetcher(t *testing.T) {
9595
close(inputChan)
9696

9797
// Create a new fetcher
98-
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, &config.Config{}, inputChan)
98+
fetcher, err := NewConcurrentFetcher[*mockFetchableItem](ctx, 2, 2, inputChan)
9999
if err != nil {
100100
t.Fatalf("Error creating fetcher: %v", err)
101101
}

internal/rows/arrowbased/arrowRows.go

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,9 @@ type valueContainerMaker interface {
2727
}
2828

2929
type sparkArrowBatch struct {
30-
rowCount, startRow, endRow int64
31-
arrowRecordBytes []byte
32-
hasSchema bool
33-
}
34-
35-
func (sab *sparkArrowBatch) contains(rowIndex int64) bool {
36-
return sab != nil && sab.startRow <= rowIndex && sab.endRow >= rowIndex
30+
rowscanner.Delimiter
31+
arrowRecordBytes []byte
32+
hasSchema bool
3733
}
3834

3935
type timeStampFn func(arrow.Timestamp) time.Time
@@ -46,6 +42,7 @@ type colInfo struct {
4642

4743
// arrowRowScanner handles extracting values from arrow records
4844
type arrowRowScanner struct {
45+
rowscanner.Delimiter
4946
recordReader
5047
valueContainerMaker
5148

@@ -61,9 +58,6 @@ type arrowRowScanner struct {
6158
// database types for the columns
6259
colInfo []colInfo
6360

64-
// number of rows in the current TRowSet
65-
nRows int64
66-
6761
// a TRowSet contains multiple arrow batches
6862
currentBatch *sparkArrowBatch
6963

@@ -140,12 +134,12 @@ func NewArrowRowScanner(resultSetMetadata *cli_service.TGetResultSetMetadataResp
140134
}
141135

142136
rs := &arrowRowScanner{
137+
Delimiter: rowscanner.NewDelimiter(rowSet.StartRowOffset, rowscanner.CountRows(rowSet)),
143138
recordReader: sparkRecordReader{
144139
ctx: ctx,
145140
},
146141
valueContainerMaker: &arrowValueContainerMaker{},
147142
ArrowConfig: arrowConfig,
148-
nRows: countRows(rowSet),
149143
arrowSchemaBytes: schemaBytes,
150144
arrowSchema: arrowSchema,
151145
toTimestampFn: ttsf,
@@ -172,7 +166,7 @@ func (ars *arrowRowScanner) Close() {
172166
// NRows returns the number of rows in the current set of batches
173167
func (ars *arrowRowScanner) NRows() int64 {
174168
if ars != nil {
175-
return ars.nRows
169+
return ars.Count()
176170
}
177171

178172
return 0
@@ -203,7 +197,7 @@ func (ars *arrowRowScanner) ScanRow(
203197
return err
204198
}
205199

206-
var rowInBatchIndex int = int(rowIndex - ars.currentBatch.startRow)
200+
var rowInBatchIndex int = int(rowIndex - ars.currentBatch.Start())
207201

208202
// if no location is provided default to UTC
209203
if ars.location == nil {
@@ -248,41 +242,14 @@ func isIntervalType(typeId cli_service.TTypeId) bool {
248242
return ok
249243
}
250244

251-
// countRows returns the number of rows in the TRowSet
252-
func countRows(rowSet *cli_service.TRowSet) int64 {
253-
if rowSet == nil {
254-
return 0
255-
}
256-
257-
if rowSet.ArrowBatches != nil {
258-
batches := rowSet.ArrowBatches
259-
var n int64
260-
for i := range batches {
261-
n += batches[i].RowCount
262-
}
263-
return n
264-
}
265-
266-
if rowSet.ResultLinks != nil {
267-
links := rowSet.ResultLinks
268-
var n int64
269-
for i := range links {
270-
n += links[i].RowCount
271-
}
272-
return n
273-
}
274-
275-
return 0
276-
}
277-
278245
// loadBatchFor loads the batch containing the specified row if necessary
279246
func (ars *arrowRowScanner) loadBatchFor(rowIndex int64) dbsqlerr.DBError {
280247

281248
if ars == nil || ars.BatchLoader == nil {
282249
return dbsqlerrint.NewDriverError(context.Background(), errArrowRowsNoArrowBatches, nil)
283250
}
284251
// if the batch already loaded we can just return
285-
if ars.currentBatch != nil && ars.currentBatch.contains(rowIndex) && ars.columnValues != nil {
252+
if ars.currentBatch != nil && ars.currentBatch.Contains(rowIndex) && ars.columnValues != nil {
286253
return nil
287254
}
288255

internal/rows/arrowbased/arrowRows_test.go

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -663,23 +663,23 @@ func TestArrowRowScanner(t *testing.T) {
663663
assert.Nil(t, err)
664664
assert.NotNil(t, lastReadBatch)
665665
assert.Equal(t, 1, callCount)
666-
assert.Equal(t, int64(0), lastReadBatch.startRow)
666+
assert.Equal(t, int64(0), lastReadBatch.Start())
667667
}
668668

669669
for _, i := range []int64{5, 6, 7} {
670670
err := ars.loadBatchFor(i)
671671
assert.Nil(t, err)
672672
assert.NotNil(t, lastReadBatch)
673673
assert.Equal(t, 2, callCount)
674-
assert.Equal(t, int64(5), lastReadBatch.startRow)
674+
assert.Equal(t, int64(5), lastReadBatch.Start())
675675
}
676676

677677
for _, i := range []int64{8, 9, 10, 11, 12, 13, 14} {
678678
err := ars.loadBatchFor(i)
679679
assert.Nil(t, err)
680680
assert.NotNil(t, lastReadBatch)
681681
assert.Equal(t, 3, callCount)
682-
assert.Equal(t, int64(8), lastReadBatch.startRow)
682+
assert.Equal(t, int64(8), lastReadBatch.Start())
683683
}
684684

685685
err := ars.loadBatchFor(-1)
@@ -983,13 +983,13 @@ func TestArrowRowScanner(t *testing.T) {
983983

984984
if i%1000 == 0 {
985985
assert.NotNil(t, ars.currentBatch)
986-
assert.Equal(t, int64(i), ars.currentBatch.startRow)
986+
assert.Equal(t, int64(i), ars.currentBatch.Start())
987987
if i < 53000 {
988-
assert.Equal(t, int64(1000), ars.currentBatch.rowCount)
988+
assert.Equal(t, int64(1000), ars.currentBatch.Count())
989989
} else {
990-
assert.Equal(t, int64(940), ars.currentBatch.rowCount)
990+
assert.Equal(t, int64(940), ars.currentBatch.Count())
991991
}
992-
assert.Equal(t, ars.currentBatch.startRow+ars.currentBatch.rowCount-1, ars.currentBatch.endRow)
992+
assert.Equal(t, ars.currentBatch.Start()+ars.currentBatch.Count()-1, ars.currentBatch.End())
993993
}
994994
}
995995

0 commit comments

Comments
 (0)