@@ -53,15 +53,18 @@ func (c *conn) Close() error {
5353 ctx := driverctx .NewContextWithConnId (context .Background (), c .id )
5454
5555 // Close telemetry and release resources
56+ closeStart := time .Now ()
57+ _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
58+ SessionHandle : c .session .SessionHandle ,
59+ })
60+ closeLatencyMs := time .Since (closeStart ).Milliseconds ()
61+
5662 if c .telemetry != nil {
63+ c .telemetry .RecordOperation (ctx , c .id , telemetry .OperationTypeDeleteSession , closeLatencyMs )
5764 _ = c .telemetry .Close (ctx )
5865 telemetry .ReleaseForConnection (c .cfg .Host )
5966 }
6067
61- _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
62- SessionHandle : c .session .SessionHandle ,
63- })
64-
6568 if err != nil {
6669 log .Err (err ).Msg ("databricks: failed to close connection" )
6770 return dbsqlerrint .NewBadConnectionError (err )
@@ -123,15 +126,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
123126
124127 corrId := driverctx .CorrelationIdFromContext (ctx )
125128
126- exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
129+ var pollCount int
130+ exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args , & pollCount )
127131 log , ctx = client .LoggerAndContext (ctx , exStmtResp )
128132 stagingErr := c .execStagingOperation (exStmtResp , ctx )
129133
130134 // Telemetry: track statement execution
131135 var statementID string
132136 if c .telemetry != nil && exStmtResp != nil && exStmtResp .OperationHandle != nil && exStmtResp .OperationHandle .OperationId != nil {
133137 statementID = client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID )
134- ctx = c .telemetry .BeforeExecute (ctx , statementID )
138+ ctx = c .telemetry .BeforeExecute (ctx , c . id , statementID )
135139 defer func () {
136140 finalErr := err
137141 if stagingErr != nil {
@@ -140,6 +144,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
140144 c .telemetry .AfterExecute (ctx , finalErr )
141145 c .telemetry .CompleteStatement (ctx , statementID , finalErr != nil )
142146 }()
147+ c .telemetry .AddTag (ctx , "poll_count" , pollCount )
143148 }
144149
145150 if exStmtResp != nil && exStmtResp .OperationHandle != nil {
@@ -181,34 +186,60 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
181186 log , _ := client .LoggerAndContext (ctx , nil )
182187 msg , start := log .Track ("QueryContext" )
183188
184- // first we try to get the results synchronously.
185- // at any point in time that the context is done we must cancel and return
186- exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
189+ // Capture execution start time for telemetry before running the query
190+ executeStart := time .Now ()
191+ var pollCount int
192+ exStmtResp , opStatusResp , pollCount , err := c .runQueryWithTelemetry (ctx , query , args , & pollCount )
187193 log , ctx = client .LoggerAndContext (ctx , exStmtResp )
188194 defer log .Duration (msg , start )
189195
190- // Telemetry: track statement execution
191196 var statementID string
192197 if c .telemetry != nil && exStmtResp != nil && exStmtResp .OperationHandle != nil && exStmtResp .OperationHandle .OperationId != nil {
193198 statementID = client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID )
194- ctx = c .telemetry .BeforeExecute (ctx , statementID )
199+ // Use BeforeExecuteWithTime to set the correct start time (before execution)
200+ ctx = c .telemetry .BeforeExecuteWithTime (ctx , c .id , statementID , executeStart )
195201 defer func () {
196202 c .telemetry .AfterExecute (ctx , err )
197203 c .telemetry .CompleteStatement (ctx , statementID , err != nil )
198204 }()
205+
206+ c .telemetry .AddTag (ctx , "poll_count" , pollCount )
207+ c .telemetry .AddTag (ctx , "operation_type" , telemetry .OperationTypeExecuteStatement )
208+
209+ if exStmtResp .DirectResults != nil && exStmtResp .DirectResults .ResultSetMetadata != nil {
210+ resultFormat := exStmtResp .DirectResults .ResultSetMetadata .GetResultFormat ()
211+ c .telemetry .AddTag (ctx , "result.format" , resultFormat .String ())
212+ }
199213 }
200214
201215 if err != nil {
202216 log .Err (err ).Msg ("databricks: failed to run query" ) // To log query we need to redact credentials
203217 return nil , dbsqlerrint .NewExecutionError (ctx , dbsqlerr .ErrQueryExecution , err , opStatusResp )
204218 }
205219
206- rows , err := rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
220+ var telemetryUpdate func (int , int64 )
221+ if c .telemetry != nil {
222+ telemetryUpdate = func (chunkCount int , bytesDownloaded int64 ) {
223+ c .telemetry .AddTag (ctx , "chunk_count" , chunkCount )
224+ c .telemetry .AddTag (ctx , "bytes_downloaded" , bytesDownloaded )
225+ }
226+ }
227+
228+ rows , err := rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults , ctx , telemetryUpdate )
207229 return rows , err
208230
209231}
210232
211- func (c * conn ) runQuery (ctx context.Context , query string , args []driver.NamedValue ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , error ) {
233+ func (c * conn ) runQueryWithTelemetry (ctx context.Context , query string , args []driver.NamedValue , pollCount * int ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , int , error ) {
234+ exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args , pollCount )
235+ count := 0
236+ if pollCount != nil {
237+ count = * pollCount
238+ }
239+ return exStmtResp , opStatusResp , count , err
240+ }
241+
242+ func (c * conn ) runQuery (ctx context.Context , query string , args []driver.NamedValue , pollCount * int ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , error ) {
212243 // first we try to get the results synchronously.
213244 // at any point in time that the context is done we must cancel and return
214245 exStmtResp , err := c .executeStatement (ctx , query , args )
@@ -240,7 +271,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
240271 case cli_service .TOperationState_INITIALIZED_STATE ,
241272 cli_service .TOperationState_PENDING_STATE ,
242273 cli_service .TOperationState_RUNNING_STATE :
243- statusResp , err := c .pollOperation (ctx , opHandle )
274+ statusResp , err := c .pollOperationWithCount (ctx , opHandle , pollCount )
244275 if err != nil {
245276 return exStmtResp , statusResp , err
246277 }
@@ -268,7 +299,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
268299 }
269300
270301 } else {
271- statusResp , err := c .pollOperation (ctx , opHandle )
302+ statusResp , err := c .pollOperationWithCount (ctx , opHandle , pollCount )
272303 if err != nil {
273304 return exStmtResp , statusResp , err
274305 }
@@ -396,7 +427,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
396427 return resp , err
397428}
398429
399- func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
430+ func (c * conn ) pollOperationWithCount (ctx context.Context , opHandle * cli_service.TOperationHandle , pollCount * int ) (* cli_service.TGetOperationStatusResp , error ) {
400431 corrId := driverctx .CorrelationIdFromContext (ctx )
401432 log := logger .WithContext (c .id , corrId , client .SprintGuid (opHandle .OperationId .GUID ))
402433 var statusResp * cli_service.TGetOperationStatusResp
@@ -413,6 +444,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
413444 OperationHandle : opHandle ,
414445 })
415446
447+ if pollCount != nil {
448+ * pollCount ++
449+ }
450+
416451 if statusResp != nil && statusResp .OperationState != nil {
417452 log .Debug ().Msgf ("databricks: status %s" , statusResp .GetOperationState ().String ())
418453 }
@@ -455,6 +490,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
455490 return statusResp , nil
456491}
457492
493+ func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
494+ return c .pollOperationWithCount (ctx , opHandle , nil )
495+ }
496+
458497func (c * conn ) CheckNamedValue (nv * driver.NamedValue ) error {
459498 var err error
460499 if parameter , ok := nv .Value .(Parameter ); ok {
@@ -622,7 +661,7 @@ func (c *conn) execStagingOperation(
622661 }
623662
624663 if len (driverctx .StagingPathsFromContext (ctx )) != 0 {
625- row , err = rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
664+ row , err = rows .NewRows (ctx , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults , nil , nil )
626665 if err != nil {
627666 return dbsqlerrint .NewDriverError (ctx , "error reading row." , err )
628667 }
0 commit comments