Skip to content

Commit dbef3ff

Browse files
committed
Introduce execution_completion_behaviour: one_shot_always for workers.
# Conflicts: # nativelink-worker/src/local_worker.rs
1 parent a64e2a0 commit dbef3ff

5 files changed

Lines changed: 268 additions & 35 deletions

File tree

nativelink-config/src/cas_server.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,17 @@ pub struct LocalWorkerConfig {
853853
/// them from CAS for every action.
854854
/// Default: None (directory cache disabled)
855855
pub directory_cache: Option<DirectoryCacheConfig>,
856+
857+
#[serde(default)]
858+
pub execution_completion_behaviour: ExecutionCompletionBehaviour,
859+
}
860+
861+
#[derive(Deserialize, Serialize, Debug, Default, Copy, Clone)]
862+
#[serde(rename_all = "snake_case")]
863+
pub enum ExecutionCompletionBehaviour {
864+
#[default]
865+
Default,
866+
OneShotAlways,
856867
}
857868

858869
#[derive(Deserialize, Serialize, Debug, Clone)]

nativelink-worker/src/local_worker.rs

Lines changed: 65 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ use std::collections::HashMap;
2222
use std::env;
2323
use std::process::Stdio;
2424
use std::sync::{Arc, Weak};
25-
2625
use futures::future::BoxFuture;
2726
use futures::stream::FuturesUnordered;
2827
use futures::{Future, FutureExt, StreamExt, TryFutureExt, select};
@@ -32,8 +31,8 @@ use nativelink_metric::{MetricsComponent, RootMetricsComponent};
3231
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update;
3332
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::worker_api_client::WorkerApiClient;
3433
use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{
35-
ExecuteComplete, ExecuteResult, GoingAwayRequest, KeepAliveRequest, UpdateForWorker,
36-
execute_result,
34+
execute_result, ExecuteComplete, ExecuteResult, GoingAwayRequest, KeepAliveRequest,
35+
UpdateForWorker,
3736
};
3837
use nativelink_store::fast_slow_store::FastSlowStore;
3938
use nativelink_util::action_messages::{ActionResult, ActionStage, OperationId};
@@ -46,6 +45,7 @@ use nativelink_util::{spawn, tls_utils};
4645
use opentelemetry::context::Context;
4746
use tokio::process;
4847
use tokio::sync::{broadcast, mpsc};
48+
use tokio::sync::broadcast::{Receiver, Sender};
4949
use tokio::time::sleep;
5050
use tokio_stream::wrappers::UnboundedReceiverStream;
5151
use tonic::Streaming;
@@ -87,6 +87,7 @@ struct LocalWorkerImpl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsM
8787
// on by the scheduler.
8888
actions_in_transit: Arc<AtomicU64>,
8989
metrics: Arc<Metrics>,
90+
shutdown_tx: Sender<ShutdownGuard>,
9091
}
9192

9293
pub async fn preconditions_met<H: BuildHasher + Sync>(
@@ -147,6 +148,7 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
147148
worker_id: String,
148149
running_actions_manager: Arc<U>,
149150
metrics: Arc<Metrics>,
151+
shutdown_tx: Sender<ShutdownGuard>,
150152
) -> Self {
151153
Self {
152154
config,
@@ -159,6 +161,7 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
159161
// on by the scheduler.
160162
actions_in_transit: Arc::new(AtomicU64::new(0)),
161163
metrics,
164+
shutdown_tx,
162165
}
163166
}
164167

@@ -208,6 +211,8 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
208211

209212
let (add_future_channel, add_future_rx) = mpsc::unbounded_channel();
210213
let mut add_future_rx = UnboundedReceiverStream::new(add_future_rx).fuse();
214+
let (inner_shutdown_channel, inner_shutdown_rx) = mpsc::unbounded_channel();
215+
let mut inner_shutdown_rx = UnboundedReceiverStream::new(inner_shutdown_rx).fuse();
211216

212217
let mut update_for_worker_stream = update_for_worker_stream.fuse();
213218
// A notify which is triggered every time actions_in_flight is subtracted.
@@ -217,6 +222,9 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
217222
let actions_in_flight = Arc::new(AtomicU64::new(0));
218223
// Set to true when shutting down, this stops any new StartAction.
219224
let mut shutting_down = false;
225+
// Channel to signal when shutdown is complete (GoingAway sent, ready to exit).
226+
let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::unbounded_channel::<()>();
227+
let mut shutdown_complete_rx = UnboundedReceiverStream::new(shutdown_complete_rx).fuse();
220228

