From 7a16e2e6043b17e7813e41450b4de9c40de435f4 Mon Sep 17 00:00:00 2001 From: Blaise Bruer Date: Fri, 28 Jun 2024 13:21:39 -0500 Subject: [PATCH] [Refactor] Move scheduler state behind mutex In prep to support a distributed/redis scheduler, prepare the state interface to no longer take mutable references. This is a partial PR and should be landed immediately with followup PRs that will remove many of the locking in the SimpleScheduler. towards: #359 --- nativelink-scheduler/src/action_scheduler.rs | 2 +- .../src/cache_lookup_scheduler.rs | 4 +- nativelink-scheduler/src/grpc_scheduler.rs | 6 +- .../src/operation_state_manager.rs | 6 +- .../src/property_modifier_scheduler.rs | 2 +- .../src/redis_operation_state.rs | 6 +- .../src/scheduler_state/state_manager.rs | 441 +++++++++--------- nativelink-scheduler/src/simple_scheduler.rs | 243 ++++++---- .../tests/cache_lookup_scheduler_test.rs | 4 +- .../tests/property_modifier_scheduler_test.rs | 4 +- .../tests/simple_scheduler_test.rs | 6 +- .../tests/utils/mock_scheduler.rs | 6 +- nativelink-service/src/execution_server.rs | 1 + 13 files changed, 376 insertions(+), 355 deletions(-) diff --git a/nativelink-scheduler/src/action_scheduler.rs b/nativelink-scheduler/src/action_scheduler.rs index 5460209d7..21d293d23 100644 --- a/nativelink-scheduler/src/action_scheduler.rs +++ b/nativelink-scheduler/src/action_scheduler.rs @@ -42,7 +42,7 @@ pub trait ActionScheduler: Sync + Send + Unpin { async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, - ) -> Option>>; + ) -> Result>>, Error>; /// Cleans up the cache of recently completed actions. async fn clean_recently_completed_actions(&self); diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index 18c7a7f4e..be26b1d55 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -222,11 +222,11 @@ impl ActionScheduler for CacheLookupScheduler { async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + ) -> Result>>, Error> { { let cache_check_actions = self.cache_check_actions.lock(); if let Some(rx) = subscribe_to_existing_action(&cache_check_actions, unique_qualifier) { - return Some(rx); + return Ok(Some(rx)); } } // Cache skipped may be in the upstream scheduler. diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 646cb4e89..f284688ab 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -260,7 +260,7 @@ impl ActionScheduler for GrpcScheduler { async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + ) -> Result>>, Error> { let request = WaitExecutionRequest { name: unique_qualifier.action_name(), }; @@ -279,14 +279,14 @@ impl ActionScheduler for GrpcScheduler { .and_then(|result_stream| Self::stream_state(result_stream.into_inner())) .await; match result_stream { - Ok(result_stream) => Some(result_stream), + Ok(result_stream) => Ok(Some(result_stream)), Err(err) => { event!( Level::WARN, ?err, "Error looking up action with upstream scheduler" ); - None + Ok(None) } } } diff --git a/nativelink-scheduler/src/operation_state_manager.rs b/nativelink-scheduler/src/operation_state_manager.rs index 2b7184d3f..7baa5f156 100644 --- a/nativelink-scheduler/src/operation_state_manager.rs +++ b/nativelink-scheduler/src/operation_state_manager.rs @@ -101,7 +101,7 @@ pub type ActionStateResultStream = Pin Result, Error>; @@ -119,7 +119,7 @@ pub trait WorkerStateManager { /// did not change with a modified timestamp in order to prevent /// the operation from being considered stale and being rescheduled. async fn update_operation( - &mut self, + &self, operation_id: OperationId, worker_id: WorkerId, action_stage: Result, @@ -136,7 +136,7 @@ pub trait MatchingEngineStateManager { /// Update that state of an operation. async fn update_operation( - &mut self, + &self, operation_id: OperationId, worker_id: Option, action_stage: Result, diff --git a/nativelink-scheduler/src/property_modifier_scheduler.rs b/nativelink-scheduler/src/property_modifier_scheduler.rs index c7b827426..93f503e8c 100644 --- a/nativelink-scheduler/src/property_modifier_scheduler.rs +++ b/nativelink-scheduler/src/property_modifier_scheduler.rs @@ -117,7 +117,7 @@ impl ActionScheduler for PropertyModifierScheduler { async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + ) -> Result>>, Error> { self.scheduler.find_existing_action(unique_qualifier).await } diff --git a/nativelink-scheduler/src/redis_operation_state.rs b/nativelink-scheduler/src/redis_operation_state.rs index 5dd4c13d0..12593980f 100644 --- a/nativelink-scheduler/src/redis_operation_state.rs +++ b/nativelink-scheduler/src/redis_operation_state.rs @@ -413,7 +413,7 @@ impl RedisStateManage #[async_trait] impl ClientStateManager for RedisStateManager { async fn add_action( - &mut self, + &self, action_info: ActionInfo, ) -> Result, Error> { self.inner_add_action(action_info).await @@ -430,7 +430,7 @@ impl ClientStateManager for RedisStateManager { #[async_trait] impl WorkerStateManager for RedisStateManager { async fn update_operation( - &mut self, + &self, operation_id: OperationId, worker_id: WorkerId, action_stage: Result, @@ -450,7 +450,7 @@ impl MatchingEngineStateManager for RedisStateManager { } async fn update_operation( - &mut self, + &self, operation_id: OperationId, worker_id: Option, action_stage: Result, diff --git a/nativelink-scheduler/src/scheduler_state/state_manager.rs b/nativelink-scheduler/src/scheduler_state/state_manager.rs index 8dd0def9c..087fc5e93 100644 --- a/nativelink-scheduler/src/scheduler_state/state_manager.rs +++ b/nativelink-scheduler/src/scheduler_state/state_manager.rs @@ -17,6 +17,7 @@ use std::collections::BTreeMap; use std::sync::Arc; use std::time::SystemTime; +use async_lock::Mutex; use async_trait::async_trait; use futures::stream; use hashbrown::{HashMap, HashSet}; @@ -43,7 +44,7 @@ use crate::worker::WorkerUpdate; #[repr(transparent)] pub(crate) struct StateManager { - pub inner: StateManagerImpl, + pub inner: Mutex, } impl StateManager { @@ -56,10 +57,10 @@ impl StateManager { recently_completed_actions: HashSet, metrics: Arc, max_job_retries: usize, - tasks_or_workers_change_notify: Arc, + tasks_change_notify: Arc, ) -> Self { Self { - inner: StateManagerImpl { + inner: Mutex::new(StateManagerImpl { queued_actions_set, queued_actions, workers, @@ -67,33 +68,159 @@ impl StateManager { recently_completed_actions, metrics, max_job_retries, - tasks_or_workers_change_notify, - }, + tasks_change_notify, + }), } } +} +/// StateManager is responsible for maintaining the state of the scheduler. Scheduler state +/// includes the actions that are queued, active, and recently completed. It also includes the +/// workers that are available to execute actions based on allocation strategy. +pub(crate) struct StateManagerImpl { + // TODO(adams): Move `queued_actions_set` and `queued_actions` into a single struct that + // provides a unified interface for interacting with the two containers. + + // Important: `queued_actions_set` and `queued_actions` are two containers that provide + // different search and sort capabilities. We are using the two different containers to + // optimize different use cases. `HashSet` is used to look up actions in O(1) time. The + // `BTreeMap` is used to sort actions in O(log n) time based on priority and timestamp. + // These two fields must be kept in-sync, so if you modify one, you likely need to modify the + // other. + /// A `HashSet` of all actions that are queued. A hashset is used to find actions that are queued + /// in O(1) time. This set allows us to find and join on new actions onto already existing + /// (or queued) actions where insert timestamp of queued actions is not known. Using an + /// additional `HashSet` will prevent us from having to iterate the `BTreeMap` to find actions. + /// + /// Important: `queued_actions_set` and `queued_actions` must be kept in sync. + pub(crate) queued_actions_set: HashSet>, + + /// A BTreeMap of sorted actions that are primarily based on priority and insert timestamp. + /// `ActionInfo` implements `Ord` that defines the `cmp` function for order. Using a BTreeMap + /// gives us to sorted actions that are queued in O(log n) time. + /// + /// Important: `queued_actions_set` and `queued_actions` must be kept in sync. + pub(crate) queued_actions: BTreeMap, AwaitedAction>, + + /// A `Workers` pool that contains all workers that are available to execute actions in a priority + /// order based on the allocation strategy. + pub(crate) workers: Workers, + + /// A map of all actions that are active. A hashmap is used to find actions that are active in + /// O(1) time. The key is the `ActionInfo` struct. The value is the `AwaitedAction` struct. + pub(crate) active_actions: HashMap, AwaitedAction>, + + /// These actions completed recently but had no listener, they might have + /// completed while the caller was thinking about calling wait_execution, so + /// keep their completion state around for a while to send back. + /// TODO(#192) Revisit if this is the best way to handle recently completed actions. + pub(crate) recently_completed_actions: HashSet, + + pub(crate) metrics: Arc, + + /// Default times a job can retry before failing. + pub(crate) max_job_retries: usize, + + /// Notify task<->worker matching engine that work needs to be done. + pub(crate) tasks_change_notify: Arc, +} + +/// Modifies the `stage` of `current_state` within `AwaitedAction`. Sends notification channel +/// the new state. +/// +/// +/// # Discussion +/// +/// The use of `Arc::make_mut` is potentially dangerous because it clones the data and +/// invalidates all weak references to it. However, in this context, it is considered +/// safe because the data is going to be re-sent back out. The primary reason for using +/// `Arc` is to reduce the number of copies, not to enforce read-only access. This approach +/// ensures that all downstream components receive the same pointer. If an update occurs +/// while another thread is operating on the data, it is acceptable, since the other thread +/// will receive another update with the new version. +/// +pub(crate) fn mutate_stage( + awaited_action: &mut AwaitedAction, + action_stage: ActionStage, +) -> Result<(), SendError>> { + Arc::make_mut(&mut awaited_action.current_state).stage = action_stage; + awaited_action + .notify_channel + .send(awaited_action.current_state.clone()) +} + +/// Updates the `last_error` field of the provided `AwaitedAction` and sends the current state +/// to the notify channel. +/// +fn mutate_last_error( + awaited_action: &mut AwaitedAction, + last_error: Error, +) -> Result<(), SendError>> { + awaited_action.last_error = Some(last_error); + awaited_action + .notify_channel + .send(awaited_action.current_state.clone()) +} + +/// Sets the action stage for the given `AwaitedAction` based on the result of the provided +/// `action_stage`. If the `action_stage` is an error, it updates the `last_error` field +/// and logs a warning. +/// +/// # Note +/// +/// Intended utility function for matching engine. +/// +/// # Errors +/// +/// This function will return an error if updating the state of the `awaited_action` fails. +/// +async fn worker_set_action_stage( + awaited_action: &mut AwaitedAction, + action_stage: Result, + worker_id: WorkerId, +) -> Result<(), SendError>> { + match action_stage { + Ok(action_stage) => mutate_stage(awaited_action, action_stage), + Err(e) => { + event!( + Level::WARN, + ?worker_id, + "Action stage setting error during do_try_match()" + ); + mutate_last_error(awaited_action, e) + } + } +} + +/// Modifies the `priority` of `action_info` within `ActionInfo`. +/// +fn mutate_priority(action_info: &mut Arc, priority: i32) { + Arc::make_mut(action_info).priority = priority; +} + +impl StateManagerImpl { fn immediate_evict_worker(&mut self, worker_id: &WorkerId, err: Error) { - if let Some(mut worker) = self.inner.workers.remove_worker(worker_id) { - self.inner.metrics.workers_evicted.inc(); + if let Some(mut worker) = self.workers.remove_worker(worker_id) { + self.metrics.workers_evicted.inc(); // We don't care if we fail to send message to worker, this is only a best attempt. let _ = worker.notify_update(WorkerUpdate::Disconnect); // We create a temporary Vec to avoid doubt about a possible code // path touching the worker.running_action_infos elsewhere. for action_info in worker.running_action_infos.drain() { - self.inner.metrics.workers_evicted_with_running_action.inc(); + self.metrics.workers_evicted_with_running_action.inc(); self.retry_action(&action_info, worker_id, err.clone()); } // Note: Calling this multiple times is very cheap, it'll only trigger `do_try_match` once. - self.inner.tasks_or_workers_change_notify.notify_one(); + self.tasks_change_notify.notify_one(); } } fn retry_action(&mut self, action_info: &Arc, worker_id: &WorkerId, err: Error) { - match self.inner.active_actions.remove(action_info) { + match self.active_actions.remove(action_info) { Some(running_action) => { let mut awaited_action = running_action; - let send_result = if awaited_action.attempts >= self.inner.max_job_retries { - self.inner.metrics.retry_action_max_attempts_reached.inc(); + let send_result = if awaited_action.attempts >= self.max_job_retries { + self.metrics.retry_action_max_attempts_reached.inc(); Arc::make_mut(&mut awaited_action.current_state).stage = ActionStage::Completed(ActionResult { execution_metadata: ExecutionMetadata { worker: format!("{worker_id}"), @@ -111,20 +238,19 @@ impl StateManager { // Do not put the action back in the queue here, as this action attempted to run too many // times. } else { - self.inner.metrics.retry_action.inc(); + self.metrics.retry_action.inc(); Arc::make_mut(&mut awaited_action.current_state).stage = ActionStage::Queued; let send_result = awaited_action .notify_channel .send(awaited_action.current_state.clone()); - self.inner.queued_actions_set.insert(action_info.clone()); - self.inner - .queued_actions + self.queued_actions_set.insert(action_info.clone()); + self.queued_actions .insert(action_info.clone(), awaited_action); send_result }; if send_result.is_err() { - self.inner.metrics.retry_action_no_more_listeners.inc(); + self.metrics.retry_action_no_more_listeners.inc(); // Don't remove this task, instead we keep them around for a bit just in case // the client disconnected and will reconnect and ask for same job to be executed // again. @@ -137,7 +263,7 @@ impl StateManager { } } None => { - self.inner.metrics.retry_action_but_action_missing.inc(); + self.metrics.retry_action_but_action_missing.inc(); event!( Level::ERROR, ?action_info, @@ -147,102 +273,6 @@ impl StateManager { } } } -} - -/// StateManager is responsible for maintaining the state of the scheduler. Scheduler state -/// includes the actions that are queued, active, and recently completed. It also includes the -/// workers that are available to execute actions based on allocation strategy. -pub(crate) struct StateManagerImpl { - // TODO(adams): Move `queued_actions_set` and `queued_actions` into a single struct that - // provides a unified interface for interacting with the two containers. - - // Important: `queued_actions_set` and `queued_actions` are two containers that provide - // different search and sort capabilities. We are using the two different containers to - // optimize different use cases. `HashSet` is used to look up actions in O(1) time. The - // `BTreeMap` is used to sort actions in O(log n) time based on priority and timestamp. - // These two fields must be kept in-sync, so if you modify one, you likely need to modify the - // other. - /// A `HashSet` of all actions that are queued. A hashset is used to find actions that are queued - /// in O(1) time. This set allows us to find and join on new actions onto already existing - /// (or queued) actions where insert timestamp of queued actions is not known. Using an - /// additional `HashSet` will prevent us from having to iterate the `BTreeMap` to find actions. - /// - /// Important: `queued_actions_set` and `queued_actions` must be kept in sync. - pub(crate) queued_actions_set: HashSet>, - - /// A BTreeMap of sorted actions that are primarily based on priority and insert timestamp. - /// `ActionInfo` implements `Ord` that defines the `cmp` function for order. Using a BTreeMap - /// gives us to sorted actions that are queued in O(log n) time. - /// - /// Important: `queued_actions_set` and `queued_actions` must be kept in sync. - pub(crate) queued_actions: BTreeMap, AwaitedAction>, - - /// A `Workers` pool that contains all workers that are available to execute actions in a priority - /// order based on the allocation strategy. - pub(crate) workers: Workers, - - /// A map of all actions that are active. A hashmap is used to find actions that are active in - /// O(1) time. The key is the `ActionInfo` struct. The value is the `AwaitedAction` struct. - pub(crate) active_actions: HashMap, AwaitedAction>, - - /// These actions completed recently but had no listener, they might have - /// completed while the caller was thinking about calling wait_execution, so - /// keep their completion state around for a while to send back. - /// TODO(#192) Revisit if this is the best way to handle recently completed actions. - pub(crate) recently_completed_actions: HashSet, - - pub(crate) metrics: Arc, - - /// Default times a job can retry before failing. - pub(crate) max_job_retries: usize, - - /// Notify task<->worker matching engine that work needs to be done. - pub(crate) tasks_or_workers_change_notify: Arc, -} - -impl StateManager { - /// Modifies the `stage` of `current_state` within `AwaitedAction`. Sends notification channel - /// the new state. - /// - /// - /// # Discussion - /// - /// The use of `Arc::make_mut` is potentially dangerous because it clones the data and - /// invalidates all weak references to it. However, in this context, it is considered - /// safe because the data is going to be re-sent back out. The primary reason for using - /// `Arc` is to reduce the number of copies, not to enforce read-only access. This approach - /// ensures that all downstream components receive the same pointer. If an update occurs - /// while another thread is operating on the data, it is acceptable, since the other thread - /// will receive another update with the new version. - /// - pub(crate) fn mutate_stage( - awaited_action: &mut AwaitedAction, - action_stage: ActionStage, - ) -> Result<(), SendError>> { - Arc::make_mut(&mut awaited_action.current_state).stage = action_stage; - awaited_action - .notify_channel - .send(awaited_action.current_state.clone()) - } - - /// Modifies the `priority` of `action_info` within `ActionInfo`. - /// - fn mutate_priority(action_info: &mut Arc, priority: i32) { - Arc::make_mut(action_info).priority = priority; - } - - /// Updates the `last_error` field of the provided `AwaitedAction` and sends the current state - /// to the notify channel. - /// - fn mutate_last_error( - awaited_action: &mut AwaitedAction, - last_error: Error, - ) -> Result<(), SendError>> { - awaited_action.last_error = Some(last_error); - awaited_action - .notify_channel - .send(awaited_action.current_state.clone()) - } /// Notifies the specified worker to run the given action and handles errors by evicting /// the worker if the notification fails. @@ -261,7 +291,7 @@ impl StateManager { worker_id: WorkerId, action_info: Arc, ) -> Result<(), Error> { - if let Some(worker) = self.inner.workers.workers.get_mut(&worker_id) { + if let Some(worker) = self.workers.workers.get_mut(&worker_id) { let notify_worker_result = worker.notify_update(WorkerUpdate::RunAction(action_info.clone())); @@ -286,36 +316,6 @@ impl StateManager { Ok(()) } - /// Sets the action stage for the given `AwaitedAction` based on the result of the provided - /// `action_stage`. If the `action_stage` is an error, it updates the `last_error` field - /// and logs a warning. - /// - /// # Note - /// - /// Intended utility function for matching engine. - /// - /// # Errors - /// - /// This function will return an error if updating the state of the `awaited_action` fails. - /// - async fn worker_set_action_stage( - awaited_action: &mut AwaitedAction, - action_stage: Result, - worker_id: WorkerId, - ) -> Result<(), SendError>> { - match action_stage { - Ok(action_stage) => StateManager::mutate_stage(awaited_action, action_stage), - Err(e) => { - event!( - Level::WARN, - ?worker_id, - "Action stage setting error during do_try_match()" - ); - StateManager::mutate_last_error(awaited_action, e) - } - } - } - /// Marks the specified action as active, assigns it to the given worker, and updates the /// action stage. This function removes the action from the queue, updates the action's state /// or error, and inserts it into the set of active actions. @@ -336,18 +336,17 @@ impl StateManager { action_stage: Result, ) -> Result<(), Error> { if let Some((action_info, mut awaited_action)) = - self.inner.queued_actions.remove_entry(action_info.as_ref()) + self.queued_actions.remove_entry(action_info.as_ref()) { assert!( - self.inner.queued_actions_set.remove(&action_info), + self.queued_actions_set.remove(&action_info), "queued_actions_set should always have same keys as queued_actions" ); awaited_action.worker_id = Some(worker_id); let send_result = - StateManager::worker_set_action_stage(&mut awaited_action, action_stage, worker_id) - .await; + worker_set_action_stage(&mut awaited_action, action_stage, worker_id).await; if send_result.is_err() { event!( @@ -359,9 +358,7 @@ impl StateManager { } awaited_action.attempts += 1; - self.inner - .active_actions - .insert(action_info, awaited_action); + self.active_actions.insert(action_info, awaited_action); Ok(()) } else { Err(make_err!( @@ -377,14 +374,11 @@ impl StateManager { action_info_hash_key: ActionInfoHashKey, err: Error, ) { - self.inner.metrics.update_action_with_internal_error.inc(); - let Some((action_info, mut running_action)) = self - .inner - .active_actions - .remove_entry(&action_info_hash_key) + self.metrics.update_action_with_internal_error.inc(); + let Some((action_info, mut running_action)) = + self.active_actions.remove_entry(&action_info_hash_key) else { - self.inner - .metrics + self.metrics .update_action_with_internal_error_no_action .inc(); event!( @@ -399,8 +393,7 @@ impl StateManager { let due_to_backpressure = err.code == Code::ResourceExhausted; // Don't count a backpressure failure as an attempt for an action. if due_to_backpressure { - self.inner - .metrics + self.metrics .update_action_with_internal_error_backpressure .inc(); running_action.attempts -= 1; @@ -426,19 +419,17 @@ impl StateManager { ); running_action.last_error = Some(err.clone()); } else { - self.inner - .metrics + self.metrics .update_action_with_internal_error_from_wrong_worker .inc(); } // Now put it back. retry_action() needs it to be there to send errors properly. - self.inner - .active_actions + self.active_actions .insert(action_info.clone(), running_action); // Clear this action from the current worker. - if let Some(worker) = self.inner.workers.workers.get_mut(worker_id) { + if let Some(worker) = self.workers.workers.get_mut(worker_id) { let was_paused = !worker.can_accept_work(); // This unpauses, but since we're completing with an error, don't // unpause unless all actions have completed. @@ -451,40 +442,39 @@ impl StateManager { // Re-queue the action or fail on max attempts. self.retry_action(&action_info, worker_id, err); - self.inner.tasks_or_workers_change_notify.notify_one(); + self.tasks_change_notify.notify_one(); } } #[async_trait] impl ClientStateManager for StateManager { async fn add_action( - &mut self, + &self, action_info: ActionInfo, ) -> Result, Error> { + let mut inner = self.inner.lock().await; // Check to see if the action is running, if it is and cacheable, merge the actions. - if let Some(running_action) = self.inner.active_actions.get_mut(&action_info) { - self.inner.metrics.add_action_joined_running_action.inc(); - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(Arc::new(ClientActionStateResult::new( - running_action.notify_channel.subscribe(), - ))); + if let Some(running_action) = inner.active_actions.get_mut(&action_info) { + let subscription = running_action.notify_channel.subscribe(); + inner.metrics.add_action_joined_running_action.inc(); + inner.tasks_change_notify.notify_one(); + return Ok(Arc::new(ClientActionStateResult::new(subscription))); } // Check to see if the action is queued, if it is and cacheable, merge the actions. - if let Some(mut arc_action_info) = self.inner.queued_actions_set.take(&action_info) { - let (original_action_info, queued_action) = self - .inner + if let Some(mut arc_action_info) = inner.queued_actions_set.take(&action_info) { + let (original_action_info, queued_action) = inner .queued_actions .remove_entry(&arc_action_info) .err_tip(|| "Internal error queued_actions and queued_actions_set should match")?; - self.inner.metrics.add_action_joined_queued_action.inc(); + inner.metrics.add_action_joined_queued_action.inc(); let new_priority = cmp::max(original_action_info.priority, action_info.priority); drop(original_action_info); // This increases the chance Arc::make_mut won't copy. // In the event our task is higher priority than the one already scheduled, increase // the priority of the scheduled one. - StateManager::mutate_priority(&mut arc_action_info, new_priority); + mutate_priority(&mut arc_action_info, new_priority); let result = Arc::new(ClientActionStateResult::new( queued_action.notify_channel.subscribe(), @@ -492,15 +482,15 @@ impl ClientStateManager for StateManager { // Even if we fail to send our action to the client, we need to add this action back to the // queue because it was remove earlier. - self.inner + inner .queued_actions .insert(arc_action_info.clone(), queued_action); - self.inner.queued_actions_set.insert(arc_action_info); - self.inner.tasks_or_workers_change_notify.notify_one(); + inner.queued_actions_set.insert(arc_action_info); + inner.tasks_change_notify.notify_one(); return Ok(result); } - self.inner.metrics.add_action_new_action_created.inc(); + inner.metrics.add_action_new_action_created.inc(); // Action needs to be added to queue or is not cacheable. let action_info = Arc::new(action_info); @@ -513,8 +503,8 @@ impl ClientStateManager for StateManager { let (tx, rx) = watch::channel(current_state.clone()); - self.inner.queued_actions_set.insert(action_info.clone()); - self.inner.queued_actions.insert( + inner.queued_actions_set.insert(action_info.clone()); + inner.queued_actions.insert( action_info.clone(), AwaitedAction { action_info, @@ -525,7 +515,7 @@ impl ClientStateManager for StateManager { worker_id: None, }, ); - self.inner.tasks_or_workers_change_notify.notify_one(); + inner.tasks_change_notify.notify_one(); return Ok(Arc::new(ClientActionStateResult::new(rx))); } @@ -533,17 +523,17 @@ impl ClientStateManager for StateManager { &self, filter: OperationFilter, ) -> Result { + let inner = self.inner.lock().await; // TODO(adams): Build out a proper filter for other fields for state, at the moment // this only supports the unique qualifier. let unique_qualifier = &filter .unique_qualifier .err_tip(|| "No unique qualifier provided")?; - let maybe_awaited_action = self - .inner + let maybe_awaited_action = inner .queued_actions_set .get(unique_qualifier) - .and_then(|action_info| self.inner.queued_actions.get(action_info)) - .or_else(|| self.inner.active_actions.get(unique_qualifier)); + .and_then(|action_info| inner.queued_actions.get(action_info)) + .or_else(|| inner.active_actions.get(unique_qualifier)); let Some(awaited_action) = maybe_awaited_action else { return Ok(Box::pin(stream::empty())); @@ -559,16 +549,17 @@ impl ClientStateManager for StateManager { #[async_trait] impl WorkerStateManager for StateManager { async fn update_operation( - &mut self, + &self, operation_id: OperationId, worker_id: WorkerId, action_stage: Result, ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; match action_stage { Ok(action_stage) => { let action_info_hash_key = operation_id.unique_qualifier; if !action_stage.has_action_result() { - self.inner.metrics.update_action_missing_action_result.inc(); + inner.metrics.update_action_missing_action_result.inc(); event!( Level::ERROR, ?action_info_hash_key, @@ -580,12 +571,11 @@ impl WorkerStateManager for StateManager { Code::Internal, "Worker '{worker_id}' set the action_stage of running action {action_info_hash_key:?} to {action_stage:?}. Removing worker.", ); - self.immediate_evict_worker(&worker_id, err.clone()); + inner.immediate_evict_worker(&worker_id, err.clone()); return Err(err); } - let (action_info, mut running_action) = self - .inner + let (action_info, mut running_action) = inner .active_actions .remove_entry(&action_info_hash_key) .err_tip(|| { @@ -593,7 +583,7 @@ impl WorkerStateManager for StateManager { })?; if running_action.worker_id != Some(worker_id) { - self.inner.metrics.update_action_from_wrong_worker.inc(); + inner.metrics.update_action_from_wrong_worker.inc(); let err = match running_action.worker_id { Some(running_action_worker_id) => make_err!( @@ -614,18 +604,16 @@ impl WorkerStateManager for StateManager { "Got a result from a worker that should not be running the action, Removing worker" ); // First put it back in our active_actions or we will drop the task. - self.inner - .active_actions - .insert(action_info, running_action); - self.immediate_evict_worker(&worker_id, err.clone()); + inner.active_actions.insert(action_info, running_action); + inner.immediate_evict_worker(&worker_id, err.clone()); return Err(err); } - let send_result = StateManager::mutate_stage(&mut running_action, action_stage); + let send_result = mutate_stage(&mut running_action, action_stage); if !running_action.current_state.stage.is_finished() { if send_result.is_err() { - self.inner.metrics.update_action_no_more_listeners.inc(); + inner.metrics.update_action_no_more_listeners.inc(); event!( Level::WARN, ?action_info, @@ -635,36 +623,27 @@ impl WorkerStateManager for StateManager { } // If the operation is not finished it means the worker is still working on it, so put it // back or else we will lose track of the task. - self.inner - .active_actions - .insert(action_info, running_action); + inner.active_actions.insert(action_info, running_action); - self.inner.tasks_or_workers_change_notify.notify_one(); + inner.tasks_change_notify.notify_one(); return Ok(()); } // Keep in case this is asked for soon. - self.inner - .recently_completed_actions - .insert(CompletedAction { - completed_time: SystemTime::now(), - state: running_action.current_state, - }); + inner.recently_completed_actions.insert(CompletedAction { + completed_time: SystemTime::now(), + state: running_action.current_state, + }); - let worker = self - .inner - .workers - .workers - .get_mut(&worker_id) - .ok_or_else(|| { - make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) - })?; + let worker = inner.workers.workers.get_mut(&worker_id).ok_or_else(|| { + make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) + })?; worker.complete_action(&action_info); - self.inner.tasks_or_workers_change_notify.notify_one(); + inner.tasks_change_notify.notify_one(); Ok(()) } Err(e) => { - self.update_action_with_internal_error( + inner.update_action_with_internal_error( &worker_id, operation_id.unique_qualifier, e.clone(), @@ -681,9 +660,10 @@ impl MatchingEngineStateManager for StateManager { &self, _filter: OperationFilter, // TODO(adam): reference filter ) -> Result { + let inner = self.inner.lock().await; // TODO(adams): use OperationFilter vs directly encoding it. let action_infos = - self.inner + inner .queued_actions .iter() .rev() @@ -700,21 +680,20 @@ impl MatchingEngineStateManager for StateManager { } async fn update_operation( - &mut self, + &self, operation_id: OperationId, worker_id: Option, action_stage: Result, ) -> Result<(), Error> { - if let Some(action_info) = self - .inner - .queued_actions_set - .get(&operation_id.unique_qualifier) - { + let mut inner = self.inner.lock().await; + if let Some(action_info) = inner.queued_actions_set.get(&operation_id.unique_qualifier) { if let Some(worker_id) = worker_id { let action_info = action_info.clone(); - self.worker_notify_run_action(worker_id, action_info.clone()) + inner + .worker_notify_run_action(worker_id, action_info.clone()) .await?; - self.worker_set_as_active(action_info, worker_id, action_stage) + inner + .worker_set_as_active(action_info, worker_id, action_stage) .await?; } else { event!( diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 4a651fe5a..33f9e0122 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -20,7 +20,7 @@ use std::time::{Instant, SystemTime}; use async_lock::{Mutex, MutexGuard}; use async_trait::async_trait; -use futures::{Future, Stream}; +use futures::{Future, Stream, TryFutureExt}; use hashbrown::{HashMap, HashSet}; use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; use nativelink_util::action_messages::{ @@ -46,7 +46,7 @@ use crate::operation_state_manager::{ }; use crate::platform_property_manager::PlatformPropertyManager; use crate::scheduler_state::metrics::Metrics as SchedulerMetrics; -use crate::scheduler_state::state_manager::StateManager; +use crate::scheduler_state::state_manager::{mutate_stage, StateManager, StateManagerImpl}; use crate::scheduler_state::workers::Workers; use crate::worker::{Worker, WorkerTimestamp, WorkerUpdate}; use crate::worker_scheduler::WorkerScheduler; @@ -90,31 +90,36 @@ impl SimpleSchedulerImpl { add_action_result.as_receiver().await.cloned() } - fn clean_recently_completed_actions(&mut self) { + async fn clean_recently_completed_actions(&mut self) { let expiry_time = SystemTime::now() .checked_sub(self.retain_completed_for) .unwrap(); self.state_manager .inner + .lock() + .await .recently_completed_actions .retain(|action| action.completed_time > expiry_time); } - fn find_recently_completed_action( + async fn find_recently_completed_action( &self, unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - self.state_manager + ) -> Result>>, Error> { + Ok(self + .state_manager .inner + .lock() + .await .recently_completed_actions .get(unique_qualifier) - .map(|action| watch::channel(action.state.clone()).1) + .map(|action| watch::channel(action.state.clone()).1)) } async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + ) -> Result>>, Error> { let filter_result = ::filter_operations( &self.state_manager, OperationFilter { @@ -131,22 +136,36 @@ impl SimpleSchedulerImpl { ) .await; - let mut stream = filter_result.ok()?; + let mut stream = filter_result + .err_tip(|| "In SimpleScheduler::find_existing_action getting filter result")?; if let Some(result) = stream.next().await { - result.as_receiver().await.ok().cloned() + Ok(Some( + result + .as_receiver() + .await + .err_tip(|| "In SimpleScheduler::find_existing_action getting receiver")? + .clone(), + )) } else { - None + Ok(None) } } - fn retry_action(&mut self, action_info: &Arc, worker_id: &WorkerId, err: Error) { - match self.state_manager.inner.active_actions.remove(action_info) { + fn retry_action( + inner_state: &mut MutexGuard<'_, StateManagerImpl>, + max_job_retries: usize, + metrics: &Metrics, + action_info: &Arc, + worker_id: &WorkerId, + err: Error, + ) { + match inner_state.active_actions.remove(action_info) { Some(running_action) => { let mut awaited_action = running_action; - let send_result = if awaited_action.attempts >= self.max_job_retries { - self.metrics.retry_action_max_attempts_reached.inc(); + let send_result = if awaited_action.attempts >= max_job_retries { + metrics.retry_action_max_attempts_reached.inc(); - StateManager::mutate_stage(&mut awaited_action, ActionStage::Completed(ActionResult { + mutate_stage(&mut awaited_action, ActionStage::Completed(ActionResult { execution_metadata: ExecutionMetadata { worker: format!("{worker_id}"), ..ExecutionMetadata::default() @@ -160,22 +179,17 @@ impl SimpleSchedulerImpl { // Do not put the action back in the queue here, as this action attempted to run too many // times. } else { - self.metrics.retry_action.inc(); - let send_result = - StateManager::mutate_stage(&mut awaited_action, ActionStage::Queued); - self.state_manager - .inner - .queued_actions_set - .insert(action_info.clone()); - self.state_manager - .inner + metrics.retry_action.inc(); + let send_result = mutate_stage(&mut awaited_action, ActionStage::Queued); + inner_state.queued_actions_set.insert(action_info.clone()); + inner_state .queued_actions .insert(action_info.clone(), awaited_action); send_result }; if send_result.is_err() { - self.metrics.retry_action_no_more_listeners.inc(); + metrics.retry_action_no_more_listeners.inc(); // Don't remove this task, instead we keep them around for a bit just in case // the client disconnected and will reconnect and ask for same job to be executed // again. @@ -188,7 +202,7 @@ impl SimpleSchedulerImpl { } } None => { - self.metrics.retry_action_but_action_missing.inc(); + metrics.retry_action_but_action_missing.inc(); event!( Level::ERROR, ?action_info, @@ -200,40 +214,50 @@ impl SimpleSchedulerImpl { } /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it. - fn immediate_evict_worker(&mut self, worker_id: &WorkerId, err: Error) { - if let Some(mut worker) = self.state_manager.inner.workers.remove_worker(worker_id) { - self.metrics.workers_evicted.inc(); + fn immediate_evict_worker( + inner_state: &mut MutexGuard<'_, StateManagerImpl>, + max_job_retries: usize, + metrics: &Metrics, + worker_id: &WorkerId, + err: Error, + ) { + if let Some(mut worker) = inner_state.workers.remove_worker(worker_id) { + metrics.workers_evicted.inc(); // We don't care if we fail to send message to worker, this is only a best attempt. let _ = worker.notify_update(WorkerUpdate::Disconnect); // We create a temporary Vec to avoid doubt about a possible code // path touching the worker.running_action_infos elsewhere. for action_info in worker.running_action_infos.drain() { - self.metrics.workers_evicted_with_running_action.inc(); - self.retry_action(&action_info, worker_id, err.clone()); + metrics.workers_evicted_with_running_action.inc(); + SimpleSchedulerImpl::retry_action( + inner_state, + max_job_retries, + metrics, + &action_info, + worker_id, + err.clone(), + ); } } // Note: Calling this many time is very cheap, it'll only trigger `do_try_match` once. - self.state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); + inner_state.tasks_change_notify.notify_one(); } /// Sets if the worker is draining or not. - fn set_drain_worker(&mut self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> { - let worker = self - .state_manager - .inner + async fn set_drain_worker( + &mut self, + worker_id: WorkerId, + is_draining: bool, + ) -> Result<(), Error> { + let mut inner_state = self.state_manager.inner.lock().await; + let worker = inner_state .workers .workers .get_mut(&worker_id) .err_tip(|| format!("Worker {worker_id} doesn't exist in the pool"))?; self.metrics.workers_drained.inc(); worker.is_draining = is_draining; - self.state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); + inner_state.tasks_change_notify.notify_one(); Ok(()) } @@ -297,15 +321,16 @@ impl SimpleSchedulerImpl { }; let maybe_worker_id: Option = { - self.state_manager - .inner + let inner_state = self.state_manager.inner.lock().await; + + inner_state .workers .find_worker_for_action(&action_info.platform_properties) }; let operation_id = state.id.clone(); let ret = ::update_operation( - &mut self.state_manager, + &self.state_manager, operation_id.clone(), maybe_worker_id, Ok(ActionStage::Executing), @@ -335,7 +360,7 @@ impl SimpleSchedulerImpl { action_stage: Result, ) -> Result<(), Error> { let update_operation_result = ::update_operation( - &mut self.state_manager, + &self.state_manager, OperationId::new(action_info_hash_key.clone()), *worker_id, action_stage, @@ -410,7 +435,7 @@ impl SimpleScheduler { max_job_retries = DEFAULT_MAX_JOB_RETRIES; } - let tasks_or_workers_change_notify = Arc::new(Notify::new()); + let tasks_change_notify = Arc::new(Notify::new()); let state_manager = StateManager::new( HashSet::new(), BTreeMap::new(), @@ -419,7 +444,7 @@ impl SimpleScheduler { HashSet::new(), Arc::new(SchedulerMetrics::default()), max_job_retries, - tasks_or_workers_change_notify.clone(), + tasks_change_notify.clone(), ); let metrics = Arc::new(Metrics::default()); let metrics_for_do_try_match = metrics.clone(); @@ -439,7 +464,7 @@ impl SimpleScheduler { async move { // Break out of the loop only when the inner is dropped. loop { - tasks_or_workers_change_notify.notified().await; + tasks_change_notify.notified().await; match weak_inner.upgrade() { // Note: According to `parking_lot` documentation, the default // `Mutex` implementation is eventual fairness, so we don't @@ -467,13 +492,9 @@ impl SimpleScheduler { /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests. #[must_use] pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool { - let inner = self.get_inner_lock().await; - inner - .state_manager - .inner - .workers - .workers - .contains(worker_id) + let inner_scheduler = self.get_inner_lock().await; + let inner_state = inner_scheduler.state_manager.inner.lock().await; + inner_state.workers.workers.contains(worker_id) } /// A unit test function used to send the keep alive message to the worker from the server. @@ -481,10 +502,9 @@ impl SimpleScheduler { &self, worker_id: &WorkerId, ) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - let worker = inner - .state_manager - .inner + let inner_scheduler = self.get_inner_lock().await; + let mut inner_state = inner_scheduler.state_manager.inner.lock().await; + let worker = inner_state .workers .workers .get_mut(worker_id) @@ -532,24 +552,32 @@ impl ActionScheduler for SimpleScheduler { async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + ) -> Result>>, Error> { let inner = self.get_inner_lock().await; - let result = inner + let maybe_receiver = inner .find_existing_action(unique_qualifier) + .and_then(|maybe_action| async { + if let Some(action) = maybe_action { + Ok(Some(action)) + } else { + inner.find_recently_completed_action(unique_qualifier).await + } + }) .await - .or_else(|| inner.find_recently_completed_action(unique_qualifier)); - if result.is_some() { + .err_tip(|| "Error while finding existing action")?; + if maybe_receiver.is_some() { self.metrics.existing_actions_found.inc(); } else { self.metrics.existing_actions_not_found.inc(); } - result + Ok(maybe_receiver) } async fn clean_recently_completed_actions(&self) { self.get_inner_lock() .await - .clean_recently_completed_actions(); + .clean_recently_completed_actions() + .await; self.metrics.clean_recently_completed_actions.inc() } @@ -566,22 +594,24 @@ impl WorkerScheduler for SimpleScheduler { async fn add_worker(&self, worker: Worker) -> Result<(), Error> { let worker_id = worker.id; - let mut inner = self.get_inner_lock().await; + let inner_scheduler = self.get_inner_lock().await; + let max_job_retries = inner_scheduler.max_job_retries; + let mut inner_state = inner_scheduler.state_manager.inner.lock().await; self.metrics.add_worker.wrap(move || { - let res = inner - .state_manager - .inner + let res = inner_state .workers .add_worker(worker) .err_tip(|| "Error while adding worker, removing from pool"); if let Err(err) = &res { - inner.immediate_evict_worker(&worker_id, err.clone()); + SimpleSchedulerImpl::immediate_evict_worker( + &mut inner_state, + max_job_retries, + &self.metrics, + &worker_id, + err.clone(), + ); } - inner - .state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); + inner_state.tasks_change_notify.notify_one(); res }) } @@ -604,37 +634,42 @@ impl WorkerScheduler for SimpleScheduler { worker_id: &WorkerId, timestamp: WorkerTimestamp, ) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - inner - .state_manager - .inner + let inner_scheduler = self.get_inner_lock().await; + let mut inner_state = inner_scheduler.state_manager.inner.lock().await; + inner_state .workers .refresh_lifetime(worker_id, timestamp) .err_tip(|| "Error refreshing lifetime in worker_keep_alive_received()") } async fn remove_worker(&self, worker_id: WorkerId) { - let mut inner = self.get_inner_lock().await; - inner.immediate_evict_worker( + let inner_scheduler = self.get_inner_lock().await; + let mut inner_state = inner_scheduler.state_manager.inner.lock().await; + SimpleSchedulerImpl::immediate_evict_worker( + &mut inner_state, + inner_scheduler.max_job_retries, + &inner_scheduler.metrics, &worker_id, make_err!(Code::Internal, "Received request to remove worker"), ); } async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; + let inner_scheduler = self.get_inner_lock().await; + let worker_timeout_s = inner_scheduler.worker_timeout_s; + let max_job_retries = inner_scheduler.max_job_retries; + let metrics = inner_scheduler.metrics.clone(); + let mut inner_state = inner_scheduler.state_manager.inner.lock().await; self.metrics.remove_timedout_workers.wrap(move || { // Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire // map most of the time. - let worker_ids_to_remove: Vec = inner - .state_manager - .inner + let worker_ids_to_remove: Vec = inner_state .workers .workers .iter() .rev() .map_while(|(worker_id, worker)| { - if worker.last_update_timestamp <= now_timestamp - inner.worker_timeout_s { + if worker.last_update_timestamp <= now_timestamp - worker_timeout_s { Some(*worker_id) } else { None @@ -647,7 +682,10 @@ impl WorkerScheduler for SimpleScheduler { ?worker_id, "Worker timed out, removing from pool" ); - inner.immediate_evict_worker( + SimpleSchedulerImpl::immediate_evict_worker( + &mut inner_state, + max_job_retries, + &metrics, worker_id, make_err!( Code::Internal, @@ -662,7 +700,7 @@ impl WorkerScheduler for SimpleScheduler { async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> { let mut inner = self.get_inner_lock().await; - inner.set_drain_worker(worker_id, is_draining) + inner.set_drain_worker(worker_id, is_draining).await } fn register_metrics(self: Arc, _registry: &mut Registry) { @@ -676,45 +714,46 @@ impl MetricsComponent for SimpleScheduler { self.metrics.gather_metrics(c); { // We use the raw lock because we dont gather stats about gathering stats. - let inner = self.inner.lock_blocking(); - inner.state_manager.inner.metrics.gather_metrics(c); + let inner_scheduler = self.inner.lock_blocking(); + let inner_state = inner_scheduler.state_manager.inner.lock_blocking(); + inner_state.metrics.gather_metrics(c); c.publish( "queued_actions_total", - &inner.state_manager.inner.queued_actions.len(), + &inner_state.queued_actions.len(), "The number actions in the queue.", ); c.publish( "workers_total", - &inner.state_manager.inner.workers.workers.len(), + &inner_state.workers.workers.len(), "The number workers active.", ); c.publish( "active_actions_total", - &inner.state_manager.inner.active_actions.len(), + &inner_state.active_actions.len(), "The number of running actions.", ); c.publish( "recently_completed_actions_total", - &inner.state_manager.inner.recently_completed_actions.len(), + &inner_state.recently_completed_actions.len(), "The number of recently completed actions in the buffer.", ); c.publish( "retain_completed_for_seconds", - &inner.retain_completed_for, + &inner_scheduler.retain_completed_for, "The duration completed actions are retained for.", ); c.publish( "worker_timeout_seconds", - &inner.worker_timeout_s, + &inner_scheduler.worker_timeout_s, "The configured timeout if workers have not responded for a while.", ); c.publish( "max_job_retries", - &inner.max_job_retries, + &inner_scheduler.max_job_retries, "The amount of times a job is allowed to retry from an internal error before it is dropped.", ); let mut props = HashMap::<&String, u64>::new(); - for (_worker_id, worker) in inner.state_manager.inner.workers.workers.iter() { + for (_worker_id, worker) in inner_state.workers.workers.iter() { c.publish_with_labels( "workers", worker, @@ -735,7 +774,7 @@ impl MetricsComponent for SimpleScheduler { format!("Total sum of available properties for {property}"), ); } - for (_, active_action) in inner.state_manager.inner.active_actions.iter() { + for (_, active_action) in inner_state.active_actions.iter() { let action_name = active_action .action_info .unique_qualifier diff --git a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs index 709313078..db26482cb 100644 --- a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs +++ b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs @@ -118,9 +118,9 @@ async fn find_existing_action_call_passed() -> Result<(), Error> { }; let (actual_result, actual_action_name) = join!( context.cache_scheduler.find_existing_action(&action_name), - context.mock_scheduler.expect_find_existing_action(None), + context.mock_scheduler.expect_find_existing_action(Ok(None)), ); - assert_eq!(true, actual_result.is_none()); + assert_eq!(true, actual_result.unwrap().is_none()); assert_eq!(action_name, actual_action_name); Ok(()) } diff --git a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs index 96af1fb27..58b91c50b 100644 --- a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs +++ b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs @@ -251,9 +251,9 @@ async fn find_existing_action_call_passed() -> Result<(), Error> { context .modifier_scheduler .find_existing_action(&action_name), - context.mock_scheduler.expect_find_existing_action(None), + context.mock_scheduler.expect_find_existing_action(Ok(None)), ); - assert_eq!(true, actual_result.is_none()); + assert_eq!(true, actual_result.unwrap().is_none()); assert_eq!(action_name, actual_action_name); Ok(()) } diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index 1b4bfcb41..a7993831b 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -179,7 +179,8 @@ async fn find_executing_action() -> Result<(), Error> { let mut client_rx = scheduler .find_existing_action(&unique_qualifier) .await - .err_tip(|| "Action not found")?; + .expect("Action not found") + .unwrap(); { // Worker should have been sent an execute command. @@ -955,7 +956,8 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E let mut client_rx = scheduler .find_existing_action(&unique_qualifier) .await - .err_tip(|| "Action not found")?; + .unwrap() + .expect("Action not found"); { // Client should get notification saying it has been completed. let action_state = client_rx.borrow_and_update(); diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index 9afd1dd6b..b4ca3c2f8 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -31,7 +31,7 @@ enum ActionSchedulerCalls { enum ActionSchedulerReturns { GetPlatformPropertyManager(Result, Error>), AddAction(Result>, Error>), - FindExistingAction(Option>>), + FindExistingAction(Result>>, Error>), } pub struct MockActionScheduler { @@ -100,7 +100,7 @@ impl MockActionScheduler { pub async fn expect_find_existing_action( &self, - result: Option>>, + result: Result>>, Error>, ) -> ActionInfoHashKey { let mut rx_call_lock = self.rx_call.lock().await; let ActionSchedulerCalls::FindExistingAction(req) = rx_call_lock @@ -161,7 +161,7 @@ impl ActionScheduler for MockActionScheduler { async fn find_existing_action( &self, unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + ) -> Result>>, Error> { self.tx_call .send(ActionSchedulerCalls::FindExistingAction( unique_qualifier.clone(), diff --git a/nativelink-service/src/execution_server.rs b/nativelink-service/src/execution_server.rs index a42f0ef03..b385f1db6 100644 --- a/nativelink-service/src/execution_server.rs +++ b/nativelink-service/src/execution_server.rs @@ -253,6 +253,7 @@ impl ExecutionServer { .scheduler .find_existing_action(&operation_id.unique_qualifier) .await + .err_tip(|| "Error running find_existing_action in ExecutionServer::wait_execution")? else { return Err(Status::not_found("Failed to find existing task")); };