@@ -52,15 +52,18 @@ func (c *conn) Close() error {
5252 ctx := driverctx .NewContextWithConnId (context .Background (), c .id )
5353
5454 // Close telemetry and release resources
55+ closeStart := time .Now ()
56+ _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
57+ SessionHandle : c .session .SessionHandle ,
58+ })
59+ closeLatencyMs := time .Since (closeStart ).Milliseconds ()
60+
5561 if c .telemetry != nil {
62+ c .telemetry .RecordOperation (ctx , c .id , telemetry .OperationTypeDeleteSession , closeLatencyMs )
5663 _ = c .telemetry .Close (ctx )
5764 telemetry .ReleaseForConnection (c .cfg .Host )
5865 }
5966
60- _ , err := c .client .CloseSession (ctx , & cli_service.TCloseSessionReq {
61- SessionHandle : c .session .SessionHandle ,
62- })
63-
6467 if err != nil {
6568 log .Err (err ).Msg ("databricks: failed to close connection" )
6669 return dbsqlerrint .NewRequestError (ctx , dbsqlerr .ErrCloseConnection , err )
@@ -122,15 +125,16 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
122125
123126 corrId := driverctx .CorrelationIdFromContext (ctx )
124127
125- exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
128+ var pollCount int
129+ exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args , & pollCount )
126130 log , ctx = client .LoggerAndContext (ctx , exStmtResp )
127131 stagingErr := c .execStagingOperation (exStmtResp , ctx )
128132
129133 // Telemetry: track statement execution
130134 var statementID string
131135 if c .telemetry != nil && exStmtResp != nil && exStmtResp .OperationHandle != nil && exStmtResp .OperationHandle .OperationId != nil {
132136 statementID = client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID )
133- ctx = c .telemetry .BeforeExecute (ctx , statementID )
137+ ctx = c .telemetry .BeforeExecute (ctx , c . id , statementID )
134138 defer func () {
135139 finalErr := err
136140 if stagingErr != nil {
@@ -139,6 +143,7 @@ func (c *conn) ExecContext(ctx context.Context, query string, args []driver.Name
139143 c .telemetry .AfterExecute (ctx , finalErr )
140144 c .telemetry .CompleteStatement (ctx , statementID , finalErr != nil )
141145 }()
146+ c .telemetry .AddTag (ctx , "poll_count" , pollCount )
142147 }
143148
144149 if exStmtResp != nil && exStmtResp .OperationHandle != nil {
@@ -180,21 +185,30 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
180185 log , _ := client .LoggerAndContext (ctx , nil )
181186 msg , start := log .Track ("QueryContext" )
182187
183- // first we try to get the results synchronously.
184- // at any point in time that the context is done we must cancel and return
185- exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args )
188+ // Capture execution start time for telemetry before running the query
189+ executeStart := time .Now ()
190+ var pollCount int
191+ exStmtResp , opStatusResp , pollCount , err := c .runQueryWithTelemetry (ctx , query , args , & pollCount )
186192 log , ctx = client .LoggerAndContext (ctx , exStmtResp )
187193 defer log .Duration (msg , start )
188194
189- // Telemetry: track statement execution
190195 var statementID string
191196 if c .telemetry != nil && exStmtResp != nil && exStmtResp .OperationHandle != nil && exStmtResp .OperationHandle .OperationId != nil {
192197 statementID = client .SprintGuid (exStmtResp .OperationHandle .OperationId .GUID )
193- ctx = c .telemetry .BeforeExecute (ctx , statementID )
198+ // Use BeforeExecuteWithTime to set the correct start time (before execution)
199+ ctx = c .telemetry .BeforeExecuteWithTime (ctx , c .id , statementID , executeStart )
194200 defer func () {
195201 c .telemetry .AfterExecute (ctx , err )
196202 c .telemetry .CompleteStatement (ctx , statementID , err != nil )
197203 }()
204+
205+ c .telemetry .AddTag (ctx , "poll_count" , pollCount )
206+ c .telemetry .AddTag (ctx , "operation_type" , telemetry .OperationTypeExecuteStatement )
207+
208+ if exStmtResp .DirectResults != nil && exStmtResp .DirectResults .ResultSetMetadata != nil {
209+ resultFormat := exStmtResp .DirectResults .ResultSetMetadata .GetResultFormat ()
210+ c .telemetry .AddTag (ctx , "result.format" , resultFormat .String ())
211+ }
198212 }
199213
200214 if err != nil {
@@ -203,13 +217,31 @@ func (c *conn) QueryContext(ctx context.Context, query string, args []driver.Nam
203217 }
204218
205219 corrId := driverctx .CorrelationIdFromContext (ctx )
206- rows , err := rows .NewRows (c .id , corrId , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
220+
221+ var telemetryUpdate func (int , int64 )
222+ if c .telemetry != nil {
223+ telemetryUpdate = func (chunkCount int , bytesDownloaded int64 ) {
224+ c .telemetry .AddTag (ctx , "chunk_count" , chunkCount )
225+ c .telemetry .AddTag (ctx , "bytes_downloaded" , bytesDownloaded )
226+ }
227+ }
228+
229+ rows , err := rows .NewRows (c .id , corrId , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults , ctx , telemetryUpdate )
207230
208231 return rows , err
209232
210233}
211234
212- func (c * conn ) runQuery (ctx context.Context , query string , args []driver.NamedValue ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , error ) {
235+ func (c * conn ) runQueryWithTelemetry (ctx context.Context , query string , args []driver.NamedValue , pollCount * int ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , int , error ) {
236+ exStmtResp , opStatusResp , err := c .runQuery (ctx , query , args , pollCount )
237+ count := 0
238+ if pollCount != nil {
239+ count = * pollCount
240+ }
241+ return exStmtResp , opStatusResp , count , err
242+ }
243+
244+ func (c * conn ) runQuery (ctx context.Context , query string , args []driver.NamedValue , pollCount * int ) (* cli_service.TExecuteStatementResp , * cli_service.TGetOperationStatusResp , error ) {
213245 // first we try to get the results synchronously.
214246 // at any point in time that the context is done we must cancel and return
215247 exStmtResp , err := c .executeStatement (ctx , query , args )
@@ -241,7 +273,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
241273 case cli_service .TOperationState_INITIALIZED_STATE ,
242274 cli_service .TOperationState_PENDING_STATE ,
243275 cli_service .TOperationState_RUNNING_STATE :
244- statusResp , err := c .pollOperation (ctx , opHandle )
276+ statusResp , err := c .pollOperationWithCount (ctx , opHandle , pollCount )
245277 if err != nil {
246278 return exStmtResp , statusResp , err
247279 }
@@ -269,7 +301,7 @@ func (c *conn) runQuery(ctx context.Context, query string, args []driver.NamedVa
269301 }
270302
271303 } else {
272- statusResp , err := c .pollOperation (ctx , opHandle )
304+ statusResp , err := c .pollOperationWithCount (ctx , opHandle , pollCount )
273305 if err != nil {
274306 return exStmtResp , statusResp , err
275307 }
@@ -396,7 +428,7 @@ func (c *conn) executeStatement(ctx context.Context, query string, args []driver
396428 return resp , err
397429}
398430
399- func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
431+ func (c * conn ) pollOperationWithCount (ctx context.Context , opHandle * cli_service.TOperationHandle , pollCount * int ) (* cli_service.TGetOperationStatusResp , error ) {
400432 corrId := driverctx .CorrelationIdFromContext (ctx )
401433 log := logger .WithContext (c .id , corrId , client .SprintGuid (opHandle .OperationId .GUID ))
402434 var statusResp * cli_service.TGetOperationStatusResp
@@ -413,6 +445,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
413445 OperationHandle : opHandle ,
414446 })
415447
448+ if pollCount != nil {
449+ * pollCount ++
450+ }
451+
416452 if statusResp != nil && statusResp .OperationState != nil {
417453 log .Debug ().Msgf ("databricks: status %s" , statusResp .GetOperationState ().String ())
418454 }
@@ -455,6 +491,10 @@ func (c *conn) pollOperation(ctx context.Context, opHandle *cli_service.TOperati
455491 return statusResp , nil
456492}
457493
494+ func (c * conn ) pollOperation (ctx context.Context , opHandle * cli_service.TOperationHandle ) (* cli_service.TGetOperationStatusResp , error ) {
495+ return c .pollOperationWithCount (ctx , opHandle , nil )
496+ }
497+
458498func (c * conn ) CheckNamedValue (nv * driver.NamedValue ) error {
459499 var err error
460500 if parameter , ok := nv .Value .(Parameter ); ok {
@@ -623,7 +663,7 @@ func (c *conn) execStagingOperation(
623663 }
624664
625665 if len (driverctx .StagingPathsFromContext (ctx )) != 0 {
626- row , err = rows .NewRows (c .id , corrId , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults )
666+ row , err = rows .NewRows (c .id , corrId , exStmtResp .OperationHandle , c .client , c .cfg , exStmtResp .DirectResults , nil , nil )
627667 if err != nil {
628668 return dbsqlerrint .NewDriverError (ctx , "error reading row." , err )
629669 }
0 commit comments