221229
loop {
222230
select! {
@@ -406,6 +414,7 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
406414
self.actions_in_transit.fetch_add(1, Ordering::Release);
407415

408416
let add_future_channel = add_future_channel.clone();
417+
let inner_shutdown_channel = inner_shutdown_channel.clone();
409418

410419
info_span!(
411420
"worker_start_action_ctx",
@@ -428,7 +437,16 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
428437
error!(?err, "Error executing action");
429438
}
430439
add_future_channel
431-
.send(make_publish_future(res).then(move |res| {
440+
.send(make_publish_future(res)
441+
.then(move |res| {
442+
match self.config.execution_completion_behaviour {
443+
ExecutionCompletionBehaviour::OneShotAlways => {
444+
inner_shutdown_channel.send(()).ok();
445+
}
446+
ExecutionCompletionBehaviour::Default => {
447+
// Do nothing
448+
}
449+
}
432450
actions_in_flight.fetch_sub(1, Ordering::Release);
433451
actions_notify.notify_one();
434452
core::future::ready(res)
@@ -452,13 +470,23 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
452470
let fut = res.err_tip(|| "New future stream receives should never be closed")?;
453471
futures.push(fut);
454472
},
473+
_ = inner_shutdown_rx.next() => {
474+
warn!("Shutting down worker because of inner shutdown signal",);
475+
let guard = ShutdownGuard::default();
476+
drop(self.shutdown_tx.send(guard.clone()));
477+
}
455478
res = futures.next() => res.err_tip(|| "Keep-alive should always pending. Likely unable to send data to scheduler")??,
479+
_ = shutdown_complete_rx.next() => {
480+
info!("Shutdown complete, exiting worker loop");
481+
return Ok(());
482+
},
456483
complete_msg = shutdown_rx.recv().fuse() => {
457484
warn!("Worker loop received shutdown signal. Shutting down worker...",);
458485
let mut grpc_client = self.grpc_client.clone();
459486
let shutdown_guard = complete_msg.map_err(|e| make_err!(Code::Internal, "Failed to receive shutdown message: {e:?}"))?;
460487
let actions_in_flight = actions_in_flight.clone();
461488
let actions_notify = actions_notify.clone();
489+
let shutdown_complete_tx = shutdown_complete_tx.clone();
462490
let shutdown_future = async move {
463491
// Wait for in-flight operations to be fully completed.
464492
while actions_in_flight.load(Ordering::Acquire) > 0 {
@@ -472,6 +500,8 @@ impl<'a, T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorke
472500
}
473501
// Allow shutdown to occur now.
474502
drop(shutdown_guard);
503+
// Signal that shutdown is complete.
504+
let _ = shutdown_complete_tx.send(());
475505
Ok::<(), Error>(())
476506
};
477507
futures.push(shutdown_future.boxed());
@@ -732,7 +762,8 @@ impl<T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorker<T,
732762
#[instrument(skip(self), level = Level::INFO)]
733763
pub async fn run(
734764
mut self,
735-
mut shutdown_rx: broadcast::Receiver<ShutdownGuard>,
765+
shutdown_tx: Sender<ShutdownGuard>,
766+
mut shutdown_rx: Receiver<ShutdownGuard>,
736767
) -> Result<(), Error> {
737768
let sleep_fn = self
738769
.sleep_fn
@@ -767,6 +798,7 @@ impl<T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorker<T,
767798
worker_id,
768799
self.running_actions_manager.clone(),
769800
self.metrics.clone(),
801+
shutdown_tx.clone(),
770802
),
771803
update_for_worker_stream,
772804
),
@@ -777,30 +809,37 @@ impl<T: WorkerApiClientTrait + 'static, U: RunningActionsManager> LocalWorker<T,
777809
);
778810

779811
// Now listen for connections and run all other services.
780-
if let Err(err) = inner.run(update_for_worker_stream, &mut shutdown_rx).await {
781-
'no_more_actions: {
782-
// Ensure there are no actions in transit before we try to kill
783-
// all our actions.
784-
const ITERATIONS: usize = 1_000;
785-
786-
const ERROR_MSG: &str = "Actions in transit did not reach zero before we disconnected from the scheduler";
787-
788-
let sleep_duration = ACTIONS_IN_TRANSIT_TIMEOUT_S / ITERATIONS as f32;
789-
for _ in 0..ITERATIONS {
790-
if inner.actions_in_transit.load(Ordering::Acquire) == 0 {
791-
break 'no_more_actions;
812+
match inner.run(update_for_worker_stream, &mut shutdown_rx).await {
813+
Ok(()) => {
814+
// Graceful shutdown completed, return without retrying.
815+
info!("Worker completed graceful shutdown");
816+
return Ok(());
817+
}
818+
Err(err) => {
819+
'no_more_actions: {
820+
// Ensure there are no actions in transit before we try to kill
821+
// all our actions.
822+
const ITERATIONS: usize = 1_000;
823+
824+
const ERROR_MSG: &str = "Actions in transit did not reach zero before we disconnected from the scheduler";
825+
826+
let sleep_duration = ACTIONS_IN_TRANSIT_TIMEOUT_S / ITERATIONS as f32;
827+
for _ in 0..ITERATIONS {
828+
if inner.actions_in_transit.load(Ordering::Acquire) == 0 {
829+
break 'no_more_actions;
830+
}
831+
(sleep_fn_pin)(Duration::from_secs_f32(sleep_duration)).await;
792832
}
793-
(sleep_fn_pin)(Duration::from_secs_f32(sleep_duration)).await;
833+
error!(ERROR_MSG);
834+
return Err(err.append(ERROR_MSG));
794835
}
795-
error!(ERROR_MSG);
796-
return Err(err.append(ERROR_MSG));
797-
}
798-
error!(?err, "Worker disconnected from scheduler");
799-
// Kill off any existing actions because if we re-connect, we'll
800-
// get some more and it might resource lock us.
801-
self.running_actions_manager.kill_all().await;
836+
error!(?err, "Worker disconnected from scheduler");
837+
// Kill off any existing actions because if we re-connect, we'll
838+
// get some more and it might resource lock us.
839+
self.running_actions_manager.kill_all().await;
802840

803-
(error_handler)(err).await; // Try to connect again.
841+
(error_handler)(err).await; // Try to connect again.
842+
}
804843
}
805844
}
806845
// Unreachable.

nativelink-worker/tests/local_worker_test.rs

Lines changed: 121 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ use core::time::Duration;
1616
use std::collections::HashMap;
1717
use std::env;
1818
use std::ffi::OsString;
19-
use std::io::Write;
19+
use std::io::{Write};
2020
#[cfg(target_family = "unix")]
2121
use std::os::unix::fs::OpenOptionsExt;
2222
use std::path::PathBuf;
@@ -29,7 +29,7 @@ mod utils {
2929
}
3030

3131
use hyper::body::Frame;
32-
use nativelink_config::cas_server::{LocalWorkerConfig, WorkerProperty};
32+
use nativelink_config::cas_server::{ExecutionCompletionBehaviour, LocalWorkerConfig, WorkerProperty};
3333
use nativelink_config::stores::{
3434
FastSlowSpec, FilesystemSpec, MemorySpec, StoreDirection, StoreSpec,
3535
};
@@ -424,6 +424,125 @@ async fn simple_worker_start_action_test() -> Result<(), Error> {
424424
Ok(())
425425
}
426426

427+
#[nativelink_test]
428+
async fn one_shot_shutdowns_worker_test() -> Result<(), Error> {
429+
let config = LocalWorkerConfig {
430+
execution_completion_behaviour: ExecutionCompletionBehaviour::OneShotAlways,
431+
..Default::default()
432+
};
433+
let mut test_context = setup_local_worker_with_config(config).await;
434+
let streaming_response = test_context.maybe_streaming_response.take().unwrap();
435+
436+
{
437+
let props = test_context
438+
.client
439+
.expect_connect_worker(Ok(streaming_response))
440+
.await;
441+
assert_eq!(props, ConnectWorkerRequest::default());
442+
}
443+
444+
let expected_worker_id = "foobar".to_string();
445+
446+
let tx_stream = test_context.maybe_tx_stream.take().unwrap();
447+
{
448+
// First initialize our worker by sending the response to the connection request.
449+
tx_stream
450+
.send(Frame::data(
451+
encode_stream_proto(&UpdateForWorker {
452+
update: Some(Update::ConnectionResult(ConnectionResult {
453+
worker_id: expected_worker_id.clone(),
454+
})),
455+
})
456+
.unwrap(),
457+
))
458+
.await
459+
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
460+
}
461+
462+
let action_digest = DigestInfo::new([3u8; 32], 10);
463+
let action_info = ActionInfo {
464+
command_digest: DigestInfo::new([1u8; 32], 10),
465+
input_root_digest: DigestInfo::new([2u8; 32], 10),
466+
timeout: Duration::from_secs(1),
467+
platform_properties: HashMap::new(),
468+
priority: 0,
469+
load_timestamp: SystemTime::UNIX_EPOCH,
470+
insert_timestamp: SystemTime::UNIX_EPOCH,
471+
unique_qualifier: ActionUniqueQualifier::Uncacheable(ActionUniqueKey {
472+
instance_name: INSTANCE_NAME.to_string(),
473+
digest_function: DigestHasherFunc::Sha256,
474+
digest: action_digest,
475+
}),
476+
};
477+
478+
{
479+
// Send execution request.
480+
tx_stream
481+
.send(Frame::data(
482+
encode_stream_proto(&UpdateForWorker {
483+
update: Some(Update::StartAction(StartExecute {
484+
execute_request: Some((&action_info).into()),
485+
operation_id: String::new(),
486+
queued_timestamp: None,
487+
platform: Some(Platform::default()),
488+
worker_id: expected_worker_id.clone(),
489+
})),
490+
})
491+
.unwrap(),
492+
))
493+
.await
494+
.map_err(|e| make_input_err!("Could not send : {:?}", e))?;
495+
}
496+
497+
let running_action = Arc::new(MockRunningAction::new());
498+
499+
let action_result = ActionResult {
500+
output_files: vec![],
501+
output_folders: vec![],
502+
output_file_symlinks: vec![],
503+
output_directory_symlinks: vec![],
504+
exit_code: 5,
505+
stdout_digest: DigestInfo::new([21u8; 32], 10),
506+
stderr_digest: DigestInfo::new([22u8; 32], 10),
507+
execution_metadata: ExecutionMetadata {
508+
worker: expected_worker_id.clone(),
509+
queued_timestamp: SystemTime::UNIX_EPOCH,
510+
worker_start_timestamp: SystemTime::UNIX_EPOCH,
511+
worker_completed_timestamp: SystemTime::UNIX_EPOCH,
512+
input_fetch_start_timestamp: SystemTime::UNIX_EPOCH,
513+
input_fetch_completed_timestamp: SystemTime::UNIX_EPOCH,
514+
execution_start_timestamp: SystemTime::UNIX_EPOCH,
515+
execution_completed_timestamp: SystemTime::UNIX_EPOCH,
516+
output_upload_start_timestamp: SystemTime::UNIX_EPOCH,
517+
output_upload_completed_timestamp: SystemTime::UNIX_EPOCH,
518+
},
519+
server_logs: HashMap::new(),
520+
error: None,
521+
message: String::new(),
522+
};
523+
524+
// Send and wait for response from create_and_add_action to RunningActionsManager.
525+
test_context
526+
.actions_manager
527+
.expect_create_and_add_action(Ok(running_action.clone()))
528+
.await;
529+
530+
531+
// Now the RunningAction needs to send a series of state updates. This shortcuts them
532+
// into a single call (shortcut for prepare, execute, upload, collect_results, cleanup).
533+
running_action
534+
.simple_expect_get_finished_result(Ok(action_result.clone()))
535+
.await?;
536+
537+
test_context.client.expect_execution_response(Ok(())).await;
538+
539+
test_context.client
540+
.expect_going_away(Ok(()))
541+
.await;
542+
543+
Ok(())
544+
}
545+
427546
#[nativelink_test]
428547
async fn new_local_worker_creates_work_directory_test() -> Result<(), Error> {
429548
let cas_store = Store::new(FastSlowStore::new(

0 commit comments

Comments
 (0)