@@ -22,7 +22,6 @@ use std::collections::HashMap;
2222use std:: env;
2323use std:: process:: Stdio ;
2424use std:: sync:: { Arc , Weak } ;
25-
2625use futures:: future:: BoxFuture ;
2726use futures:: stream:: FuturesUnordered ;
2827use futures:: { Future , FutureExt , StreamExt , TryFutureExt , select} ;
@@ -32,8 +31,8 @@ use nativelink_metric::{MetricsComponent, RootMetricsComponent};
3231use nativelink_proto:: com:: github:: trace_machina:: nativelink:: remote_execution:: update_for_worker:: Update ;
3332use nativelink_proto:: com:: github:: trace_machina:: nativelink:: remote_execution:: worker_api_client:: WorkerApiClient ;
3433use nativelink_proto:: com:: github:: trace_machina:: nativelink:: remote_execution:: {
35- ExecuteComplete , ExecuteResult , GoingAwayRequest , KeepAliveRequest , UpdateForWorker ,
36- execute_result ,
34+ execute_result , ExecuteComplete , ExecuteResult , GoingAwayRequest , KeepAliveRequest ,
35+ UpdateForWorker ,
3736} ;
3837use nativelink_store:: fast_slow_store:: FastSlowStore ;
3938use nativelink_util:: action_messages:: { ActionResult , ActionStage , OperationId } ;
@@ -46,6 +45,7 @@ use nativelink_util::{spawn, tls_utils};
4645use opentelemetry:: context:: Context ;
4746use tokio:: process;
4847use tokio:: sync:: { broadcast, mpsc} ;
48+ use tokio:: sync:: broadcast:: { Receiver , Sender } ;
4949use tokio:: time:: sleep;
5050use tokio_stream:: wrappers:: UnboundedReceiverStream ;
5151use tonic:: Streaming ;
@@ -87,6 +87,7 @@ struct LocalWorkerImpl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsM
8787 // on by the scheduler.
8888 actions_in_transit : Arc < AtomicU64 > ,
8989 metrics : Arc < Metrics > ,
90+ shutdown_tx : Sender < ShutdownGuard > ,
9091}
9192
9293pub async fn preconditions_met < H : BuildHasher + Sync > (
@@ -147,6 +148,7 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
147148 worker_id : String ,
148149 running_actions_manager : Arc < U > ,
149150 metrics : Arc < Metrics > ,
151+ shutdown_tx : Sender < ShutdownGuard > ,
150152 ) -> Self {
151153 Self {
152154 config,
@@ -159,6 +161,7 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
159161 // on by the scheduler.
160162 actions_in_transit : Arc :: new ( AtomicU64 :: new ( 0 ) ) ,
161163 metrics,
164+ shutdown_tx,
162165 }
163166 }
164167
@@ -208,6 +211,8 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
208211
209212 let ( add_future_channel, add_future_rx) = mpsc:: unbounded_channel ( ) ;
210213 let mut add_future_rx = UnboundedReceiverStream :: new ( add_future_rx) . fuse ( ) ;
214+ let ( inner_shutdown_channel, inner_shutdown_rx) = mpsc:: unbounded_channel ( ) ;
215+ let mut inner_shutdown_rx = UnboundedReceiverStream :: new ( inner_shutdown_rx) . fuse ( ) ;
211216
212217 let mut update_for_worker_stream = update_for_worker_stream. fuse ( ) ;
213218 // A notify which is triggered every time actions_in_flight is subtracted.
@@ -217,6 +222,9 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
217222 let actions_in_flight = Arc :: new ( AtomicU64 :: new ( 0 ) ) ;
218223 // Set to true when shutting down, this stops any new StartAction.
219224 let mut shutting_down = false ;
225+ // Channel to signal when shutdown is complete (GoingAway sent, ready to exit).
226+ let ( shutdown_complete_tx, shutdown_complete_rx) = mpsc:: unbounded_channel :: < ( ) > ( ) ;
227+ let mut shutdown_complete_rx = UnboundedReceiverStream :: new ( shutdown_complete_rx) . fuse ( ) ;
220228
221229 loop {
222230 select ! {
@@ -406,6 +414,7 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
406414 self . actions_in_transit. fetch_add( 1 , Ordering :: Release ) ;
407415
408416 let add_future_channel = add_future_channel. clone( ) ;
417+ let inner_shutdown_channel = inner_shutdown_channel. clone( ) ;
409418
410419 info_span!(
411420 "worker_start_action_ctx" ,
@@ -428,7 +437,16 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
428437 error!( ?err, "Error executing action" ) ;
429438 }
430439 add_future_channel
431- . send( make_publish_future( res) . then( move |res| {
440+ . send( make_publish_future( res)
441+ . then( move |res| {
442+ match self . config. execution_completion_behaviour {
443+ ExecutionCompletionBehaviour :: OneShotAlways => {
444+ inner_shutdown_channel. send( ( ) ) . ok( ) ;
445+ }
446+ ExecutionCompletionBehaviour :: Default => {
447+ // Do nothing
448+ }
449+ }
432450 actions_in_flight. fetch_sub( 1 , Ordering :: Release ) ;
433451 actions_notify. notify_one( ) ;
434452 core:: future:: ready( res)
@@ -452,13 +470,23 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
452470 let fut = res. err_tip( || "New future stream receives should never be closed" ) ?;
453471 futures. push( fut) ;
454472 } ,
473+ _ = inner_shutdown_rx. next( ) => {
474+ warn!( "Shutting down worker because of inner shutdown signal" , ) ;
475+ let guard = ShutdownGuard :: default ( ) ;
476+ drop( self . shutdown_tx. send( guard. clone( ) ) ) ;
477+ }
455478 res = futures. next( ) => res. err_tip( || "Keep-alive should always pending. Likely unable to send data to scheduler" ) ??,
479+ _ = shutdown_complete_rx. next( ) => {
480+ info!( "Shutdown complete, exiting worker loop" ) ;
481+ return Ok ( ( ) ) ;
482+ } ,
456483 complete_msg = shutdown_rx. recv( ) . fuse( ) => {
457484 warn!( "Worker loop received shutdown signal. Shutting down worker..." , ) ;
458485 let mut grpc_client = self . grpc_client. clone( ) ;
459486 let shutdown_guard = complete_msg. map_err( |e| make_err!( Code :: Internal , "Failed to receive shutdown message: {e:?}" ) ) ?;
460487 let actions_in_flight = actions_in_flight. clone( ) ;
461488 let actions_notify = actions_notify. clone( ) ;
489+ let shutdown_complete_tx = shutdown_complete_tx. clone( ) ;
462490 let shutdown_future = async move {
463491 // Wait for in-flight operations to be fully completed.
464492 while actions_in_flight. load( Ordering :: Acquire ) > 0 {
@@ -472,6 +500,8 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
472500 }
473501 // Allow shutdown to occur now.
474502 drop( shutdown_guard) ;
503+ // Signal that shutdown is complete.
504+ let _ = shutdown_complete_tx. send( ( ) ) ;
475505 Ok :: <( ) , Error >( ( ) )
476506 } ;
477507 futures. push( shutdown_future. boxed( ) ) ;
@@ -732,7 +762,8 @@ impl<T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorker<T,
732762 #[ instrument( skip( self ) , level = Level :: INFO ) ]
733763 pub async fn run (
734764 mut self ,
735- mut shutdown_rx : broadcast:: Receiver < ShutdownGuard > ,
765+ shutdown_tx : Sender < ShutdownGuard > ,
766+ mut shutdown_rx : Receiver < ShutdownGuard > ,
736767 ) -> Result < ( ) , Error > {
737768 let sleep_fn = self
738769 . sleep_fn
@@ -767,6 +798,7 @@ impl<T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorker<T,
767798 worker_id,
768799 self . running_actions_manager . clone ( ) ,
769800 self . metrics . clone ( ) ,
801+ shutdown_tx. clone ( ) ,
770802 ) ,
771803 update_for_worker_stream,
772804 ) ,
@@ -777,30 +809,37 @@ impl<T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorker<T,
777809 ) ;
778810
779811 // Now listen for connections and run all other services.
780- if let Err ( err) = inner. run ( update_for_worker_stream, & mut shutdown_rx) . await {
781- ' no_more_actions: {
782- // Ensure there are no actions in transit before we try to kill
783- // all our actions.
784- const ITERATIONS : usize = 1_000 ;
785-
786- const ERROR_MSG : & str = "Actions in transit did not reach zero before we disconnected from the scheduler" ;
787-
788- let sleep_duration = ACTIONS_IN_TRANSIT_TIMEOUT_S / ITERATIONS as f32 ;
789- for _ in 0 ..ITERATIONS {
790- if inner. actions_in_transit . load ( Ordering :: Acquire ) == 0 {
791- break ' no_more_actions;
812+ match inner. run ( update_for_worker_stream, & mut shutdown_rx) . await {
813+ Ok ( ( ) ) => {
814+ // Graceful shutdown completed, return without retrying.
815+ info ! ( "Worker completed graceful shutdown" ) ;
816+ return Ok ( ( ) ) ;
817+ }
818+ Err ( err) => {
819+ ' no_more_actions: {
820+ // Ensure there are no actions in transit before we try to kill
821+ // all our actions.
822+ const ITERATIONS : usize = 1_000 ;
823+
824+ const ERROR_MSG : & str = "Actions in transit did not reach zero before we disconnected from the scheduler" ;
825+
826+ let sleep_duration = ACTIONS_IN_TRANSIT_TIMEOUT_S / ITERATIONS as f32 ;
827+ for _ in 0 ..ITERATIONS {
828+ if inner. actions_in_transit . load ( Ordering :: Acquire ) == 0 {
829+ break ' no_more_actions;
830+ }
831+ ( sleep_fn_pin) ( Duration :: from_secs_f32 ( sleep_duration) ) . await ;
792832 }
793- ( sleep_fn_pin) ( Duration :: from_secs_f32 ( sleep_duration) ) . await ;
833+ error ! ( ERROR_MSG ) ;
834+ return Err ( err. append ( ERROR_MSG ) ) ;
794835 }
795- error ! ( ERROR_MSG ) ;
796- return Err ( err. append ( ERROR_MSG ) ) ;
797- }
798- error ! ( ?err, "Worker disconnected from scheduler" ) ;
799- // Kill off any existing actions because if we re-connect, we'll
800- // get some more and it might resource lock us.
801- self . running_actions_manager . kill_all ( ) . await ;
836+ error ! ( ?err, "Worker disconnected from scheduler" ) ;
837+ // Kill off any existing actions because if we re-connect, we'll
838+ // get some more and it might resource lock us.
839+ self . running_actions_manager . kill_all ( ) . await ;
802840
803- ( error_handler) ( err) . await ; // Try to connect again.
841+ ( error_handler) ( err) . await ; // Try to connect again.
842+ }
804843 }
805844 }
806845 // Unreachable.
0 commit comments