diff --git a/Cargo.lock b/Cargo.lock index c0c945860..f943dc7ab 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2053,6 +2053,7 @@ dependencies = [ "nativelink-macro", "nativelink-proto", "parking_lot", + "pin-project", "pin-project-lite", "pretty_assertions", "prometheus-client", diff --git a/nativelink-scheduler/BUILD.bazel b/nativelink-scheduler/BUILD.bazel index 87edb6b4e..68e9f331f 100644 --- a/nativelink-scheduler/BUILD.bazel +++ b/nativelink-scheduler/BUILD.bazel @@ -15,18 +15,15 @@ rust_library( "src/default_scheduler_factory.rs", "src/grpc_scheduler.rs", "src/lib.rs", - "src/operation_state_manager.rs", "src/platform_property_manager.rs", + "src/memory_awaited_action_db.rs", + "src/awaited_action_db/awaited_action.rs", + "src/awaited_action_db/mod.rs", + "src/simple_scheduler_state_manager.rs", "src/property_modifier_scheduler.rs", "src/redis_action_stage.rs", + "src/api_worker_scheduler.rs", "src/redis_operation_state.rs", - "src/scheduler_state/awaited_action.rs", - "src/scheduler_state/client_action_state_result.rs", - "src/scheduler_state/completed_action.rs", - "src/scheduler_state/matching_engine_action_state_result.rs", - "src/scheduler_state/mod.rs", - "src/scheduler_state/state_manager.rs", - "src/scheduler_state/workers.rs", "src/simple_scheduler.rs", "src/worker.rs", "src/worker_scheduler.rs", @@ -42,7 +39,6 @@ rust_library( "//nativelink-store", "//nativelink-util", "@crates//:async-lock", - "@crates//:bitflags", "@crates//:blake3", "@crates//:futures", "@crates//:hashbrown", diff --git a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs new file mode 100644 index 000000000..0cc181a33 --- /dev/null +++ b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs @@ -0,0 +1,224 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::hash::{DefaultHasher, Hash, Hasher}; +use std::sync::Arc; +use std::time::SystemTime; + +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, OperationId, + WorkerId, +}; +use nativelink_util::evicting_map::InstantWrapper; +use static_assertions::{assert_eq_size, const_assert, const_assert_eq}; + +/// An action that is being awaited on and last known state. +#[derive(Debug, Clone)] +pub struct AwaitedAction { + /// The action that is being awaited on. + action_info: Arc, + + /// The operation id of the action. + operation_id: OperationId, + + /// The currentsort key used to order the actions. + sort_key: AwaitedActionSortKey, + + /// The time the action was last updated. + last_worker_updated_timestamp: SystemTime, + + /// Worker that is currently running this action, None if unassigned. + worker_id: Option, + + /// The current state of the action. + state: Arc, + + /// Number of attempts the job has been tried. + pub attempts: usize, + + /// Number of clients listening to the state of the action. + pub connected_clients: usize, +} + +impl AwaitedAction { + pub fn new(operation_id: OperationId, action_info: Arc) -> Self { + let unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key, + ActionUniqueQualifier::Uncachable(unique_key) => unique_key, + }; + let stage = ActionStage::Queued; + let sort_key = AwaitedActionSortKey::new_with_unique_key( + action_info.priority, + &action_info.insert_timestamp, + unique_key, + ); + let state = Arc::new(ActionState { + stage, + id: operation_id.clone(), + }); + Self { + action_info, + operation_id, + sort_key, + attempts: 0, + last_worker_updated_timestamp: SystemTime::now(), + connected_clients: 1, + worker_id: None, + state, + } + } + + pub fn action_info(&self) -> &Arc { + &self.action_info + } + + pub fn operation_id(&self) -> &OperationId { + &self.operation_id + } + + pub fn sort_key(&self) -> AwaitedActionSortKey { + self.sort_key + } + + pub fn state(&self) -> &Arc { + &self.state + } + + pub fn worker_id(&self) -> Option { + self.worker_id + } + + pub fn last_worker_updated_timestamp(&self) -> SystemTime { + self.last_worker_updated_timestamp + } + + /// Sets the worker id that is currently processing this action. + pub fn set_worker_id(&mut self, new_maybe_worker_id: Option) { + if self.worker_id != new_maybe_worker_id { + self.worker_id = new_maybe_worker_id; + self.last_worker_updated_timestamp = SystemTime::now(); + } + } + + /// Sets the current state of the action and notifies subscribers. + /// Returns true if the state was set, false if there are no subscribers. + pub fn set_state(&mut self, mut state: Arc) { + std::mem::swap(&mut self.state, &mut state); + self.last_worker_updated_timestamp = SystemTime::now(); + } +} + +/// The key used to sort the awaited actions. +/// +/// The rules for sorting are as follows: +/// 1. priority of the action +/// 2. insert order of the action (lower = higher priority) +/// 3. (mostly random hash based on the action info) +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct AwaitedActionSortKey(u128); + +impl AwaitedActionSortKey { + #[rustfmt::skip] + const fn new(priority: i32, insert_timestamp: u64, hash: [u8; 4]) -> Self { + // Shift `new_priority` so [`i32::MIN`] is represented by zero. + // This makes it so any nagative values are positive, but + // maintains ordering. + const MIN_I32: i64 = (i32::MIN as i64).abs(); + let priority = ((priority as i64 + MIN_I32) as u32).to_be_bytes(); + + // Invert our timestamp so the larger the timestamp the lower the number. + // This makes timestamp descending order instead of ascending. + let timestamp = (insert_timestamp ^ u64::MAX).to_be_bytes(); + + AwaitedActionSortKey(u128::from_be_bytes([ + priority[0], priority[1], priority[2], priority[3], + timestamp[0], timestamp[1], timestamp[2], timestamp[3], + timestamp[4], timestamp[5], timestamp[6], timestamp[7], + hash[0], hash[1], hash[2], hash[3] + ])) + } + + fn new_with_unique_key( + priority: i32, + insert_timestamp: &SystemTime, + action_hash: &ActionUniqueKey, + ) -> Self { + let hash = { + let mut hasher = DefaultHasher::new(); + ActionUniqueKey::hash(action_hash, &mut hasher); + hasher.finish().to_le_bytes()[0..4].try_into().unwrap() + }; + Self::new(priority, insert_timestamp.unix_timestamp(), hash) + } +} + +// Ensure the size of the sort key is the same as a `u64`. +assert_eq_size!(AwaitedActionSortKey, u128); + +const_assert_eq!( + AwaitedActionSortKey::new(0x1234_5678, 0x9abc_def0_1234_5678, [0x9a, 0xbc, 0xde, 0xf0]).0, + // Note: Result has 0x12345678 + 0x80000000 = 0x92345678 because we need + // to shift the `i32::MIN` value to be represented by zero. + // Note: `6543210fedcba987` are the inverted bits of `9abcdef012345678`. + // This effectively inverts the priority to now have the highest priority + // be the lowest timestamps. + AwaitedActionSortKey(0x9234_5678_6543_210f_edcb_a987_9abc_def0).0 +); +// Ensure the priority is used as the sort key first. +const_assert!( + AwaitedActionSortKey::new(i32::MAX, 0, [0xff; 4]).0 + > AwaitedActionSortKey::new(i32::MAX - 1, 0, [0; 4]).0 +); +const_assert!( + AwaitedActionSortKey::new(i32::MAX - 1, 0, [0xff; 4]).0 + > AwaitedActionSortKey::new(1, 0, [0; 4]).0 +); +const_assert!( + AwaitedActionSortKey::new(1, 0, [0xff; 4]).0 > AwaitedActionSortKey::new(0, 0, [0; 4]).0 +); +const_assert!( + AwaitedActionSortKey::new(0, 0, [0xff; 4]).0 > AwaitedActionSortKey::new(-1, 0, [0; 4]).0 +); +const_assert!( + AwaitedActionSortKey::new(-1, 0, [0xff; 4]).0 + > AwaitedActionSortKey::new(i32::MIN + 1, 0, [0; 4]).0 +); +const_assert!( + AwaitedActionSortKey::new(i32::MIN + 1, 0, [0xff; 4]).0 + > AwaitedActionSortKey::new(i32::MIN, 0, [0; 4]).0 +); + +// Ensure the insert timestamp is used as the sort key second. +const_assert!( + AwaitedActionSortKey::new(0, u64::MIN, [0; 4]).0 + > AwaitedActionSortKey::new(0, u64::MAX, [0; 4]).0 +); + +// Ensure the hash is used as the sort key third. +const_assert!( + AwaitedActionSortKey::new(0, 0, [0xff, 0xff, 0xff, 0xff]).0 + > AwaitedActionSortKey::new(0, 0, [0; 4]).0 +); +const_assert!( + AwaitedActionSortKey::new(1, 0, [0xff, 0xff, 0xff, 0xff]).0 + > AwaitedActionSortKey::new(0, 0, [0; 4]).0 +); +const_assert!( + AwaitedActionSortKey::new(0, 0, [0; 4]).0 > AwaitedActionSortKey::new(0, 1, [0; 4]).0 +); +const_assert!( + AwaitedActionSortKey::new(0, 0, [0xff, 0xff, 0xff, 0xff]).0 + > AwaitedActionSortKey::new(0, 0, [0; 4]).0 +); diff --git a/nativelink-scheduler/src/awaited_action_db/mod.rs b/nativelink-scheduler/src/awaited_action_db/mod.rs new file mode 100644 index 000000000..7878e9e93 --- /dev/null +++ b/nativelink-scheduler/src/awaited_action_db/mod.rs @@ -0,0 +1,120 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp; +use std::ops::Bound; +use std::sync::Arc; + +pub use awaited_action::{AwaitedAction, AwaitedActionSortKey}; +use futures::{Future, Stream}; +use nativelink_error::Error; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId, OperationId}; + +mod awaited_action; + +/// A simple enum to represent the state of an AwaitedAction. +#[derive(Debug, Clone, Copy)] +pub enum SortedAwaitedActionState { + CacheCheck, + Queued, + Executing, + Completed, +} + +/// A struct pointing to an AwaitedAction that can be sorted. +#[derive(Debug, Clone)] +pub struct SortedAwaitedAction { + pub sort_key: AwaitedActionSortKey, + pub operation_id: OperationId, +} + +impl PartialEq for SortedAwaitedAction { + fn eq(&self, other: &Self) -> bool { + self.sort_key == other.sort_key && self.operation_id == other.operation_id + } +} + +impl Eq for SortedAwaitedAction {} + +impl PartialOrd for SortedAwaitedAction { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SortedAwaitedAction { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.sort_key + .cmp(&other.sort_key) + .then_with(|| self.operation_id.cmp(&other.operation_id)) + } +} + +/// Subscriber that can be used to monitor when AwaitedActions change. +pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static { + /// Wait for AwaitedAction to change. + fn changed(&mut self) -> impl Future> + Send; + + /// Get the current awaited action. + fn borrow(&self) -> AwaitedAction; +} + +/// A trait that defines the interface for an AwaitedActionDb. +pub trait AwaitedActionDb: Send + Sync { + type Subscriber: AwaitedActionSubscriber; + + /// Get the AwaitedAction by the client operation id. + fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> impl Future, Error>> + Send + Sync; + + /// Get all AwaitedActions. This call should be avoided as much as possible. + fn get_all_awaited_actions( + &self, + ) -> impl Future> + Send + Sync> + + Send + + Sync; + + /// Get the AwaitedAction by the operation id. + fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> impl Future, Error>> + Send + Sync; + + /// Get a range of AwaitedActions of a specific state in sorted order. + fn get_range_of_actions( + &self, + state: SortedAwaitedActionState, + start: Bound, + end: Bound, + desc: bool, + ) -> impl Future> + Send + Sync> + + Send + + Sync; + + /// Process a change changed AwaitedAction and notify any listeners. + fn update_awaited_action( + &self, + new_awaited_action: AwaitedAction, + ) -> impl Future> + Send + Sync; + + /// Add (or join) an action to the AwaitedActionDb and subscribe + /// to changes. + fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> impl Future> + Send + Sync; +} diff --git a/nativelink-scheduler/src/lib.rs b/nativelink-scheduler/src/lib.rs index c60675621..0c731ff09 100644 --- a/nativelink-scheduler/src/lib.rs +++ b/nativelink-scheduler/src/lib.rs @@ -14,15 +14,17 @@ pub mod action_scheduler; pub mod api_worker_scheduler; +mod awaited_action_db; pub mod cache_lookup_scheduler; pub mod default_action_listener; pub mod default_scheduler_factory; pub mod grpc_scheduler; -pub mod memory_scheduler_state; +mod memory_awaited_action_db; pub mod platform_property_manager; pub mod property_modifier_scheduler; pub mod redis_action_stage; pub mod redis_operation_state; pub mod simple_scheduler; +mod simple_scheduler_state_manager; pub mod worker; pub mod worker_scheduler; diff --git a/nativelink-scheduler/src/memory_awaited_action_db.rs b/nativelink-scheduler/src/memory_awaited_action_db.rs new file mode 100644 index 000000000..d454f9335 --- /dev/null +++ b/nativelink-scheduler/src/memory_awaited_action_db.rs @@ -0,0 +1,906 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::ops::{Bound, RangeBounds}; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime}; + +use async_lock::Mutex; +use async_trait::async_trait; +use futures::{FutureExt, Stream}; +use nativelink_config::stores::EvictionPolicy; +use nativelink_error::{error_if, make_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, OperationId, +}; +use nativelink_util::chunked_stream::ChunkedStream; +use nativelink_util::evicting_map::{EvictingMap, LenEntry}; +use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::spawn; +use nativelink_util::task::JoinHandleDropGuard; +use tokio::sync::{mpsc, watch}; +use tracing::{event, Level}; + +use crate::awaited_action_db::{ + AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedAction, + SortedAwaitedActionState, +}; + +/// Number of events to process per cycle. +const MAX_ACTION_EVENTS_RX_PER_CYCLE: usize = 1024; + +/// Duration to wait before sending client keep alive messages. +const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10); + +/// Represents a client that is currently listening to an action. +/// When the client is dropped, it will send the [`AwaitedAction`] to the +/// `drop_tx` if there are other cleanups needed. +#[derive(Debug)] +struct ClientAwaitedAction { + /// The OperationId that the client is listening to. + operation_id: OperationId, + + /// The sender to notify of this struct being dropped. + drop_tx: mpsc::UnboundedSender, +} + +impl ClientAwaitedAction { + pub fn new(operation_id: OperationId, drop_tx: mpsc::UnboundedSender) -> Self { + Self { + operation_id, + drop_tx, + } + } + + pub fn operation_id(&self) -> &OperationId { + &self.operation_id + } +} + +impl Drop for ClientAwaitedAction { + fn drop(&mut self) { + // If we failed to send it means noone is listening. + let _ = self.drop_tx.send(ActionEvent::ClientDroppedOperation( + self.operation_id.clone(), + )); + } +} + +/// Trait to be able to use the EvictingMap with [`ClientAwaitedAction`]. +/// Note: We only use EvictingMap for a time based eviction, which is +/// why the implementation has fixed default values in it. +impl LenEntry for ClientAwaitedAction { + #[inline] + fn len(&self) -> usize { + 0 + } + + #[inline] + fn is_empty(&self) -> bool { + true + } +} + +/// Actions the AwaitedActionsDb needs to process. +pub(crate) enum ActionEvent { + /// A client has sent a keep alive message. + ClientKeepAlive(ClientOperationId), + /// A client has dropped and pointed to OperationId. + ClientDroppedOperation(OperationId), +} + +/// Information required to track an individual client +/// keep alive config and state. +struct ClientKeepAlive { + /// The client operation id. + client_operation_id: ClientOperationId, + /// The last time a keep alive was sent. + last_keep_alive: Instant, + /// The sender to notify of this struct being dropped. + drop_tx: mpsc::UnboundedSender, +} + +/// Subscriber that can be used to monitor when AwaitedActions change. +pub struct MemoryAwaitedActionSubscriber { + /// The receiver to listen for changes. + awaited_action_rx: watch::Receiver, + /// The client operation id and keep alive information. + client_operation_info: Option, +} + +impl MemoryAwaitedActionSubscriber { + pub fn new(mut awaited_action_rx: watch::Receiver) -> Self { + awaited_action_rx.mark_changed(); + Self { + awaited_action_rx, + client_operation_info: None, + } + } + + pub fn new_with_client( + mut awaited_action_rx: watch::Receiver, + client_operation_id: ClientOperationId, + drop_tx: mpsc::UnboundedSender, + ) -> Self { + awaited_action_rx.mark_changed(); + Self { + awaited_action_rx, + client_operation_info: Some(ClientKeepAlive { + client_operation_id, + last_keep_alive: Instant::now(), + drop_tx, + }), + } + } +} + +impl AwaitedActionSubscriber for MemoryAwaitedActionSubscriber { + async fn changed(&mut self) -> Result { + { + let changed_fut = self.awaited_action_rx.changed().map(|r| { + r.map_err(|e| { + make_err!( + Code::Internal, + "Failed to wait for awaited action to change {e:?}" + ) + }) + }); + let Some(client_keep_alive) = self.client_operation_info.as_mut() else { + changed_fut.await?; + return Ok(self.awaited_action_rx.borrow().clone()); + }; + tokio::pin!(changed_fut); + loop { + if client_keep_alive.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION { + client_keep_alive.last_keep_alive = Instant::now(); + // Failing to send just means our receiver dropped. + let _ = client_keep_alive.drop_tx.send(ActionEvent::ClientKeepAlive( + client_keep_alive.client_operation_id.clone(), + )); + } + tokio::select! { + result = &mut changed_fut => { + result?; + break; + } + _ = tokio::time::sleep(CLIENT_KEEPALIVE_DURATION) => { + // If we haven't received any updates for a while, we should + // let the database know that we are still listening to prevent + // the action from being dropped. + } + + } + } + } + Ok(self.awaited_action_rx.borrow().clone()) + } + + fn borrow(&self) -> AwaitedAction { + self.awaited_action_rx.borrow().clone() + } +} + +pub struct MatchingEngineActionStateResult { + awaited_action_sub: T, +} +impl MatchingEngineActionStateResult { + pub fn new(awaited_action_sub: T) -> Self { + Self { awaited_action_sub } + } +} + +#[async_trait] +impl ActionStateResult for MatchingEngineActionStateResult { + async fn as_state(&self) -> Result, Error> { + Ok(self.awaited_action_sub.borrow().state().clone()) + } + + async fn changed(&mut self) -> Result, Error> { + let awaited_action = self.awaited_action_sub.changed().await.map_err(|e| { + make_err!( + Code::Internal, + "Failed to wait for awaited action to change {e:?}" + ) + })?; + Ok(awaited_action.state().clone()) + } + + async fn as_action_info(&self) -> Result, Error> { + Ok(self.awaited_action_sub.borrow().action_info().clone()) + } +} + +pub(crate) struct ClientActionStateResult { + inner: MatchingEngineActionStateResult, +} + +impl ClientActionStateResult { + pub fn new(sub: T) -> Self { + Self { + inner: MatchingEngineActionStateResult::new(sub), + } + } +} + +#[async_trait] +impl ActionStateResult for ClientActionStateResult { + async fn as_state(&self) -> Result, Error> { + self.inner.as_state().await + } + + async fn changed(&mut self) -> Result, Error> { + self.inner.changed().await + } + + async fn as_action_info(&self) -> Result, Error> { + self.inner.as_action_info().await + } +} + +/// A struct that is used to keep the devloper from trying to +/// return early from a function. +struct NoEarlyReturn; + +#[derive(Default)] +struct SortedAwaitedActions { + unknown: BTreeSet, + cache_check: BTreeSet, + queued: BTreeSet, + executing: BTreeSet, + completed: BTreeSet, +} + +impl SortedAwaitedActions { + fn btree_for_state(&mut self, state: &ActionStage) -> &mut BTreeSet { + match state { + ActionStage::Unknown => &mut self.unknown, + ActionStage::CacheCheck => &mut self.cache_check, + ActionStage::Queued => &mut self.queued, + ActionStage::Executing => &mut self.executing, + ActionStage::Completed(_) => &mut self.completed, + ActionStage::CompletedFromCache(_) => &mut self.completed, + } + } + + fn insert_sort_map_for_stage( + &mut self, + stage: &ActionStage, + sorted_awaited_action: SortedAwaitedAction, + ) -> Result<(), Error> { + let newly_inserted = match stage { + ActionStage::Unknown => self.unknown.insert(sorted_awaited_action.clone()), + ActionStage::CacheCheck => self.cache_check.insert(sorted_awaited_action.clone()), + ActionStage::Queued => self.queued.insert(sorted_awaited_action.clone()), + ActionStage::Executing => self.executing.insert(sorted_awaited_action.clone()), + ActionStage::Completed(_) => self.completed.insert(sorted_awaited_action.clone()), + ActionStage::CompletedFromCache(_) => { + self.completed.insert(sorted_awaited_action.clone()) + } + }; + if !newly_inserted { + return Err(make_err!( + Code::Internal, + "Tried to insert an action that was already in the sorted map. This should never happen. {:?} - {:?}", + stage, + sorted_awaited_action + )); + } + Ok(()) + } + + fn process_state_changes( + &mut self, + old_awaited_action: &AwaitedAction, + new_awaited_action: &AwaitedAction, + ) -> Result<(), Error> { + let btree = self.btree_for_state(&old_awaited_action.state().stage); + let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction { + sort_key: old_awaited_action.sort_key(), + operation_id: new_awaited_action.operation_id().clone(), + }); + + let Some(sorted_awaited_action) = maybe_sorted_awaited_action else { + return Err(make_err!( + Code::Internal, + "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync - {} - {:?}", + new_awaited_action.operation_id(), + new_awaited_action, + )); + }; + + self.insert_sort_map_for_stage(&new_awaited_action.state().stage, sorted_awaited_action) + .err_tip(|| "In AwaitedActionDb::update_awaited_action")?; + Ok(()) + } +} + +/// The database for storing the state of all actions. +pub struct AwaitedActionDbImpl { + /// A lookup table to lookup the state of an action by its client operation id. + client_operation_to_awaited_action: + EvictingMap, SystemTime>, + + /// A lookup table to lookup the state of an action by its worker operation id. + operation_id_to_awaited_action: BTreeMap>, + + /// A lookup table to lookup the state of an action by its unique qualifier. + action_info_hash_key_to_awaited_action: HashMap, + + /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting + /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`]. + /// + /// See [`AwaitedActionSortKey`] for more information on the ordering. + sorted_action_info_hash_keys: SortedAwaitedActions, + + action_event_tx: mpsc::UnboundedSender, +} + +impl AwaitedActionDbImpl { + async fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> Result, Error> { + let maybe_client_awaited_action = self + .client_operation_to_awaited_action + .get(client_operation_id) + .await; + let client_awaited_action = match maybe_client_awaited_action { + Some(client_awaited_action) => client_awaited_action, + None => return Ok(None), + }; + + self.operation_id_to_awaited_action + .get(client_awaited_action.operation_id()) + .map(|tx| Some(MemoryAwaitedActionSubscriber::new(tx.subscribe()))) + .ok_or_else(|| { + make_err!( + Code::Internal, + "Failed to get client operation id {client_operation_id:?}" + ) + }) + } + + /// Processes action events that need to be handled by the database. + async fn handle_action_events( + &mut self, + action_events: impl IntoIterator, + ) -> NoEarlyReturn { + for drop_action in action_events.into_iter() { + match drop_action { + ActionEvent::ClientDroppedOperation(operation_id) => { + // Cleanup operation_id_to_awaited_action. + let Some(tx) = self.operation_id_to_awaited_action.remove(&operation_id) else { + event!( + Level::ERROR, + ?operation_id, + "operation_id_to_awaited_action does not have operation_id", + ); + continue; + }; + let mut connected_clients = 0; + // Note: We use this trick to modify the value, but we don't actually + // want to notify any listeners of the change. + tx.send_if_modified(|awaited_action| { + awaited_action.connected_clients -= 1; + connected_clients = awaited_action.connected_clients; + false + }); + // Note: It is rare to have more than one client listening + // to the same action, so we assume that we are the last + // client and insert it back into the map if we detect that + // there are still clients listening (ie: the happy path + // is operation.connected_clients == 0). + if connected_clients != 0 { + self.operation_id_to_awaited_action.insert(operation_id, tx); + continue; + } + let awaited_action = tx.borrow().clone(); + // Cleanup action_info_hash_key_to_awaited_action if it was marked cached. + match &awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(action_key) => { + let maybe_awaited_action = self + .action_info_hash_key_to_awaited_action + .remove(action_key); + if !awaited_action.state().stage.is_finished() + && maybe_awaited_action.is_none() + { + event!( + Level::ERROR, + ?operation_id, + ?awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", + ); + } + } + ActionUniqueQualifier::Uncachable(_action_key) => { + // This Operation should not be in the hash_key map. + } + } + + // Cleanup sorted_awaited_action. + let sort_key = awaited_action.sort_key(); + let sort_btree_for_state = self + .sorted_action_info_hash_keys + .btree_for_state(&awaited_action.state().stage); + + let maybe_sorted_awaited_action = + sort_btree_for_state.take(&SortedAwaitedAction { + sort_key, + operation_id: operation_id.clone(), + }); + if maybe_sorted_awaited_action.is_none() { + event!( + Level::ERROR, + ?operation_id, + ?sort_key, + "Expected maybe_sorted_awaited_action to have {sort_key:?}", + ); + } + } + ActionEvent::ClientKeepAlive(client_id) => { + let maybe_size = self + .client_operation_to_awaited_action + .size_for_key(&client_id) + .await; + if maybe_size.is_none() { + event!( + Level::ERROR, + ?client_id, + "client_operation_to_awaited_action does not have client_id", + ); + } + } + } + } + NoEarlyReturn + } + + fn get_awaited_actions_range( + &self, + start: Bound<&OperationId>, + end: Bound<&OperationId>, + ) -> impl Iterator { + self.operation_id_to_awaited_action + .range((start, end)) + .map(|(operation_id, tx)| { + ( + operation_id, + MemoryAwaitedActionSubscriber::new(tx.subscribe()), + ) + }) + } + + fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> Option { + self.operation_id_to_awaited_action + .get(operation_id) + .map(|tx| MemoryAwaitedActionSubscriber::new(tx.subscribe())) + } + + // TODO!(rename) + fn get_range_of_actions<'a, 'b>( + &'a self, + state: SortedAwaitedActionState, + range: impl RangeBounds + 'b, + ) -> impl DoubleEndedIterator< + Item = Result<(&'a SortedAwaitedAction, MemoryAwaitedActionSubscriber), Error>, + > + 'a { + let btree = match state { + SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check, + SortedAwaitedActionState::Queued => &self.sorted_action_info_hash_keys.queued, + SortedAwaitedActionState::Executing => &self.sorted_action_info_hash_keys.executing, + SortedAwaitedActionState::Completed => &self.sorted_action_info_hash_keys.completed, + }; + btree.range(range).map(|sorted_awaited_action| { + let operation_id = &sorted_awaited_action.operation_id; + self.get_by_operation_id(operation_id) + .ok_or_else(|| { + make_err!( + Code::Internal, + "Failed to get operation id {}", + operation_id + ) + }) + .map(|subscriber| (sorted_awaited_action, subscriber)) + }) + } + + fn process_state_changes_for_hash_key_map( + action_info_hash_key_to_awaited_action: &mut HashMap, + new_awaited_action: &AwaitedAction, + ) -> Result<(), Error> { + // Do not allow future subscribes if the action is already completed, + // this is the responsibility of the CacheLookupScheduler. + // TODO(allad) Once we land the new scheduler onto main, we can remove this check. + // It makes sense to allow users to subscribe to already completed items. + // This can be changed to `.is_error()` later. + if !new_awaited_action.state().stage.is_finished() { + return Ok(()); + } + match &new_awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(action_key) => { + let maybe_awaited_action = + action_info_hash_key_to_awaited_action.remove(action_key); + match maybe_awaited_action { + Some(removed_operation_id) => { + if &removed_operation_id != new_awaited_action.operation_id() { + event!( + Level::ERROR, + ?removed_operation_id, + ?new_awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", + ); + } + } + None => { + event!( + Level::ERROR, + ?new_awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action out of sync, it should have had the unique_key", + ); + } + } + Ok(()) + } + ActionUniqueQualifier::Uncachable(_action_key) => { + // If we are not cachable, the action should not be in the + // hash_key map, so we don't need to process anything in + // action_info_hash_key_to_awaited_action. + Ok(()) + } + } + } + + fn update_awaited_action(&mut self, new_awaited_action: AwaitedAction) -> Result<(), Error> { + let tx = self + .operation_id_to_awaited_action + .get(new_awaited_action.operation_id()) + .ok_or_else(|| { + make_err!( + Code::Internal, + "OperationId does not exist in map in AwaitedActionDb::update_awaited_action" + ) + })?; + { + // Note: It's important to drop old_awaited_action before we call + // send_replace or we will have a deadlock. + let old_awaited_action = tx.borrow(); + error_if!( + old_awaited_action.action_info().unique_qualifier + != new_awaited_action.action_info().unique_qualifier, + "Unique key changed for operation_id {:?} - {:?} - {:?}", + new_awaited_action.operation_id(), + old_awaited_action.action_info(), + new_awaited_action.action_info(), + ); + let is_same_stage = old_awaited_action + .state() + .stage + .is_same_stage(&new_awaited_action.state().stage); + + // TODO!(Handle priority changes here). + if !is_same_stage { + self.sorted_action_info_hash_keys + .process_state_changes(&old_awaited_action, &new_awaited_action)?; + Self::process_state_changes_for_hash_key_map( + &mut self.action_info_hash_key_to_awaited_action, + &new_awaited_action, + )?; + } + } + + // Notify all listeners of the new state and ignore if no one is listening. + // Note: Do not use `.send()` as it will not update the state if all listeners + // are dropped. + let _ = tx.send_replace(new_awaited_action); + + Ok(()) + } + + /// Creates a new [`ClientAwaitedAction`] and a [`watch::Receiver`] to + /// listen for changes. We don't do this in-line because it is important + /// to ALWAYS construct a [`ClientAwaitedAction`] before inserting it into + /// the map. Failing to do so may result in memory leaks. This is because + /// [`ClientAwaitedAction`] implements a drop function that will trigger + /// cleanup of the other maps on drop. + fn make_client_awaited_action( + &mut self, + operation_id: OperationId, + awaited_action: AwaitedAction, + ) -> (Arc, watch::Receiver) { + let (tx, rx) = watch::channel(awaited_action); + let client_awaited_action = Arc::new(ClientAwaitedAction::new( + operation_id.clone(), + self.action_event_tx.clone(), + )); + self.operation_id_to_awaited_action + .insert(operation_id.clone(), tx); + (client_awaited_action, rx) + } + + async fn add_action( + &mut self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + // Check to see if the action is already known and subscribe if it is. + let subscription_result = self + .try_subscribe( + &client_operation_id, + &action_info.unique_qualifier, + action_info.priority, + ) + .await + .err_tip(|| "In AwaitedActionDb::subscribe_or_add_action"); + match subscription_result { + Err(err) => return Err(err), + Ok(Some(subscription)) => return Ok(subscription), + Ok(None) => { /* Add item to queue. */ } + } + + let maybe_unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()), + ActionUniqueQualifier::Uncachable(_unique_key) => None, + }; + let operation_id = OperationId::new(action_info.unique_qualifier.clone()); + let awaited_action = AwaitedAction::new(operation_id.clone(), action_info); + debug_assert!( + awaited_action.connected_clients == 1, + "Expected connected_clients to be 1" + ); + debug_assert!( + ActionStage::Queued == awaited_action.state().stage, + "Expected action to be queued" + ); + let sort_key = awaited_action.sort_key(); + + let (client_awaited_action, rx) = + self.make_client_awaited_action(operation_id.clone(), awaited_action); + + self.client_operation_to_awaited_action + .insert(client_operation_id.clone(), client_awaited_action) + .await; + + // Note: We only put items in the map that are cachable. + if let Some(unique_key) = maybe_unique_key { + let old_value = self + .action_info_hash_key_to_awaited_action + .insert(unique_key, operation_id.clone()); + if let Some(old_value) = old_value { + event!( + Level::ERROR, + ?operation_id, + ?old_value, + "action_info_hash_key_to_awaited_action already has unique_key" + ); + } + } + + self.sorted_action_info_hash_keys + .insert_sort_map_for_stage( + &ActionStage::Queued, + SortedAwaitedAction { + sort_key, + operation_id, + }, + ) + .err_tip(|| "In AwaitedActionDb::subscribe_or_add_action")?; + + Ok(MemoryAwaitedActionSubscriber::new_with_client( + rx, + client_operation_id, + self.action_event_tx.clone(), + )) + } + + async fn try_subscribe( + &mut self, + client_operation_id: &ClientOperationId, + unique_qualifier: &ActionUniqueQualifier, + // TODO!() + _priority: i32, + ) -> Result, Error> { + let unique_key = match unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key, + ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None), + }; + + let Some(operation_id) = self.action_info_hash_key_to_awaited_action.get(unique_key) else { + return Ok(None); // Not currently running. + }; + + let Some(tx) = self.operation_id_to_awaited_action.get(operation_id) else { + return Err(make_err!( + Code::Internal, + "operation_id_to_awaited_action and action_info_hash_key_to_awaited_action are out of sync for {unique_key:?} - {operation_id}" + )); + }; + + error_if!( + tx.borrow().state().stage.is_finished(), + "Tried to subscribe to a completed action but it already finished. This should never happen. {:?}", + tx.borrow() + ); + + let subscription = tx.subscribe(); + // Note: We use this trick to modify the value, but we don't actually + // want to notify any listeners of the change. + tx.send_if_modified(|awaited_action| { + awaited_action.connected_clients += 1; + false + }); + + self.client_operation_to_awaited_action + .insert( + client_operation_id.clone(), + Arc::new(ClientAwaitedAction::new( + operation_id.clone(), + self.action_event_tx.clone(), + )), + ) + .await; + + Ok(Some(MemoryAwaitedActionSubscriber::new(subscription))) + } +} + +pub struct MemoryAwaitedActionDb { + inner: Arc>, + _handle_awaited_action_events: JoinHandleDropGuard<()>, +} + +impl MemoryAwaitedActionDb { + pub fn new(eviction_config: &EvictionPolicy) -> Self { + let (action_event_tx, mut action_event_rx) = mpsc::unbounded_channel(); + let inner = Arc::new(Mutex::new(AwaitedActionDbImpl { + client_operation_to_awaited_action: EvictingMap::new( + eviction_config, + SystemTime::now(), + ), + operation_id_to_awaited_action: BTreeMap::new(), + action_info_hash_key_to_awaited_action: HashMap::new(), + sorted_action_info_hash_keys: SortedAwaitedActions::default(), + action_event_tx, + })); + let weak_inner = Arc::downgrade(&inner); + Self { + inner, + _handle_awaited_action_events: spawn!("handle_awaited_action_events", async move { + let mut dropped_operation_ids = Vec::with_capacity(MAX_ACTION_EVENTS_RX_PER_CYCLE); + loop { + dropped_operation_ids.clear(); + action_event_rx + .recv_many(&mut dropped_operation_ids, MAX_ACTION_EVENTS_RX_PER_CYCLE) + .await; + let Some(inner) = weak_inner.upgrade() else { + return; // Nothing to cleanup, our struct is dropped. + }; + let mut inner = inner.lock().await; + inner + .handle_action_events(dropped_operation_ids.drain(..)) + .await; + } + }), + } + } +} + +impl AwaitedActionDb for MemoryAwaitedActionDb { + type Subscriber = MemoryAwaitedActionSubscriber; + + async fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> Result, Error> { + self.inner + .lock() + .await + .get_awaited_action_by_id(client_operation_id) + .await + } + + async fn get_all_awaited_actions(&self) -> impl Stream> { + ChunkedStream::new( + Bound::Unbounded, + Bound::Unbounded, + move |start, end, mut output| async move { + let inner = self.inner.lock().await; + let mut maybe_new_start = None; + + for (operation_id, item) in + inner.get_awaited_actions_range(start.as_ref(), end.as_ref()) + { + output.push_back(item); + maybe_new_start = Some(operation_id); + } + + Ok(maybe_new_start + .map(|new_start| ((Bound::Excluded(new_start.clone()), end), output))) + }, + ) + } + + async fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> Result, Error> { + Ok(self.inner.lock().await.get_by_operation_id(operation_id)) + } + + async fn get_range_of_actions( + &self, + state: SortedAwaitedActionState, + start: Bound, + end: Bound, + desc: bool, + ) -> impl Stream> + Send + Sync { + ChunkedStream::new(start, end, move |start, end, mut output| async move { + let inner = self.inner.lock().await; + let mut done = true; + let mut new_start = start.as_ref(); + let mut new_end = end.as_ref(); + + let iterator = inner.get_range_of_actions(state, (start.as_ref(), end.as_ref())); + // TODO(allada) This should probably use the `.left()/right()` pattern, + // but that doesn't exist in the std or any libraries we use. + if desc { + for result in iterator.rev() { + let (sorted_awaited_action, item) = + result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + output.push_back(item); + new_end = Bound::Excluded(sorted_awaited_action); + done = false; + } + } else { + for result in iterator { + let (sorted_awaited_action, item) = + result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + output.push_back(item); + new_start = Bound::Excluded(sorted_awaited_action); + done = false; + } + } + if done { + return Ok(None); + } + Ok(Some(((new_start.cloned(), new_end.cloned()), output))) + }) + } + + async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> { + self.inner + .lock() + .await + .update_awaited_action(new_awaited_action) + } + + async fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + self.inner + .lock() + .await + .add_action(client_operation_id, action_info) + .await + } +} diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action.rs deleted file mode 100644 index c7df604c2..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action.rs +++ /dev/null @@ -1,353 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::hash::{DefaultHasher, Hash, Hasher}; -use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering}; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; - -use nativelink_util::action_messages::{ - ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, OperationId, - WorkerId, -}; -use nativelink_util::evicting_map::InstantWrapper; -use parking_lot::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard}; -use static_assertions::{assert_eq_size, const_assert, const_assert_eq}; -use tokio::sync::watch; - -enum ReadOrWriteGuard<'a, T> { - Read(RwLockReadGuard<'a, T>), - Write(RwLockWriteGuard<'a, T>), -} - -#[derive(Debug)] -struct SortInfo { - priority: i32, - sort_key: AwaitedActionSortKey, -} - -/// This struct's main purpose is to allow the caller of `set_priority` to -/// get the previous sort key and perform any additional operations -/// that may be needed after the sort key has been updated without allowing -/// anyone else to read or modify the sort key until the caller is done. -pub struct SortInfoLock<'a> { - previous_sort_key: AwaitedActionSortKey, - new_sort_info: ReadOrWriteGuard<'a, SortInfo>, -} - -impl SortInfoLock<'_> { - /// Gets the previous sort key of the action. - pub fn get_previous_sort_key(&self) -> AwaitedActionSortKey { - self.previous_sort_key - } - - /// Gets the new sort key of the action. - pub fn get_new_sort_key(&self) -> AwaitedActionSortKey { - match &self.new_sort_info { - ReadOrWriteGuard::Read(sort_info) => sort_info.sort_key, - ReadOrWriteGuard::Write(sort_info) => sort_info.sort_key, - } - } -} - -/// An action that is being awaited on and last known state. -#[derive(Debug)] -pub struct AwaitedAction { - /// The action that is being awaited on. - action_info: Arc, - - /// The data that is used to sort the action in the queue. - /// The first item in the tuple is the current priority, - /// the second item is the sort key. - sort_info: RwLock, - - /// Number of attempts the job has been tried. - attempts: AtomicUsize, - - /// The time the action was last updated. - last_worker_updated_timestamp: AtomicU64, - - /// Number of clients listening to the state of the action. - listening_clients: AtomicUsize, - - /// Worker that is currently running this action, None if unassigned. - worker_id: RwLock>, - - /// The channel to notify subscribers of state changes when updated, completed or retrying. - notify_channel: watch::Sender>, -} - -impl AwaitedAction { - pub fn new_with_subscription( - action_info: Arc, - ) -> ( - Self, - AwaitedActionSortKey, - watch::Receiver>, - ) { - let unique_key = match &action_info.unique_qualifier { - ActionUniqueQualifier::Cachable(unique_key) => unique_key, - ActionUniqueQualifier::Uncachable(unique_key) => unique_key, - }; - let stage = ActionStage::Queued; - let sort_key = AwaitedActionSortKey::new_with_unique_key( - action_info.priority, - &action_info.insert_timestamp, - unique_key, - ); - let sort_info = RwLock::new(SortInfo { - priority: action_info.priority, - sort_key, - }); - let current_state = Arc::new(ActionState { - stage, - id: OperationId::new(action_info.unique_qualifier.clone()), - }); - let (tx, rx) = watch::channel(current_state); - ( - Self { - action_info, - notify_channel: tx, - sort_info, - attempts: AtomicUsize::new(0), - last_worker_updated_timestamp: AtomicU64::new(SystemTime::now().unix_timestamp()), - listening_clients: AtomicUsize::new(0), - worker_id: RwLock::new(None), - }, - sort_key, - rx, - ) - } - - /// Gets the action info. - pub fn get_action_info(&self) -> &Arc { - &self.action_info - } - - pub fn get_operation_id(&self) -> OperationId { - self.notify_channel.borrow().id.clone() - } - - pub fn get_last_worker_updated_timestamp(&self) -> SystemTime { - let timestamp = self.last_worker_updated_timestamp.load(Ordering::Acquire); - SystemTime::UNIX_EPOCH + Duration::from_secs(timestamp) - } - - pub fn get_listening_clients(&self) -> usize { - self.listening_clients.load(Ordering::Acquire) - } - - pub fn inc_listening_clients(&self) { - self.listening_clients.fetch_add(1, Ordering::Release); - } - - pub fn dec_listening_clients(&self) { - self.listening_clients.fetch_sub(1, Ordering::Release); - } - - /// Updates the timestamp of the action. - fn update_worker_timestamp(&self) { - self.last_worker_updated_timestamp - .store(SystemTime::now().unix_timestamp(), Ordering::Release); - } - - /// Upgrades the priority of the action if new priority is higher. - /// - /// If the priority was already set to `new_priority`, this function will - /// return `None`. If the priority was different, it will return a - /// struct that contains the previous sort key and the new sort key and - /// will hold a lock preventing anyone else from reading or modifying the - /// sort key until the result is dropped. - #[must_use] - pub fn upgrade_priority(&self, new_priority: i32) -> Option { - let sort_info_lock = self.sort_info.upgradable_read(); - if sort_info_lock.priority >= new_priority { - return None; - } - let unique_key = match &self.action_info.unique_qualifier { - ActionUniqueQualifier::Cachable(unique_key) => unique_key, - ActionUniqueQualifier::Uncachable(unique_key) => unique_key, - }; - let mut sort_info_lock = RwLockUpgradableReadGuard::upgrade(sort_info_lock); - let previous_sort_key = sort_info_lock.sort_key; - sort_info_lock.priority = new_priority; - sort_info_lock.sort_key = AwaitedActionSortKey::new_with_unique_key( - new_priority, - &self.action_info.insert_timestamp, - unique_key, - ); - Some(SortInfoLock { - previous_sort_key, - new_sort_info: ReadOrWriteGuard::Write(sort_info_lock), - }) - } - - /// Gets the sort info of the action. - pub fn get_sort_info(&self) -> SortInfoLock { - let sort_info = self.sort_info.read(); - SortInfoLock { - previous_sort_key: sort_info.sort_key, - new_sort_info: ReadOrWriteGuard::Read(sort_info), - } - } - - /// Gets the number of times the action has been attempted - /// to be executed. - pub fn get_attempts(&self) -> usize { - self.attempts.load(Ordering::Acquire) - } - - /// Adds one to the number of attempts the action has been tried. - pub fn inc_attempts(&self) { - self.attempts.fetch_add(1, Ordering::Release); - } - - /// Gets the worker id that is currently processing this action. - pub fn get_worker_id(&self) -> Option { - *self.worker_id.read() - } - - /// Sets the worker id that is currently processing this action. - pub fn set_worker_id(&self, new_worker_id: Option) { - let mut worker_id = self.worker_id.write(); - if *worker_id != new_worker_id { - self.update_worker_timestamp(); - *worker_id = new_worker_id; - } - } - - /// Gets the current state of the action. - pub fn get_current_state(&self) -> Arc { - self.notify_channel.borrow().clone() - } - - /// Sets the current state of the action and notifies subscribers. - /// Returns true if the state was set, false if there are no subscribers. - #[must_use] - pub fn set_current_state(&self, state: Arc) -> bool { - self.update_worker_timestamp(); - // Note: Use `send_replace()`. Using `send()` will not change the value if - // there are no subscribers. - self.notify_channel.send_replace(state); - self.notify_channel.receiver_count() > 0 - } - - pub fn subscribe(&self) -> watch::Receiver> { - self.notify_channel.subscribe() - } -} - -/// The key used to sort the awaited actions. -/// -/// The rules for sorting are as follows: -/// 1. priority of the action -/// 2. insert order of the action (lower = higher priority) -/// 3. (mostly random hash based on the action info) -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -#[repr(transparent)] -pub struct AwaitedActionSortKey(u128); - -impl AwaitedActionSortKey { - #[rustfmt::skip] - const fn new(priority: i32, insert_timestamp: u64, hash: [u8; 4]) -> Self { - // Shift `new_priority` so [`i32::MIN`] is represented by zero. - // This makes it so any nagative values are positive, but - // maintains ordering. - const MIN_I32: i64 = (i32::MIN as i64).abs(); - let priority = ((priority as i64 + MIN_I32) as u32).to_be_bytes(); - - // Invert our timestamp so the larger the timestamp the lower the number. - // This makes timestamp descending order instead of ascending. - let timestamp = (insert_timestamp ^ u64::MAX).to_be_bytes(); - - AwaitedActionSortKey(u128::from_be_bytes([ - priority[0], priority[1], priority[2], priority[3], - timestamp[0], timestamp[1], timestamp[2], timestamp[3], - timestamp[4], timestamp[5], timestamp[6], timestamp[7], - hash[0], hash[1], hash[2], hash[3] - ])) - } - - fn new_with_unique_key( - priority: i32, - insert_timestamp: &SystemTime, - action_hash: &ActionUniqueKey, - ) -> Self { - let hash = { - let mut hasher = DefaultHasher::new(); - ActionUniqueKey::hash(action_hash, &mut hasher); - hasher.finish().to_le_bytes()[0..4].try_into().unwrap() - }; - Self::new(priority, insert_timestamp.unix_timestamp(), hash) - } -} - -// Ensure the size of the sort key is the same as a `u64`. -assert_eq_size!(AwaitedActionSortKey, u128); - -const_assert_eq!( - AwaitedActionSortKey::new(0x1234_5678, 0x9abc_def0_1234_5678, [0x9a, 0xbc, 0xde, 0xf0]).0, - // Note: Result has 0x12345678 + 0x80000000 = 0x92345678 because we need - // to shift the `i32::MIN` value to be represented by zero. - // Note: `6543210fedcba987` are the inverted bits of `9abcdef012345678`. - // This effectively inverts the priority to now have the highest priority - // be the lowest timestamps. - AwaitedActionSortKey(0x9234_5678_6543_210f_edcb_a987_9abc_def0).0 -); -// Ensure the priority is used as the sort key first. -const_assert!( - AwaitedActionSortKey::new(i32::MAX, 0, [0xff; 4]).0 - > AwaitedActionSortKey::new(i32::MAX - 1, 0, [0; 4]).0 -); -const_assert!( - AwaitedActionSortKey::new(i32::MAX - 1, 0, [0xff; 4]).0 - > AwaitedActionSortKey::new(1, 0, [0; 4]).0 -); -const_assert!( - AwaitedActionSortKey::new(1, 0, [0xff; 4]).0 > AwaitedActionSortKey::new(0, 0, [0; 4]).0 -); -const_assert!( - AwaitedActionSortKey::new(0, 0, [0xff; 4]).0 > AwaitedActionSortKey::new(-1, 0, [0; 4]).0 -); -const_assert!( - AwaitedActionSortKey::new(-1, 0, [0xff; 4]).0 - > AwaitedActionSortKey::new(i32::MIN + 1, 0, [0; 4]).0 -); -const_assert!( - AwaitedActionSortKey::new(i32::MIN + 1, 0, [0xff; 4]).0 - > AwaitedActionSortKey::new(i32::MIN, 0, [0; 4]).0 -); - -// Ensure the insert timestamp is used as the sort key second. -const_assert!( - AwaitedActionSortKey::new(0, u64::MIN, [0; 4]).0 - > AwaitedActionSortKey::new(0, u64::MAX, [0; 4]).0 -); - -// Ensure the hash is used as the sort key third. -const_assert!( - AwaitedActionSortKey::new(0, 0, [0xff, 0xff, 0xff, 0xff]).0 - > AwaitedActionSortKey::new(0, 0, [0; 4]).0 -); -const_assert!( - AwaitedActionSortKey::new(1, 0, [0xff, 0xff, 0xff, 0xff]).0 - > AwaitedActionSortKey::new(0, 0, [0; 4]).0 -); -const_assert!( - AwaitedActionSortKey::new(0, 0, [0; 4]).0 > AwaitedActionSortKey::new(0, 1, [0; 4]).0 -); -const_assert!( - AwaitedActionSortKey::new(0, 0, [0xff, 0xff, 0xff, 0xff]).0 - > AwaitedActionSortKey::new(0, 0, [0; 4]).0 -); diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action_db.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action_db.rs deleted file mode 100644 index ee7777967..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/awaited_action_db.rs +++ /dev/null @@ -1,399 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; -use std::time::SystemTime; - -use nativelink_config::stores::EvictionPolicy; -use nativelink_error::{make_err, Code, Error, ResultExt}; -use nativelink_util::action_messages::{ - ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, - ClientOperationId, OperationId, -}; -use nativelink_util::evicting_map::EvictingMap; -use tokio::sync::{mpsc, watch}; -use tracing::{event, Level}; - -use super::client_awaited_action::ClientAwaitedAction; -use super::{AwaitedAction, SortedAwaitedAction}; - -#[derive(Default)] -struct SortedAwaitedActions { - unknown: BTreeSet, - cache_check: BTreeSet, - queued: BTreeSet, - executing: BTreeSet, - completed: BTreeSet, - completed_from_cache: BTreeSet, -} - -/// The database for storing the state of all actions. -/// IMPORTANT: Any time an item is removed from -/// [`AwaitedActionDb::client_operation_to_awaited_action`], it must -/// also remove the entries from all the other maps. -pub struct AwaitedActionDb { - /// A lookup table to lookup the state of an action by its client operation id. - client_operation_to_awaited_action: - EvictingMap, SystemTime>, - - /// A lookup table to lookup the state of an action by its worker operation id. - operation_id_to_awaited_action: HashMap>, - - /// A lookup table to lookup the state of an action by its unique qualifier. - action_info_hash_key_to_awaited_action: HashMap>, - - /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting - /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`]. - /// - /// See [`AwaitedActionSortKey`] for more information on the ordering. - sorted_action_info_hash_keys: SortedAwaitedActions, -} - -#[allow(clippy::mutable_key_type)] -impl AwaitedActionDb { - pub fn new(eviction_config: &EvictionPolicy) -> Self { - Self { - client_operation_to_awaited_action: EvictingMap::new( - eviction_config, - SystemTime::now(), - ), - operation_id_to_awaited_action: HashMap::new(), - action_info_hash_key_to_awaited_action: HashMap::new(), - sorted_action_info_hash_keys: SortedAwaitedActions::default(), - } - } - - /// Refreshes/Updates the time to live of the [`ClientOperationId`] in - /// the [`EvictingMap`] by touching the key. - pub async fn refresh_client_operation_id( - &self, - client_operation_id: &ClientOperationId, - ) -> bool { - self.client_operation_to_awaited_action - .size_for_key(client_operation_id) - .await - .is_some() - } - - pub async fn get_by_client_operation_id( - &self, - client_operation_id: &ClientOperationId, - ) -> Option> { - self.client_operation_to_awaited_action - .get(client_operation_id) - .await - } - - /// When a client operation is dropped, we need to remove it from the - /// other maps and update the listening clients count on the [`AwaitedAction`]. - pub fn on_client_operations_drop( - &mut self, - awaited_actions: impl IntoIterator>, - ) { - for awaited_action in awaited_actions.into_iter() { - if awaited_action.get_listening_clients() != 0 { - // We still have other clients listening to this action. - continue; - } - - let operation_id = awaited_action.get_operation_id(); - - // Cleanup operation_id_to_awaited_action. - if self - .operation_id_to_awaited_action - .remove(&operation_id) - .is_none() - { - event!( - Level::ERROR, - ?operation_id, - ?awaited_action, - "operation_id_to_awaited_action and client_operation_to_awaited_action are out of sync", - ); - } - - // Cleanup action_info_hash_key_to_awaited_action if it was marked cached. - let action_info = awaited_action.get_action_info(); - match &action_info.unique_qualifier { - ActionUniqueQualifier::Cachable(action_key) => { - let maybe_awaited_action = self - .action_info_hash_key_to_awaited_action - .remove(action_key); - if maybe_awaited_action.is_none() { - event!( - Level::ERROR, - ?operation_id, - ?awaited_action, - ?action_key, - "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", - ); - } - } - ActionUniqueQualifier::Uncachable(_action_key) => { - // This Operation should not be in the hash_key map. - } - } - - // Cleanup sorted_awaited_action. - let sort_info = awaited_action.get_sort_info(); - let sort_key = sort_info.get_previous_sort_key(); - let sort_map_for_state = - self.get_sort_map_for_state(&awaited_action.get_current_state().stage); - drop(sort_info); - let maybe_sorted_awaited_action = sort_map_for_state.take(&SortedAwaitedAction { - sort_key, - awaited_action, - }); - if maybe_sorted_awaited_action.is_none() { - event!( - Level::ERROR, - ?operation_id, - ?sort_key, - "Expected maybe_sorted_awaited_action to have {sort_key:?}", - ); - } - } - } - - pub fn get_all_awaited_actions(&self) -> impl Iterator> { - self.operation_id_to_awaited_action.values() - } - - pub fn get_by_operation_id(&self, operation_id: &OperationId) -> Option<&Arc> { - self.operation_id_to_awaited_action.get(operation_id) - } - - pub fn get_cache_check_actions(&self) -> &BTreeSet { - &self.sorted_action_info_hash_keys.cache_check - } - - pub fn get_queued_actions(&self) -> &BTreeSet { - &self.sorted_action_info_hash_keys.queued - } - - pub fn get_executing_actions(&self) -> &BTreeSet { - &self.sorted_action_info_hash_keys.executing - } - - pub fn get_completed_actions(&self) -> &BTreeSet { - &self.sorted_action_info_hash_keys.completed - } - - fn get_sort_map_for_state( - &mut self, - state: &ActionStage, - ) -> &mut BTreeSet { - match state { - ActionStage::Unknown => &mut self.sorted_action_info_hash_keys.unknown, - ActionStage::CacheCheck => &mut self.sorted_action_info_hash_keys.cache_check, - ActionStage::Queued => &mut self.sorted_action_info_hash_keys.queued, - ActionStage::Executing => &mut self.sorted_action_info_hash_keys.executing, - ActionStage::Completed(_) => &mut self.sorted_action_info_hash_keys.completed, - ActionStage::CompletedFromCache(_) => { - &mut self.sorted_action_info_hash_keys.completed_from_cache - } - } - } - - fn insert_sort_map_for_stage( - &mut self, - stage: &ActionStage, - sorted_awaited_action: SortedAwaitedAction, - ) { - let newly_inserted = match stage { - ActionStage::Unknown => self - .sorted_action_info_hash_keys - .unknown - .insert(sorted_awaited_action), - ActionStage::CacheCheck => self - .sorted_action_info_hash_keys - .cache_check - .insert(sorted_awaited_action), - ActionStage::Queued => self - .sorted_action_info_hash_keys - .queued - .insert(sorted_awaited_action), - ActionStage::Executing => self - .sorted_action_info_hash_keys - .executing - .insert(sorted_awaited_action), - ActionStage::Completed(_) => self - .sorted_action_info_hash_keys - .completed - .insert(sorted_awaited_action), - ActionStage::CompletedFromCache(_) => self - .sorted_action_info_hash_keys - .completed_from_cache - .insert(sorted_awaited_action), - }; - if !newly_inserted { - event!( - Level::ERROR, - "Tried to insert an action that was already in the sorted map. This should never happen.", - ); - } - } - - /// Sets the state of the action to the provided `action_state` and notifies all listeners. - /// If the action has no more listeners, returns `false`. - pub fn set_action_state( - &mut self, - awaited_action: Arc, - new_action_state: Arc, - ) -> bool { - // We need to first get a lock on the awaited action to ensure - // another operation doesn't update it while we are looking up - // the sorted key. - let sort_info = awaited_action.get_sort_info(); - let old_state = awaited_action.get_current_state(); - - let has_listeners = awaited_action.set_current_state(new_action_state.clone()); - - if !old_state.stage.is_same_stage(&new_action_state.stage) { - let sort_key = sort_info.get_previous_sort_key(); - let btree = self.get_sort_map_for_state(&old_state.stage); - drop(sort_info); - let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction { - sort_key, - awaited_action, - }); - - let Some(sorted_awaited_action) = maybe_sorted_awaited_action else { - event!( - Level::ERROR, - "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync", - ); - return false; - }; - - self.insert_sort_map_for_stage(&new_action_state.stage, sorted_awaited_action); - } - has_listeners - } - - pub async fn subscribe_or_add_action( - &mut self, - client_operation_id: ClientOperationId, - action_info: Arc, - client_operation_drop_tx: &mpsc::UnboundedSender>, - ) -> Result>, Error> { - // Check to see if the action is already known and subscribe if it is. - let subscription_result = self - .try_subscribe( - &client_operation_id, - &action_info.unique_qualifier, - action_info.priority, - client_operation_drop_tx, - ) - .await - .err_tip(|| "In AwaitedActionDb::subscribe_or_add_action"); - match subscription_result { - Err(err) => return Err(err), - Ok(Some(subscription)) => return Ok(subscription), - Ok(None) => { /* Add item to queue. */ } - } - - let maybe_unique_key = match &action_info.unique_qualifier { - ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()), - ActionUniqueQualifier::Uncachable(_unique_key) => None, - }; - let (awaited_action, sort_key, subscription) = - AwaitedAction::new_with_subscription(action_info); - let awaited_action = Arc::new(awaited_action); - self.client_operation_to_awaited_action - .insert( - client_operation_id, - Arc::new(ClientAwaitedAction::new( - awaited_action.clone(), - client_operation_drop_tx.clone(), - )), - ) - .await; - // Note: We only put items in the map that are cachable. - if let Some(unique_key) = maybe_unique_key { - self.action_info_hash_key_to_awaited_action - .insert(unique_key, awaited_action.clone()); - } - self.operation_id_to_awaited_action - .insert(awaited_action.get_operation_id(), awaited_action.clone()); - - self.insert_sort_map_for_stage( - &awaited_action.get_current_state().stage, - SortedAwaitedAction { - sort_key, - awaited_action, - }, - ); - Ok(subscription) - } - - async fn try_subscribe( - &mut self, - client_operation_id: &ClientOperationId, - unique_qualifier: &ActionUniqueQualifier, - priority: i32, - client_operation_drop_tx: &mpsc::UnboundedSender>, - ) -> Result>>, Error> { - let unique_key = match unique_qualifier { - ActionUniqueQualifier::Cachable(unique_key) => unique_key, - ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None), - }; - - let Some(awaited_action) = self.action_info_hash_key_to_awaited_action.get(unique_key) - else { - return Ok(None); // Not currently running. - }; - - // Do not subscribe if the action is already completed, - // this is the responsibility of the CacheLookupScheduler. - // TODO(allad) Once we land the new scheduler onto main, we can remove this check. - // It makes sense to allow users to subscribe to already completed items. - if awaited_action.get_current_state().stage.is_finished() { - return Ok(None); // Already completed. - } - let awaited_action = awaited_action.clone(); - if let Some(sort_info_lock) = awaited_action.upgrade_priority(priority) { - let state = awaited_action.get_current_state(); - let maybe_sorted_awaited_action = - self.get_sort_map_for_state(&state.stage) - .take(&SortedAwaitedAction { - sort_key: sort_info_lock.get_previous_sort_key(), - awaited_action: awaited_action.clone(), - }); - let Some(mut sorted_awaited_action) = maybe_sorted_awaited_action else { - return Err(make_err!( - Code::Internal, - "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync" - )); - }; - sorted_awaited_action.sort_key = sort_info_lock.get_new_sort_key(); - self.insert_sort_map_for_stage(&state.stage, sorted_awaited_action); - } - - let subscription = awaited_action.subscribe(); - - self.client_operation_to_awaited_action - .insert( - client_operation_id.clone(), - Arc::new(ClientAwaitedAction::new( - awaited_action, - client_operation_drop_tx.clone(), - )), - ) - .await; - - Ok(Some(subscription)) - } -} diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/client_awaited_action.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/client_awaited_action.rs deleted file mode 100644 index 5661256bd..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/client_awaited_action.rs +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use nativelink_util::evicting_map::LenEntry; -use tokio::sync::mpsc; - -use super::AwaitedAction; - -/// Represents a client that is currently listening to an action. -/// When the client is dropped, it will send the [`AwaitedAction`] to the -/// `client_operation_drop_tx` if there are other cleanups needed. -#[derive(Debug)] -pub(crate) struct ClientAwaitedAction { - /// The awaited action that the client is listening to. - // Note: This is an Option because it is taken when the - // ClientAwaitedAction is dropped, but will never actually be - // None except during the drop. - awaited_action: Option>, - - /// The sender to notify of this struct being dropped. - client_operation_drop_tx: mpsc::UnboundedSender>, -} - -impl ClientAwaitedAction { - pub fn new( - awaited_action: Arc, - client_operation_drop_tx: mpsc::UnboundedSender>, - ) -> Self { - awaited_action.inc_listening_clients(); - Self { - awaited_action: Some(awaited_action), - client_operation_drop_tx, - } - } - - /// Returns the awaited action that the client is listening to. - pub fn awaited_action(&self) -> &Arc { - self.awaited_action - .as_ref() - .expect("AwaitedAction should be present") - } -} - -impl Drop for ClientAwaitedAction { - fn drop(&mut self) { - let awaited_action = self - .awaited_action - .take() - .expect("AwaitedAction should be present"); - awaited_action.dec_listening_clients(); - // If we failed to send it means noone is listening. - let _ = self.client_operation_drop_tx.send(awaited_action); - } -} - -/// Trait to be able to use the EvictingMap with [`ClientAwaitedAction`]. -/// Note: We only use EvictingMap for a time based eviction, which is -/// why the implementation has fixed default values in it. -impl LenEntry for ClientAwaitedAction { - #[inline] - fn len(&self) -> usize { - 0 - } - - #[inline] - fn is_empty(&self) -> bool { - true - } -} diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/mod.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/mod.rs deleted file mode 100644 index 0c3da873f..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/mod.rs +++ /dev/null @@ -1,23 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -pub(crate) use awaited_action::AwaitedAction; -pub(crate) use awaited_action_db::AwaitedActionDb; -pub(crate) use sorted_awaited_action::SortedAwaitedAction; - -mod awaited_action; -#[allow(clippy::module_inception)] -mod awaited_action_db; -mod client_awaited_action; -mod sorted_awaited_action; diff --git a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/sorted_awaited_action.rs b/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/sorted_awaited_action.rs deleted file mode 100644 index 18a7d607a..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/awaited_action_db/sorted_awaited_action.rs +++ /dev/null @@ -1,46 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::cmp; -use std::sync::Arc; - -use super::awaited_action::{AwaitedAction, AwaitedActionSortKey}; - -#[derive(Debug, Clone)] -pub struct SortedAwaitedAction { - pub sort_key: AwaitedActionSortKey, - pub awaited_action: Arc, -} - -impl PartialEq for SortedAwaitedAction { - fn eq(&self, other: &Self) -> bool { - Arc::ptr_eq(&self.awaited_action, &other.awaited_action) && self.sort_key == other.sort_key - } -} - -impl Eq for SortedAwaitedAction {} - -impl PartialOrd for SortedAwaitedAction { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for SortedAwaitedAction { - fn cmp(&self, other: &Self) -> cmp::Ordering { - self.sort_key.cmp(&other.sort_key).then_with(|| { - Arc::as_ptr(&self.awaited_action).cmp(&Arc::as_ptr(&other.awaited_action)) - }) - } -} diff --git a/nativelink-scheduler/src/memory_scheduler_state/client_action_state_result.rs b/nativelink-scheduler/src/memory_scheduler_state/client_action_state_result.rs deleted file mode 100644 index b0469c3f7..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/client_action_state_result.rs +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::borrow::Cow; -use std::sync::Arc; - -use async_trait::async_trait; -use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionState}; -use nativelink_util::operation_state_manager::ActionStateResult; -use nativelink_util::task::JoinHandleDropGuard; -use tokio::sync::watch::Receiver; - -pub(crate) struct ClientActionStateResult { - /// The action info for the action. - action_info: Arc, - - /// The receiver for the action state updates. - rx: Receiver>, - - /// Holds a handle to an optional spawn that will be automatically - /// canceled when this struct is dropped. - /// This is primarily used to keep the EvictionMap from dropping the - /// struct while a client is listening for updates. - _maybe_keepalive_spawn: Option>, -} - -impl ClientActionStateResult { - pub fn new( - action_info: Arc, - mut rx: Receiver>, - maybe_keepalive_spawn: Option>, - ) -> Self { - // Marking the initial value as changed for new or existing actions regardless if - // underlying state has changed. This allows for triggering notification after subscription - // without having to use an explicit notification. - rx.mark_changed(); - Self { - action_info, - rx, - _maybe_keepalive_spawn: maybe_keepalive_spawn, - } - } -} - -#[async_trait] -impl ActionStateResult for ClientActionStateResult { - async fn as_state(&self) -> Result, Error> { - Ok(self.rx.borrow().clone()) - } - - async fn as_receiver(&self) -> Result>>, Error> { - Ok(Cow::Borrowed(&self.rx)) - } - - async fn as_action_info(&self) -> Result, Error> { - Ok(self.action_info.clone()) - } -} diff --git a/nativelink-scheduler/src/memory_scheduler_state/matching_engine_action_state_result.rs b/nativelink-scheduler/src/memory_scheduler_state/matching_engine_action_state_result.rs deleted file mode 100644 index 4b5948169..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/matching_engine_action_state_result.rs +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::borrow::Cow; -use std::sync::Arc; - -use async_trait::async_trait; -use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionState}; -use nativelink_util::operation_state_manager::ActionStateResult; -use tokio::sync::watch; - -use super::awaited_action_db::AwaitedAction; - -pub(crate) struct MatchingEngineActionStateResult { - awaited_action: Arc, -} -impl MatchingEngineActionStateResult { - pub fn new(awaited_action: Arc) -> Self { - Self { awaited_action } - } -} - -#[async_trait] -impl ActionStateResult for MatchingEngineActionStateResult { - async fn as_state(&self) -> Result, Error> { - Ok(self.awaited_action.get_current_state()) - } - - async fn as_receiver(&self) -> Result>>, Error> { - Ok(Cow::Owned(self.awaited_action.subscribe())) - } - - async fn as_action_info(&self) -> Result, Error> { - Ok(self.awaited_action.get_action_info().clone()) - } -} diff --git a/nativelink-scheduler/src/memory_scheduler_state/memory_scheduler_state_manager.rs b/nativelink-scheduler/src/memory_scheduler_state/memory_scheduler_state_manager.rs deleted file mode 100644 index 005c52e1a..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/memory_scheduler_state_manager.rs +++ /dev/null @@ -1,569 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::collections::{BTreeSet, VecDeque}; -use std::ops::Bound; -use std::sync::{Arc, Weak}; -use std::time::Duration; - -use async_lock::Mutex; -use async_trait::async_trait; -use futures::stream::{self, unfold}; -use nativelink_config::stores::EvictionPolicy; -use nativelink_error::{make_err, Code, Error, ResultExt}; -use nativelink_util::action_messages::{ - ActionInfo, ActionResult, ActionStage, ActionState, ActionUniqueQualifier, ClientOperationId, - ExecutionMetadata, OperationId, WorkerId, -}; -use nativelink_util::operation_state_manager::{ - ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, - OperationFilter, OperationStageFlags, OrderDirection, WorkerStateManager, -}; -use nativelink_util::spawn; -use nativelink_util::task::JoinHandleDropGuard; -use tokio::sync::{mpsc, watch, Notify}; -use tracing::{event, Level}; - -use super::awaited_action_db::{AwaitedAction, AwaitedActionDb, SortedAwaitedAction}; -use super::client_action_state_result::ClientActionStateResult; -use super::matching_engine_action_state_result::MatchingEngineActionStateResult; - -/// How often the owning database will have the AwaitedAction touched -/// to keep it from being evicted. -const KEEPALIVE_DURATION: Duration = Duration::from_secs(10); - -/// Number of client drop events to pull from the stream at a time. -const MAX_CLIENT_DROP_HANDLES_PER_CYCLE: usize = 1024; - -fn apply_filter_predicate(awaited_action: &AwaitedAction, filter: &OperationFilter) -> bool { - // Note: The caller must filter `client_operation_id`. - - if let Some(operation_id) = &filter.operation_id { - if operation_id != &awaited_action.get_operation_id() { - return false; - } - } - - if filter.worker_id.is_some() && filter.worker_id != awaited_action.get_worker_id() { - return false; - } - - { - let action_info = awaited_action.get_action_info(); - if let Some(filter_unique_key) = &filter.unique_key { - match &action_info.unique_qualifier { - ActionUniqueQualifier::Cachable(unique_key) => { - if filter_unique_key != unique_key { - return false; - } - } - ActionUniqueQualifier::Uncachable(_) => { - return false; - } - } - } - if let Some(action_digest) = filter.action_digest { - if action_digest != action_info.digest() { - return false; - } - } - } - - { - let last_worker_update_timestamp = awaited_action.get_last_worker_updated_timestamp(); - if let Some(worker_update_before) = filter.worker_update_before { - if worker_update_before < last_worker_update_timestamp { - return false; - } - } - let state = awaited_action.get_current_state(); - if let Some(completed_before) = filter.completed_before { - if state.stage.is_finished() && completed_before < last_worker_update_timestamp { - return false; - } - } - if filter.stages != OperationStageFlags::Any { - let stage_flag = match state.stage { - ActionStage::Unknown => OperationStageFlags::Any, - ActionStage::CacheCheck => OperationStageFlags::CacheCheck, - ActionStage::Queued => OperationStageFlags::Queued, - ActionStage::Executing => OperationStageFlags::Executing, - ActionStage::Completed(_) => OperationStageFlags::Completed, - ActionStage::CompletedFromCache(_) => OperationStageFlags::Completed, - }; - if !filter.stages.intersects(stage_flag) { - return false; - } - } - } - - true -} - -/// Utility struct to create a background task that keeps the client operation id alive. -fn make_client_keepalive_spawn( - client_operation_id: ClientOperationId, - inner_weak: Weak>, -) -> JoinHandleDropGuard<()> { - spawn!("client_action_state_result_keepalive", async move { - loop { - tokio::time::sleep(KEEPALIVE_DURATION).await; - let Some(inner) = inner_weak.upgrade() else { - return; // Nothing to do. - }; - let inner = inner.lock().await; - let refresh_success = inner - .action_db - .refresh_client_operation_id(&client_operation_id) - .await; - if !refresh_success { - event! { - Level::ERROR, - ?client_operation_id, - "Client operation id not found in MemorySchedulerStateManager::add_action keepalive" - }; - } - } - }) -} - -/// MemorySchedulerStateManager is responsible for maintaining the state of the scheduler. -/// Scheduler state includes the actions that are queued, active, and recently completed. -/// It also includes the workers that are available to execute actions based on allocation -/// strategy. -struct MemorySchedulerStateManagerImpl { - /// Database for storing the state of all actions. - action_db: AwaitedActionDb, - - /// Notify task<->worker matching engine that work needs to be done. - tasks_change_notify: Arc, - - /// Maximum number of times a job can be retried. - max_job_retries: usize, - - /// Channel to notify when a client operation id is dropped. - client_operation_drop_tx: mpsc::UnboundedSender>, - - /// Task to cleanup client operation ids that are no longer being listened to. - // Note: This has a custom Drop function on it. It should stay alive only while - // the MemorySchedulerStateManager is alive. - _client_operation_cleanup_spawn: JoinHandleDropGuard<()>, -} - -impl MemorySchedulerStateManagerImpl { - fn inner_update_operation( - &mut self, - operation_id: &OperationId, - maybe_worker_id: Option<&WorkerId>, - action_stage_result: Result, - ) -> Result<(), Error> { - let awaited_action = self - .action_db - .get_by_operation_id(operation_id) - .ok_or_else(|| { - make_err!( - Code::Internal, - "Could not find action info MemorySchedulerStateManager::update_operation {}", - format!( - "for operation_id: {operation_id}, maybe_worker_id: {maybe_worker_id:?}" - ), - ) - })? - .clone(); - - // Make sure we don't update an action that is already completed. - if awaited_action.get_current_state().stage.is_finished() { - return Err(make_err!( - Code::Internal, - "Action {operation_id:?} is already completed with state {:?} - maybe_worker_id: {:?}", - awaited_action.get_current_state().stage, - maybe_worker_id, - )); - } - - // Make sure the worker id matches the awaited action worker id. - // This might happen if the worker sending the update is not the - // worker that was assigned. - let awaited_action_worker_id = awaited_action.get_worker_id(); - if awaited_action_worker_id.is_some() - && maybe_worker_id.is_some() - && maybe_worker_id != awaited_action_worker_id.as_ref() - { - let err = make_err!( - Code::Internal, - "Worker ids do not match - {:?} != {:?} for {:?}", - maybe_worker_id, - awaited_action_worker_id, - awaited_action, - ); - event!( - Level::ERROR, - ?operation_id, - ?maybe_worker_id, - ?awaited_action_worker_id, - "{}", - err.to_string(), - ); - return Err(err); - } - - let stage = match action_stage_result { - Ok(stage) => stage, - Err(err) => { - // Don't count a backpressure failure as an attempt for an action. - let due_to_backpressure = err.code == Code::ResourceExhausted; - if !due_to_backpressure { - awaited_action.inc_attempts(); - } - - if awaited_action.get_attempts() > self.max_job_retries { - ActionStage::Completed(ActionResult { - execution_metadata: ExecutionMetadata { - worker: maybe_worker_id.map_or_else(String::default, |v| v.to_string()), - ..ExecutionMetadata::default() - }, - error: Some(err.clone().merge(make_err!( - Code::Internal, - "Job cancelled because it attempted to execute too many times and failed {}", - format!("for operation_id: {operation_id}, maybe_worker_id: {maybe_worker_id:?}"), - ))), - ..ActionResult::default() - }) - } else { - ActionStage::Queued - } - } - }; - if matches!(stage, ActionStage::Queued) { - // If the action is queued, we need to unset the worker id regardless of - // which worker sent the update. - awaited_action.set_worker_id(None); - } else { - awaited_action.set_worker_id(maybe_worker_id.copied()); - } - let has_listeners = self.action_db.set_action_state( - awaited_action.clone(), - Arc::new(ActionState { - stage, - id: operation_id.clone(), - }), - ); - if !has_listeners { - let action_state = awaited_action.get_current_state(); - event!( - Level::WARN, - ?awaited_action, - ?action_state, - "Action has no more listeners during AwaitedActionDb::set_action_state" - ); - } - - self.tasks_change_notify.notify_one(); - Ok(()) - } - - async fn inner_add_operation( - &mut self, - new_client_operation_id: ClientOperationId, - action_info: Arc, - ) -> Result>, Error> { - let rx = self - .action_db - .subscribe_or_add_action( - new_client_operation_id, - action_info, - &self.client_operation_drop_tx, - ) - .await - .err_tip(|| "In MemorySchedulerStateManager::add_operation")?; - self.tasks_change_notify.notify_one(); - Ok(rx) - } -} - -#[repr(transparent)] -pub struct MemorySchedulerStateManager { - inner: Arc>, -} - -impl MemorySchedulerStateManager { - pub fn new( - eviction_config: &EvictionPolicy, - tasks_change_notify: Arc, - max_job_retries: usize, - ) -> Self { - Self { - inner: Arc::new_cyclic(move |weak_self| -> Mutex { - let weak_inner = weak_self.clone(); - let (client_operation_drop_tx, mut client_operation_drop_rx) = - mpsc::unbounded_channel(); - let client_operation_cleanup_spawn = - spawn!("state_manager_client_drop_rx", async move { - let mut dropped_client_ids = - Vec::with_capacity(MAX_CLIENT_DROP_HANDLES_PER_CYCLE); - loop { - dropped_client_ids.clear(); - client_operation_drop_rx - .recv_many( - &mut dropped_client_ids, - MAX_CLIENT_DROP_HANDLES_PER_CYCLE, - ) - .await; - let Some(inner) = weak_inner.upgrade() else { - return; // Nothing to cleanup, our struct is dropped. - }; - let mut inner_mux = inner.lock().await; - inner_mux - .action_db - .on_client_operations_drop(dropped_client_ids.drain(..)); - } - }); - Mutex::new(MemorySchedulerStateManagerImpl { - action_db: AwaitedActionDb::new(eviction_config), - tasks_change_notify, - max_job_retries, - client_operation_drop_tx, - _client_operation_cleanup_spawn: client_operation_cleanup_spawn, - }) - }), - } - } - - async fn inner_filter_operations( - &self, - filter: &OperationFilter, - to_action_state_result: F, - ) -> Result - where - F: Fn(Arc) -> Arc + Send + Sync + 'static, - { - fn get_tree_for_stage( - action_db: &AwaitedActionDb, - stage: OperationStageFlags, - ) -> Option<&BTreeSet> { - match stage { - OperationStageFlags::CacheCheck => Some(action_db.get_cache_check_actions()), - OperationStageFlags::Queued => Some(action_db.get_queued_actions()), - OperationStageFlags::Executing => Some(action_db.get_executing_actions()), - OperationStageFlags::Completed => Some(action_db.get_completed_actions()), - _ => None, - } - } - - let inner = self.inner.lock().await; - - if let Some(operation_id) = &filter.operation_id { - return Ok(inner - .action_db - .get_by_operation_id(operation_id) - .filter(|awaited_action| apply_filter_predicate(awaited_action.as_ref(), filter)) - .cloned() - .map(|awaited_action| -> ActionStateResultStream { - Box::pin(stream::once(async move { - to_action_state_result(awaited_action) - })) - }) - .unwrap_or_else(|| Box::pin(stream::empty()))); - } - if let Some(client_operation_id) = &filter.client_operation_id { - return Ok(inner - .action_db - .get_by_client_operation_id(client_operation_id) - .await - .filter(|client_awaited_action| { - apply_filter_predicate(client_awaited_action.awaited_action().as_ref(), filter) - }) - .map(|client_awaited_action| -> ActionStateResultStream { - Box::pin(stream::once(async move { - to_action_state_result(client_awaited_action.awaited_action().clone()) - })) - }) - .unwrap_or_else(|| Box::pin(stream::empty()))); - } - - if get_tree_for_stage(&inner.action_db, filter.stages).is_none() { - let mut all_items: Vec> = inner - .action_db - .get_all_awaited_actions() - .filter(|awaited_action| apply_filter_predicate(awaited_action.as_ref(), filter)) - .cloned() - .collect(); - match filter.order_by_priority_direction { - Some(OrderDirection::Asc) => all_items.sort_unstable_by(|a, b| { - a.get_sort_info() - .get_new_sort_key() - .cmp(&b.get_sort_info().get_new_sort_key()) - }), - Some(OrderDirection::Desc) => all_items.sort_unstable_by(|a, b| { - b.get_sort_info() - .get_new_sort_key() - .cmp(&a.get_sort_info().get_new_sort_key()) - }), - None => {} - } - return Ok(Box::pin(stream::iter( - all_items.into_iter().map(to_action_state_result), - ))); - } - - drop(inner); - - struct State< - F: Fn(Arc) -> Arc + Send + Sync + 'static, - > { - inner: Arc>, - filter: OperationFilter, - buffer: VecDeque, - start_key: Bound, - to_action_state_result: F, - } - let state = State { - inner: self.inner.clone(), - filter: filter.clone(), - buffer: VecDeque::new(), - start_key: Bound::Unbounded, - to_action_state_result, - }; - - const STREAM_BUFF_SIZE: usize = 64; - - Ok(Box::pin(unfold(state, move |mut state| async move { - if let Some(sorted_awaited_action) = state.buffer.pop_front() { - if state.buffer.is_empty() { - state.start_key = Bound::Excluded(sorted_awaited_action.clone()); - } - return Some(( - (state.to_action_state_result)(sorted_awaited_action.awaited_action), - state, - )); - } - - let inner = state.inner.lock().await; - - #[allow(clippy::mutable_key_type)] - let btree = get_tree_for_stage(&inner.action_db, state.filter.stages) - .expect("get_tree_for_stage() should have already returned Some but in iteration it returned None"); - - let range = (state.start_key.as_ref(), Bound::Unbounded); - if state.filter.order_by_priority_direction == Some(OrderDirection::Asc) { - btree - .range(range) - .filter(|item| { - apply_filter_predicate(item.awaited_action.as_ref(), &state.filter) - }) - .take(STREAM_BUFF_SIZE) - .for_each(|item| state.buffer.push_back(item.clone())); - } else { - btree - .range(range) - .rev() - .filter(|item| { - apply_filter_predicate(item.awaited_action.as_ref(), &state.filter) - }) - .take(STREAM_BUFF_SIZE) - .for_each(|item| state.buffer.push_back(item.clone())); - } - drop(inner); - let sorted_awaited_action = state.buffer.pop_front()?; - if state.buffer.is_empty() { - state.start_key = Bound::Excluded(sorted_awaited_action.clone()); - } - Some(( - (state.to_action_state_result)(sorted_awaited_action.awaited_action), - state, - )) - }))) - } -} - -#[async_trait] -impl ClientStateManager for MemorySchedulerStateManager { - async fn add_action( - &self, - client_operation_id: ClientOperationId, - action_info: Arc, - ) -> Result, Error> { - let mut inner = self.inner.lock().await; - let rx = inner - .inner_add_operation(client_operation_id.clone(), action_info.clone()) - .await?; - - let inner_weak = Arc::downgrade(&self.inner); - Ok(Arc::new(ClientActionStateResult::new( - action_info, - rx, - Some(make_client_keepalive_spawn(client_operation_id, inner_weak)), - ))) - } - - async fn filter_operations( - &self, - filter: &OperationFilter, - ) -> Result { - let maybe_client_operation_id = filter.client_operation_id.clone(); - let inner_weak = Arc::downgrade(&self.inner); - self.inner_filter_operations(filter, move |awaited_action| { - Arc::new(ClientActionStateResult::new( - awaited_action.get_action_info().clone(), - awaited_action.subscribe(), - maybe_client_operation_id - .as_ref() - .map(|client_operation_id| { - make_client_keepalive_spawn(client_operation_id.clone(), inner_weak.clone()) - }), - )) - }) - .await - } -} - -#[async_trait] -impl WorkerStateManager for MemorySchedulerStateManager { - async fn update_operation( - &self, - operation_id: &OperationId, - worker_id: &WorkerId, - action_stage_result: Result, - ) -> Result<(), Error> { - let mut inner = self.inner.lock().await; - inner.inner_update_operation(operation_id, Some(worker_id), action_stage_result) - } -} - -#[async_trait] -impl MatchingEngineStateManager for MemorySchedulerStateManager { - async fn filter_operations( - &self, - filter: &OperationFilter, - ) -> Result { - self.inner_filter_operations(filter, |awaited_action| { - Arc::new(MatchingEngineActionStateResult::new(awaited_action)) - }) - .await - } - - async fn assign_operation( - &self, - operation_id: &OperationId, - worker_id_or_reason_for_unsassign: Result<&WorkerId, Error>, - ) -> Result<(), Error> { - let mut inner = self.inner.lock().await; - - let (maybe_worker_id, stage_result) = match worker_id_or_reason_for_unsassign { - Ok(worker_id) => (Some(worker_id), Ok(ActionStage::Executing)), - Err(err) => (None, Err(err)), - }; - inner.inner_update_operation(operation_id, maybe_worker_id, stage_result) - } -} diff --git a/nativelink-scheduler/src/memory_scheduler_state/mod.rs b/nativelink-scheduler/src/memory_scheduler_state/mod.rs deleted file mode 100644 index a06f17897..000000000 --- a/nativelink-scheduler/src/memory_scheduler_state/mod.rs +++ /dev/null @@ -1,20 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -mod awaited_action_db; -mod client_action_state_result; -mod matching_engine_action_state_result; -mod memory_scheduler_state_manager; - -pub(crate) use memory_scheduler_state_manager::MemorySchedulerStateManager; diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index a6b36d3c4..57d0dff0e 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -16,28 +16,29 @@ use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; -use futures::{Future, Stream}; +use futures::Future; use nativelink_config::stores::EvictionPolicy; -use nativelink_error::{make_err, Code, Error, ResultExt}; +use nativelink_error::{Error, ResultExt}; use nativelink_util::action_messages::{ ActionInfo, ActionStage, ActionState, ClientOperationId, OperationId, WorkerId, }; use nativelink_util::metrics_utils::Registry; use nativelink_util::operation_state_manager::{ - ActionStateResult, ClientStateManager, MatchingEngineStateManager, OperationFilter, - OperationStageFlags, + ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, + OperationFilter, OperationStageFlags, OrderDirection, }; use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; -use tokio::sync::{watch, Notify}; +use tokio::sync::Notify; use tokio::time::Duration; use tokio_stream::StreamExt; use tracing::{event, Level}; use crate::action_scheduler::{ActionListener, ActionScheduler}; use crate::api_worker_scheduler::ApiWorkerScheduler; -use crate::memory_scheduler_state::MemorySchedulerStateManager; +use crate::memory_awaited_action_db::MemoryAwaitedActionDb; use crate::platform_property_manager::PlatformPropertyManager; +use crate::simple_scheduler_state_manager::SimpleSchedulerStateManager; use crate::worker::{Worker, WorkerTimestamp}; use crate::worker_scheduler::WorkerScheduler; @@ -55,19 +56,17 @@ const DEFAULT_MAX_JOB_RETRIES: usize = 3; struct SimpleSchedulerActionListener { client_operation_id: ClientOperationId, - action_state_result: Arc, - maybe_receiver: Option>>, + action_state_result: Box, } impl SimpleSchedulerActionListener { fn new( client_operation_id: ClientOperationId, - action_state_result: Arc, + action_state_result: Box, ) -> Self { Self { client_operation_id, action_state_result, - maybe_receiver: None, } } } @@ -80,29 +79,13 @@ impl ActionListener for SimpleSchedulerActionListener { fn changed( &mut self, ) -> Pin, Error>> + Send + '_>> { - let action_state_result = self.action_state_result.clone(); Box::pin(async move { - let receiver = match &mut self.maybe_receiver { - Some(receiver) => receiver, - None => { - let mut receiver = action_state_result - .as_receiver() - .await - .err_tip(|| "In SimpleSchedulerActionListener::changed getting receiver")? - .into_owned(); - receiver.mark_changed(); - self.maybe_receiver = Some(receiver.clone()); - self.maybe_receiver.as_mut().unwrap() - } - }; - receiver.changed().await.map_err(|_| { - make_err!( - Code::Internal, - "Sender hungup in SimpleSchedulerActionListener::changed()" - ) - })?; - let result = receiver.borrow().clone(); - Ok(result) + let action_state = self + .action_state_result + .changed() + .await + .err_tip(|| "In SimpleSchedulerActionListener::changed getting receiver")?; + Ok(action_state) }) } } @@ -111,27 +94,32 @@ impl ActionListener for SimpleSchedulerActionListener { /// the worker nodes. All state on how the workers and actions are interacting /// should be held in this struct. pub struct SimpleScheduler { + /// Manager for matching engine side of the state manager. matching_engine_state_manager: Arc, + + /// Manager for client state of this scheduler. client_state_manager: Arc, + /// Manager for platform of this scheduler. platform_property_manager: Arc, - /// Background task that tries to match actions to workers. If this struct - /// is dropped the spawn will be cancelled as well. - _task_worker_matching_spawn: JoinHandleDropGuard<()>, - /// A `Workers` pool that contains all workers that are available to execute actions in a priority /// order based on the allocation strategy. worker_scheduler: Arc, + + /// Background task that tries to match actions to workers. If this struct + /// is dropped the spawn will be cancelled as well. + _task_worker_matching_spawn: JoinHandleDropGuard<()>, } impl SimpleScheduler { /// Attempts to find a worker to execute an action and begins executing it. - /// If an action is already running that is cacheable it may merge this action - /// with the results and state changes of the already running action. - /// If the task cannot be executed immediately it will be queued for execution - /// based on priority and other metrics. - /// All further updates to the action will be provided through `listener`. + /// If an action is already running that is cacheable it may merge this + /// action with the results and state changes of the already running + /// action. If the task cannot be executed immediately it will be queued + /// for execution based on priority and other metrics. + /// All further updates to the action will be provided through the returned + /// value. async fn add_action( &self, client_operation_id: ClientOperationId, @@ -152,13 +140,11 @@ impl SimpleScheduler { &self, client_operation_id: &ClientOperationId, ) -> Result>>, Error> { - let filter_result = self - .client_state_manager - .filter_operations(&OperationFilter { - client_operation_id: Some(client_operation_id.clone()), - ..Default::default() - }) - .await; + let filter = OperationFilter { + client_operation_id: Some(client_operation_id.clone()), + ..Default::default() + }; + let filter_result = self.client_state_manager.filter_operations(filter).await; let mut stream = filter_result .err_tip(|| "In SimpleScheduler::find_by_client_operation_id getting filter result")?; @@ -171,22 +157,21 @@ impl SimpleScheduler { )))) } - async fn get_queued_operations( - &self, - ) -> Result> + Send>>, Error> - { + async fn get_queued_operations(&self) -> Result { + let filter = OperationFilter { + stages: OperationStageFlags::Queued, + order_by_priority_direction: Some(OrderDirection::Desc), + ..Default::default() + }; self.matching_engine_state_manager - .filter_operations(&OperationFilter { - stages: OperationStageFlags::Queued, - ..Default::default() - }) + .filter_operations(filter) .await .err_tip(|| "In SimpleScheduler::get_queued_operations getting filter result") } - // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map - // of capabilities of each worker and then try and match the actions to the worker using - // the map lookup (ie. map reduce). + // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we + // can create a map of capabilities of each worker and then try and match + // the actions to the worker using the map lookup (ie. map reduce). async fn do_try_match(&self) -> Result<(), Error> { async fn match_action_to_worker( action_state_result: &dyn ActionStateResult, @@ -262,12 +247,13 @@ impl SimpleScheduler { ) -> (Arc, Arc) { Self::new_with_callback(scheduler_cfg, || { // The cost of running `do_try_match()` is very high, but constant - // in relation to the number of changes that have happened. This means - // that grabbing this lock to process `do_try_match()` should always - // yield to any other tasks that might want the lock. The easiest and - // most fair way to do this is to sleep for a small amount of time. - // Using something like tokio::task::yield_now() does not yield as - // aggresively as we'd like if new futures are scheduled within a future. + // in relation to the number of changes that have happened. This + // means that grabbing this lock to process `do_try_match()` should + // always yield to any other tasks that might want the lock. The + // easiest and most fair way to do this is to sleep for a small + // amount of time. Using something like tokio::task::yield_now() + // does not yield as aggresively as we'd like if new futures are + // scheduled within a future. tokio::time::sleep(Duration::from_millis(1)) }) } @@ -302,13 +288,13 @@ impl SimpleScheduler { } let tasks_or_worker_change_notify = Arc::new(Notify::new()); - let state_manager = Arc::new(MemorySchedulerStateManager::new( - &EvictionPolicy { - max_seconds: retain_completed_for_s, - ..Default::default() - }, + let state_manager = Arc::new(SimpleSchedulerStateManager::new( tasks_or_worker_change_notify.clone(), max_job_retries, + MemoryAwaitedActionDb::new(&EvictionPolicy { + max_seconds: retain_completed_for_s, + ..Default::default() + }), )); let worker_scheduler = ApiWorkerScheduler::new( diff --git a/nativelink-scheduler/src/simple_scheduler_state_manager.rs b/nativelink-scheduler/src/simple_scheduler_state_manager.rs new file mode 100644 index 000000000..119e63939 --- /dev/null +++ b/nativelink-scheduler/src/simple_scheduler_state_manager.rs @@ -0,0 +1,433 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Bound; +use std::sync::Arc; + +use async_trait::async_trait; +use futures::{future, stream, StreamExt, TryStreamExt}; +use nativelink_error::{make_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ + ActionInfo, ActionResult, ActionStage, ActionState, ActionUniqueQualifier, ClientOperationId, + ExecutionMetadata, OperationId, WorkerId, +}; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, + OperationFilter, OperationStageFlags, OrderDirection, WorkerStateManager, +}; +use tokio::sync::Notify; +use tracing::{event, Level}; + +use super::awaited_action_db::{ + AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedActionState, +}; +use crate::memory_awaited_action_db::{ClientActionStateResult, MatchingEngineActionStateResult}; + +/// Simple struct that implements the ActionStateResult trait and always returns an error. +struct ErrorActionStateResult(Error); + +#[async_trait] +impl ActionStateResult for ErrorActionStateResult { + async fn as_state(&self) -> Result, Error> { + Err(self.0.clone()) + } + + async fn changed(&mut self) -> Result, Error> { + Err(self.0.clone()) + } + + async fn as_action_info(&self) -> Result, Error> { + Err(self.0.clone()) + } +} + +fn apply_filter_predicate(awaited_action: &AwaitedAction, filter: &OperationFilter) -> bool { + // Note: The caller must filter `client_operation_id`. + + if let Some(operation_id) = &filter.operation_id { + if operation_id != awaited_action.operation_id() { + return false; + } + } + + if filter.worker_id.is_some() && filter.worker_id != awaited_action.worker_id() { + return false; + } + + { + if let Some(filter_unique_key) = &filter.unique_key { + match &awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => { + if filter_unique_key != unique_key { + return false; + } + } + ActionUniqueQualifier::Uncachable(_) => { + return false; + } + } + } + if let Some(action_digest) = filter.action_digest { + if action_digest != awaited_action.action_info().digest() { + return false; + } + } + } + + { + let last_worker_update_timestamp = awaited_action.last_worker_updated_timestamp(); + if let Some(worker_update_before) = filter.worker_update_before { + if worker_update_before < last_worker_update_timestamp { + return false; + } + } + if let Some(completed_before) = filter.completed_before { + if awaited_action.state().stage.is_finished() + && completed_before < last_worker_update_timestamp + { + return false; + } + } + if filter.stages != OperationStageFlags::Any { + let stage_flag = match awaited_action.state().stage { + ActionStage::Unknown => OperationStageFlags::Any, + ActionStage::CacheCheck => OperationStageFlags::CacheCheck, + ActionStage::Queued => OperationStageFlags::Queued, + ActionStage::Executing => OperationStageFlags::Executing, + ActionStage::Completed(_) => OperationStageFlags::Completed, + ActionStage::CompletedFromCache(_) => OperationStageFlags::Completed, + }; + if !filter.stages.intersects(stage_flag) { + return false; + } + } + } + + true +} + +/// MemorySchedulerStateManager is responsible for maintaining the state of the scheduler. +/// Scheduler state includes the actions that are queued, active, and recently completed. +/// It also includes the workers that are available to execute actions based on allocation +/// strategy. +pub struct SimpleSchedulerStateManager { + /// Database for storing the state of all actions. + action_db: T, + + /// Notify matching engine that work needs to be done. + tasks_change_notify: Arc, + + /// Maximum number of times a job can be retried. + // TODO(allada) This should be a scheduler decorator instead + // of always having it on every SimpleScheduler. + max_job_retries: usize, +} + +impl SimpleSchedulerStateManager { + pub fn new(tasks_change_notify: Arc, max_job_retries: usize, action_db: T) -> Self { + Self { + action_db, + tasks_change_notify, + max_job_retries, + } + } + + async fn inner_update_operation( + &self, + operation_id: &OperationId, + maybe_worker_id: Option<&WorkerId>, + action_stage_result: Result, + ) -> Result<(), Error> { + let maybe_awaited_action_subscriber = self + .action_db + .get_by_operation_id(operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::update_operation")?; + let awaited_action_subscriber = match maybe_awaited_action_subscriber { + Some(sub) => sub, + // No action found. It is ok if the action was not found. It probably + // means that the action was dropped, but worker was still processing + // it. + None => return Ok(()), + }; + + let mut awaited_action = awaited_action_subscriber.borrow(); + + // Make sure we don't update an action that is already completed. + if awaited_action.state().stage.is_finished() { + return Err(make_err!( + Code::Internal, + "Action {operation_id:?} is already completed with state {:?} - maybe_worker_id: {:?}", + awaited_action.state().stage, + maybe_worker_id, + )); + } + + // Make sure the worker id matches the awaited action worker id. + // This might happen if the worker sending the update is not the + // worker that was assigned. + if awaited_action.worker_id().is_some() + && maybe_worker_id.is_some() + && maybe_worker_id != awaited_action.worker_id().as_ref() + { + let err = make_err!( + Code::Internal, + "Worker ids do not match - {:?} != {:?} for {:?}", + maybe_worker_id, + awaited_action.worker_id(), + awaited_action, + ); + event!( + Level::ERROR, + ?operation_id, + ?maybe_worker_id, + ?awaited_action, + "{}", + err.to_string(), + ); + return Err(err); + } + + let stage = match action_stage_result { + Ok(stage) => stage, + Err(err) => { + // Don't count a backpressure failure as an attempt for an action. + let due_to_backpressure = err.code == Code::ResourceExhausted; + if !due_to_backpressure { + awaited_action.attempts += 1; + } + + if awaited_action.attempts > self.max_job_retries { + ActionStage::Completed(ActionResult { + execution_metadata: ExecutionMetadata { + worker: maybe_worker_id.map_or_else(String::default, |v| v.to_string()), + ..ExecutionMetadata::default() + }, + error: Some(err.clone().merge(make_err!( + Code::Internal, + "Job cancelled because it attempted to execute too many times and failed {}", + format!("for operation_id: {operation_id}, maybe_worker_id: {maybe_worker_id:?}"), + ))), + ..ActionResult::default() + }) + } else { + ActionStage::Queued + } + } + }; + if matches!(stage, ActionStage::Queued) { + // If the action is queued, we need to unset the worker id regardless of + // which worker sent the update. + awaited_action.set_worker_id(None); + } else { + awaited_action.set_worker_id(maybe_worker_id.copied()); + } + awaited_action.set_state(Arc::new(ActionState { + stage, + id: operation_id.clone(), + })); + self.action_db + .update_awaited_action(awaited_action) + .await + .err_tip(|| "In MemorySchedulerStateManager::update_operation")?; + + self.tasks_change_notify.notify_one(); + Ok(()) + } + + async fn inner_add_operation( + &self, + new_client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + let rx = self + .action_db + .add_action(new_client_operation_id, action_info) + .await + .err_tip(|| "In MemorySchedulerStateManager::add_operation")?; + self.tasks_change_notify.notify_one(); + Ok(rx) + } + + async fn inner_filter_operations<'a, F>( + &'a self, + filter: OperationFilter, + to_action_state_result: F, + ) -> Result, Error> + where + F: Fn(T::Subscriber) -> Box + Send + Sync + 'a, + { + fn sorted_awaited_action_state_for_flags( + stage: OperationStageFlags, + ) -> Option { + match stage { + OperationStageFlags::CacheCheck => Some(SortedAwaitedActionState::CacheCheck), + OperationStageFlags::Queued => Some(SortedAwaitedActionState::Queued), + OperationStageFlags::Executing => Some(SortedAwaitedActionState::Executing), + OperationStageFlags::Completed => Some(SortedAwaitedActionState::Completed), + _ => None, + } + } + + if let Some(operation_id) = &filter.operation_id { + return Ok(self + .action_db + .get_by_operation_id(operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")? + .filter(|awaited_action_rx| { + let awaited_action = awaited_action_rx.borrow(); + apply_filter_predicate(&awaited_action, &filter) + }) + .map(|awaited_action| -> ActionStateResultStream { + Box::pin(stream::once(async move { + to_action_state_result(awaited_action) + })) + }) + .unwrap_or_else(|| Box::pin(stream::empty()))); + } + if let Some(client_operation_id) = &filter.client_operation_id { + return Ok(self + .action_db + .get_awaited_action_by_id(client_operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")? + .filter(|awaited_action_rx| { + let awaited_action = awaited_action_rx.borrow(); + apply_filter_predicate(&awaited_action, &filter) + }) + .map(|awaited_action| -> ActionStateResultStream { + Box::pin(stream::once(async move { + to_action_state_result(awaited_action) + })) + }) + .unwrap_or_else(|| Box::pin(stream::empty()))); + } + + let Some(sorted_awaited_action_state) = + sorted_awaited_action_state_for_flags(filter.stages) + else { + let mut all_items: Vec = self + .action_db + .get_all_awaited_actions() + .await + .try_filter(|awaited_action_subscriber| { + future::ready(apply_filter_predicate( + &awaited_action_subscriber.borrow(), + &filter, + )) + }) + .try_collect() + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")?; + match filter.order_by_priority_direction { + Some(OrderDirection::Asc) => { + all_items.sort_unstable_by_key(|a| a.borrow().sort_key()) + } + Some(OrderDirection::Desc) => { + all_items.sort_unstable_by_key(|a| std::cmp::Reverse(a.borrow().sort_key())) + } + None => {} + } + return Ok(Box::pin(stream::iter( + all_items.into_iter().map(to_action_state_result), + ))); + }; + + let desc = matches!( + filter.order_by_priority_direction, + Some(OrderDirection::Desc) + ); + let filter = filter.clone(); + let stream = self + .action_db + .get_range_of_actions( + sorted_awaited_action_state, + Bound::Unbounded, + Bound::Unbounded, + desc, + ) + .await + .try_filter(move |sub| future::ready(apply_filter_predicate(&sub.borrow(), &filter))) + .map(move |result| -> Box { + result.map_or_else( + |e| -> Box { Box::new(ErrorActionStateResult(e)) }, + |v| -> Box { to_action_state_result(v) }, + ) + }); + Ok(Box::pin(stream)) + } +} + +#[async_trait] +impl ClientStateManager for SimpleSchedulerStateManager { + async fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result, Error> { + let sub = self + .inner_add_operation(client_operation_id.clone(), action_info.clone()) + .await?; + + Ok(Box::new(ClientActionStateResult::new(sub))) + } + + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter, move |rx| Box::new(ClientActionStateResult::new(rx))) + .await + } +} + +#[async_trait] +impl WorkerStateManager for SimpleSchedulerStateManager { + async fn update_operation( + &self, + operation_id: &OperationId, + worker_id: &WorkerId, + action_stage_result: Result, + ) -> Result<(), Error> { + self.inner_update_operation(operation_id, Some(worker_id), action_stage_result) + .await + } +} + +#[async_trait] +impl MatchingEngineStateManager for SimpleSchedulerStateManager { + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter, |rx| { + Box::new(MatchingEngineActionStateResult::new(rx)) + }) + .await + } + + async fn assign_operation( + &self, + operation_id: &OperationId, + worker_id_or_reason_for_unsassign: Result<&WorkerId, Error>, + ) -> Result<(), Error> { + let (maybe_worker_id, stage_result) = match worker_id_or_reason_for_unsassign { + Ok(worker_id) => (Some(worker_id), Ok(ActionStage::Executing)), + Err(err) => (None, Err(err)), + }; + self.inner_update_operation(operation_id, maybe_worker_id, stage_result) + .await + } +} diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index 13df66aa2..e2be299ac 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -1582,8 +1582,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> if let Some(real_err) = &mut stage.error { assert!( real_err.to_string().contains("Job cancelled because it attempted to execute too many times and failed"), - "{} did not contain 'Job cancelled because it attempted to execute too many times and failed'", - real_err.to_string(), + "{real_err} did not contain 'Job cancelled because it attempted to execute too many times and failed'", ); *real_err = err; } diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index f6bfc494d..8c8e3d0d7 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -19,6 +19,8 @@ rust_library( "src/fastcdc.rs", "src/fs.rs", "src/health_utils.rs", + "src/chunked_stream.rs", + "src/operation_state_manager.rs", "src/lib.rs", "src/metrics_utils.rs", "src/origin_context.rs", @@ -44,10 +46,12 @@ rust_library( "@crates//:bytes", "@crates//:console-subscriber", "@crates//:futures", + "@crates//:bitflags", "@crates//:hex", "@crates//:hyper", "@crates//:lru", "@crates//:parking_lot", + "@crates//:pin-project", "@crates//:pin-project-lite", "@crates//:prometheus-client", "@crates//:prost", diff --git a/nativelink-util/Cargo.toml b/nativelink-util/Cargo.toml index 2adfe7409..58cf80d66 100644 --- a/nativelink-util/Cargo.toml +++ b/nativelink-util/Cargo.toml @@ -13,6 +13,7 @@ async-trait = "0.1.80" bitflags = "2.5.0" blake3 = { version = "1.5.1", features = ["mmap"] } bytes = "1.6.0" +pin-project = "1.1.5" console-subscriber = { version = "0.3.0" } futures = "0.3.30" hex = "0.4.3" diff --git a/nativelink-util/src/action_messages.rs b/nativelink-util/src/action_messages.rs index 089515769..e2e6351c8 100644 --- a/nativelink-util/src/action_messages.rs +++ b/nativelink-util/src/action_messages.rs @@ -104,7 +104,7 @@ impl Hash for OperationId { impl OperationId { pub fn new(unique_qualifier: ActionUniqueQualifier) -> Self { Self { - id: uuid::Uuid::new_v4(), + id: Uuid::new_v4(), unique_qualifier, } } diff --git a/nativelink-util/src/chunked_stream.rs b/nativelink-util/src/chunked_stream.rs new file mode 100644 index 000000000..e3665562d --- /dev/null +++ b/nativelink-util/src/chunked_stream.rs @@ -0,0 +1,110 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; +use std::ops::Bound; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{Future, Stream}; +use pin_project::pin_project; + +#[pin_project(project = StreamStateProj)] +enum StreamState { + Future(#[pin] Fut), + Next, +} + +/// Takes a range of keys and a function that returns a future that yields +/// an iterator of key-value pairs. The stream will yield all key-value pairs +/// in the range, in order and buffered. A great use case is where you need +/// to implement Stream, but to access the underlying data requires a lock, +/// but API does not require the data to be in sync with data already received. +#[pin_project] +pub struct ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + chunk_fn: F, + buffer: VecDeque, + start_key: Option>, + end_key: Option>, + #[pin] + stream_state: StreamState, +} + +impl ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + pub fn new(start_key: Bound, end_key: Bound, chunk_fn: F) -> Self { + Self { + chunk_fn, + buffer: VecDeque::new(), + start_key: Some(start_key), + end_key: Some(end_key), + stream_state: StreamState::Next, + } + } +} + +impl Stream for ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { + if let Some(item) = this.buffer.pop_front() { + return Poll::Ready(Some(Ok(item))); + } + match this.stream_state.as_mut().project() { + StreamStateProj::Future(fut) => { + match futures::ready!(fut.poll(cx)) { + Ok(Some(((start, end), mut buffer))) => { + *this.start_key = Some(start); + *this.end_key = Some(end); + std::mem::swap(&mut buffer, this.buffer); + } + Ok(None) => return Poll::Ready(None), // End of stream. + Err(err) => return Poll::Ready(Some(Err(err))), + } + this.stream_state.set(StreamState::Next); + // Loop again. + } + StreamStateProj::Next => { + this.buffer.clear(); + // This trick is used to recycle capacity. + let buffer = std::mem::take(this.buffer); + let start_key = this + .start_key + .take() + .expect("start_key should never be None"); + let end_key = this.end_key.take().expect("end_key should never be None"); + let fut = (this.chunk_fn)(start_key, end_key, buffer); + this.stream_state.set(StreamState::Future(fut)); + // Loop again. + } + } + } + } +} diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index e735c86d0..d7e4f7bee 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -14,6 +14,7 @@ pub mod action_messages; pub mod buf_channel; +pub mod chunked_stream; pub mod common; pub mod connection_manager; pub mod default_store_key_subscribe; diff --git a/nativelink-util/src/operation_state_manager.rs b/nativelink-util/src/operation_state_manager.rs index 090b6e8b0..c2a18c041 100644 --- a/nativelink-util/src/operation_state_manager.rs +++ b/nativelink-util/src/operation_state_manager.rs @@ -12,20 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Cow; use std::pin::Pin; use std::sync::Arc; use std::time::SystemTime; -use crate::action_messages::{ - ActionInfo, ActionStage, ActionState, ActionUniqueKey, ClientOperationId, OperationId, WorkerId, -}; -use crate::common::DigestInfo; use async_trait::async_trait; use bitflags::bitflags; use futures::Stream; use nativelink_error::Error; -use tokio::sync::watch; + +use crate::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ClientOperationId, OperationId, WorkerId, +}; +use crate::common::DigestInfo; bitflags! { #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -48,8 +47,8 @@ impl Default for OperationStageFlags { pub trait ActionStateResult: Send + Sync + 'static { // Provides the current state of the action. async fn as_state(&self) -> Result, Error>; - // Subscribes to the state of the action, receiving updates as they are published. - async fn as_receiver(&self) -> Result>>, Error>; + // Waits for the state of the action to change. + async fn changed(&mut self) -> Result, Error>; // Provide result as action info. This behavior will not be supported by all implementations. async fn as_action_info(&self) -> Result, Error>; } @@ -93,26 +92,27 @@ pub struct OperationFilter { pub order_by_priority_direction: Option, } -pub type ActionStateResultStream = Pin> + Send>>; +pub type ActionStateResultStream<'a> = + Pin> + Send + 'a>>; #[async_trait] -pub trait ClientStateManager: Sync + Send + 'static { +pub trait ClientStateManager: Sync + Send { /// Add a new action to the queue or joins an existing action. async fn add_action( &self, client_operation_id: ClientOperationId, action_info: Arc, - ) -> Result, Error>; + ) -> Result, Error>; /// Returns a stream of operations that match the filter. - async fn filter_operations( - &self, - filter: &OperationFilter, - ) -> Result; + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error>; } #[async_trait] -pub trait WorkerStateManager: Sync + Send + 'static { +pub trait WorkerStateManager: Sync + Send { /// Update that state of an operation. /// The worker must also send periodic updates even if the state /// did not change with a modified timestamp in order to prevent @@ -126,12 +126,12 @@ pub trait WorkerStateManager: Sync + Send + 'static { } #[async_trait] -pub trait MatchingEngineStateManager: Sync + Send + 'static { +pub trait MatchingEngineStateManager: Sync + Send { /// Returns a stream of operations that match the filter. - async fn filter_operations( - &self, - filter: &OperationFilter, - ) -> Result; + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error>; /// Assign an operation to a worker or unassign it. async fn assign_operation(