diff --git a/Cargo.lock b/Cargo.lock index 864aed5e5..2405dcc9a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -723,9 +723,9 @@ dependencies = [ [[package]] name = "blake3" -version = "1.5.2" +version = "1.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d08263faac5cde2a4d52b513dadb80846023aade56fcd8fc99ba73ba8050e92" +checksum = "e9ec96fe9a81b5e365f9db71fe00edc4fe4ca2cc7dcb7861f0603012a7caa210" dependencies = [ "arrayref", "arrayvec", @@ -825,9 +825,9 @@ dependencies = [ [[package]] name = "cc" -version = "1.1.2" +version = "1.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "47de7e88bbbd467951ae7f5a6f34f70d1b4d9cfce53d5fd70f74ebe118b3db56" +checksum = "324c74f2155653c90b04f25b2a47a8a631360cb908f92a772695f430c7e31052" [[package]] name = "cfg-if" @@ -1951,7 +1951,6 @@ version = "0.4.0" dependencies = [ "async-lock", "async-trait", - "bitflags 2.6.0", "blake3", "futures", "hashbrown 0.14.5", @@ -1971,6 +1970,7 @@ dependencies = [ "scopeguard", "serde", "serde_json", + "static_assertions", "tokio", "tokio-stream", "tonic", @@ -1982,6 +1982,8 @@ dependencies = [ name = "nativelink-service" version = "0.4.0" dependencies = [ + "async-lock", + "async-trait", "bytes", "futures", "hyper 0.14.30", @@ -2064,6 +2066,7 @@ version = "0.4.0" dependencies = [ "async-lock", "async-trait", + "bitflags 2.6.0", "blake3", "bytes", "console-subscriber", @@ -2078,6 +2081,7 @@ dependencies = [ "nativelink-macro", "nativelink-proto", "parking_lot", + "pin-project", "pin-project-lite", "pretty_assertions", "prometheus-client", @@ -2259,7 +2263,7 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.2", + "redox_syscall 0.5.3", "smallvec", "windows-targets 0.52.6", ] @@ -2660,9 +2664,9 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c82cf8cff14456045f55ec4241383baeff27af886adb72ffb2162f99911de0fd" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" dependencies = [ "bitflags 2.6.0", ] @@ -2962,9 +2966,9 @@ checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "scc" -version = "2.1.2" +version = "2.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af947d0ca10a2f3e00c7ec1b515b7c83e5cb3fa62d4c11a64301d9eec54440e9" +checksum = "a4465c22496331e20eb047ff46e7366455bc01c0c02015c4a376de0b2cd3a1af" dependencies = [ "sdd", ] @@ -2996,9 +3000,9 @@ dependencies = [ [[package]] name = "sdd" -version = "0.2.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84345e4c9bd703274a082fb80caaa99b7612be48dfaa1dd9266577ec412309d" +checksum = "1e806d6633ef141556fef75e345275e35652e9c045bbbc21e6ecfce3e9aa2638" [[package]] name = "seahash" @@ -3022,9 +3026,9 @@ dependencies = [ [[package]] name = "security-framework" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +checksum = "897b2245f0b511c87893af39b033e5ca9cce68824c4d7e7630b5a1d339658d02" dependencies = [ "bitflags 2.6.0", "core-foundation", @@ -3035,9 +3039,9 @@ dependencies = [ [[package]] name = "security-framework-sys" -version = "2.11.0" +version = "2.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +checksum = "75da29fe9b9b08fe9d6b22b5b4bcbc75d8db3aa31e639aa56bb62e9d46bfceaf" dependencies = [ "core-foundation-sys", "libc", @@ -3410,9 +3414,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.38.0" +version = "1.38.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba4f4a02a7a80d6f274636f0aa95c7e383b912d41fe721a31f29e29698585a4a" +checksum = "eb2caba9f80616f438e09748d5acda951967e1ea58508ef53d9c6402485a46df" dependencies = [ "backtrace", "bytes", diff --git a/nativelink-config/src/schedulers.rs b/nativelink-config/src/schedulers.rs index 76f475734..0af053bdf 100644 --- a/nativelink-config/src/schedulers.rs +++ b/nativelink-config/src/schedulers.rs @@ -98,7 +98,7 @@ pub struct SimpleScheduler { /// a WaitExecution is called after the action has completed. /// Default: 60 (seconds) #[serde(default, deserialize_with = "convert_duration_with_shellexpand")] - pub retain_completed_for_s: u64, + pub retain_completed_for_s: u32, /// Remove workers from pool once the worker has not responded in this /// amount of time in seconds. diff --git a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto index 99e191dac..09598d780 100644 --- a/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto +++ b/nativelink-proto/com/github/trace_machina/nativelink/remote_execution/worker_api.proto @@ -99,25 +99,8 @@ message ExecuteResult { /// that initially sent the job as part of the BRE protocol. string instance_name = 6; - /// The original execution digest request for this response. The scheduler knows what it - /// should be, but we do safety checks to ensure it really is the request we expected. - build.bazel.remote.execution.v2.Digest action_digest = 2; - - /// The salt originally sent along with the StartExecute request. This salt is used - /// as a seed for cases where the execution digest should never be cached or merged - /// with other jobs. This salt is added to the hash function used to compute jobs that - /// are running or cached. - uint64 salt = 3; - - // The digest function that was used to compute the action digest - // and all related blobs. - // - // If the digest function used is one of MD5, MURMUR3, SHA1, SHA256, - // SHA384, SHA512, or VSO, the client MAY leave this field unset. In - // that case the server SHOULD infer the digest function using the - // length of the action digest hash and the digest functions announced - // in the server's capabilities. - build.bazel.remote.execution.v2.DigestFunction.Value digest_function = 7; + /// The operation ID that was executed. + string operation_id = 8; /// The actual response data. oneof result { @@ -131,7 +114,7 @@ message ExecuteResult { google.rpc.Status internal_error = 5; } - reserved 8; // NextId. + reserved 9; // NextId. } /// Result sent back from the server when a node connects. @@ -141,10 +124,10 @@ message ConnectionResult { reserved 2; // NextId. } -/// Request to kill a running action sent from the scheduler to a worker. -message KillActionRequest { - /// The the hex encoded unique qualifier for the action to be killed. - string action_id = 1; +/// Request to kill a running operation sent from the scheduler to a worker. +message KillOperationRequest { + /// The the operation id for the operation to be killed. + string operation_id = 1; reserved 2; // NextId. } /// Communication from the scheduler to the worker. @@ -169,8 +152,8 @@ message UpdateForWorker { /// The worker may discard any outstanding work that is being executed. google.protobuf.Empty disconnect = 4; - /// Instructs the worker to kill a specific running action. - KillActionRequest kill_action_request = 5; + /// Instructs the worker to kill a specific running operation. + KillOperationRequest kill_operation_request = 5; } reserved 6; // NextId. } @@ -179,14 +162,14 @@ message StartExecute { /// The action information used to execute job. build.bazel.remote.execution.v2.ExecuteRequest execute_request = 1; - /// See documentation in ExecuteResult::salt. - uint64 salt = 2; + /// Id of the operation. + string operation_id = 4; /// The time at which the command was added to the queue to allow population /// of the ActionResult. google.protobuf.Timestamp queued_timestamp = 3; - reserved 4; // NextId. + reserved 5; // NextId. } /// This is a special message used to save actions into the CAS that can be used diff --git a/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs b/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs index d4e9eae70..268b5e3ce 100644 --- a/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs +++ b/nativelink-proto/genproto/com.github.trace_machina.nativelink.remote_execution.pb.rs @@ -60,31 +60,9 @@ pub struct ExecuteResult { /// / that initially sent the job as part of the BRE protocol. #[prost(string, tag = "6")] pub instance_name: ::prost::alloc::string::String, - /// / The original execution digest request for this response. The scheduler knows what it - /// / should be, but we do safety checks to ensure it really is the request we expected. - #[prost(message, optional, tag = "2")] - pub action_digest: ::core::option::Option< - super::super::super::super::super::build::bazel::remote::execution::v2::Digest, - >, - /// / The salt originally sent along with the StartExecute request. This salt is used - /// / as a seed for cases where the execution digest should never be cached or merged - /// / with other jobs. This salt is added to the hash function used to compute jobs that - /// / are running or cached. - #[prost(uint64, tag = "3")] - pub salt: u64, - /// The digest function that was used to compute the action digest - /// and all related blobs. - /// - /// If the digest function used is one of MD5, MURMUR3, SHA1, SHA256, - /// SHA384, SHA512, or VSO, the client MAY leave this field unset. In - /// that case the server SHOULD infer the digest function using the - /// length of the action digest hash and the digest functions announced - /// in the server's capabilities. - #[prost( - enumeration = "super::super::super::super::super::build::bazel::remote::execution::v2::digest_function::Value", - tag = "7" - )] - pub digest_function: i32, + /// / The operation ID that was executed. + #[prost(string, tag = "8")] + pub operation_id: ::prost::alloc::string::String, /// / The actual response data. #[prost(oneof = "execute_result::Result", tags = "4, 5")] pub result: ::core::option::Option, @@ -116,13 +94,13 @@ pub struct ConnectionResult { #[prost(string, tag = "1")] pub worker_id: ::prost::alloc::string::String, } -/// / Request to kill a running action sent from the scheduler to a worker. +/// / Request to kill a running operation sent from the scheduler to a worker. #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] -pub struct KillActionRequest { - /// / The the hex encoded unique qualifier for the action to be killed. +pub struct KillOperationRequest { + /// / The the operation id for the operation to be killed. #[prost(string, tag = "1")] - pub action_id: ::prost::alloc::string::String, + pub operation_id: ::prost::alloc::string::String, } /// / Communication from the scheduler to the worker. #[allow(clippy::derive_partial_eq_without_eq)] @@ -155,9 +133,9 @@ pub mod update_for_worker { /// / The worker may discard any outstanding work that is being executed. #[prost(message, tag = "4")] Disconnect(()), - /// / Instructs the worker to kill a specific running action. + /// / Instructs the worker to kill a specific running operation. #[prost(message, tag = "5")] - KillActionRequest(super::KillActionRequest), + KillOperationRequest(super::KillOperationRequest), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -168,9 +146,9 @@ pub struct StartExecute { pub execute_request: ::core::option::Option< super::super::super::super::super::build::bazel::remote::execution::v2::ExecuteRequest, >, - /// / See documentation in ExecuteResult::salt. - #[prost(uint64, tag = "2")] - pub salt: u64, + /// / Id of the operation. + #[prost(string, tag = "4")] + pub operation_id: ::prost::alloc::string::String, /// / The time at which the command was added to the queue to allow population /// / of the ActionResult. #[prost(message, optional, tag = "3")] diff --git a/nativelink-scheduler/BUILD.bazel b/nativelink-scheduler/BUILD.bazel index ba4fc8f89..e7251c78e 100644 --- a/nativelink-scheduler/BUILD.bazel +++ b/nativelink-scheduler/BUILD.bazel @@ -10,24 +10,19 @@ rust_library( name = "nativelink-scheduler", srcs = [ "src/action_scheduler.rs", + "src/api_worker_scheduler.rs", + "src/awaited_action_db/awaited_action.rs", + "src/awaited_action_db/mod.rs", "src/cache_lookup_scheduler.rs", + "src/default_action_listener.rs", "src/default_scheduler_factory.rs", "src/grpc_scheduler.rs", "src/lib.rs", - "src/operation_state_manager.rs", + "src/memory_awaited_action_db.rs", "src/platform_property_manager.rs", "src/property_modifier_scheduler.rs", - "src/redis_action_stage.rs", - "src/redis_operation_state.rs", - "src/scheduler_state/awaited_action.rs", - "src/scheduler_state/client_action_state_result.rs", - "src/scheduler_state/completed_action.rs", - "src/scheduler_state/matching_engine_action_state_result.rs", - "src/scheduler_state/metrics.rs", - "src/scheduler_state/mod.rs", - "src/scheduler_state/state_manager.rs", - "src/scheduler_state/workers.rs", "src/simple_scheduler.rs", + "src/simple_scheduler_state_manager.rs", "src/worker.rs", "src/worker_scheduler.rs", ], @@ -42,7 +37,6 @@ rust_library( "//nativelink-store", "//nativelink-util", "@crates//:async-lock", - "@crates//:bitflags", "@crates//:blake3", "@crates//:futures", "@crates//:hashbrown", @@ -55,6 +49,7 @@ rust_library( "@crates//:scopeguard", "@crates//:serde", "@crates//:serde_json", + "@crates//:static_assertions", "@crates//:tokio", "@crates//:tokio-stream", "@crates//:tonic", diff --git a/nativelink-scheduler/Cargo.toml b/nativelink-scheduler/Cargo.toml index 109eb8eb5..9c766f300 100644 --- a/nativelink-scheduler/Cargo.toml +++ b/nativelink-scheduler/Cargo.toml @@ -28,11 +28,11 @@ tokio = { version = "1.37.0", features = ["sync", "rt", "parking_lot"] } tokio-stream = { version = "0.1.15", features = ["sync"] } tonic = { version = "0.11.0", features = ["gzip", "tls"] } tracing = "0.1.40" -bitflags = "2.5.0" redis = { version = "0.25.2", features = ["aio", "tokio", "json"] } serde = "1.0.203" redis-macros = "0.3.0" serde_json = "1.0.117" +static_assertions = "1.1.0" [dev-dependencies] nativelink-macro = { path = "../nativelink-macro" } diff --git a/nativelink-scheduler/src/action_scheduler.rs b/nativelink-scheduler/src/action_scheduler.rs index 5460209d7..7e7ac5f24 100644 --- a/nativelink-scheduler/src/action_scheduler.rs +++ b/nativelink-scheduler/src/action_scheduler.rs @@ -12,16 +12,29 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; +use futures::Future; use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionState}; +use nativelink_util::action_messages::{ActionInfo, ActionState, ClientOperationId}; use nativelink_util::metrics_utils::Registry; -use tokio::sync::watch; use crate::platform_property_manager::PlatformPropertyManager; +/// ActionListener interface is responsible for interfacing with clients +/// that are interested in the state of an action. +pub trait ActionListener: Sync + Send + Unpin { + /// Returns the client operation id. + fn client_operation_id(&self) -> &ClientOperationId; + + /// Waits for the action state to change. + fn changed( + &mut self, + ) -> Pin, Error>> + Send + '_>>; +} + /// ActionScheduler interface is responsible for interactions between the scheduler /// and action related operations. #[async_trait] @@ -35,17 +48,15 @@ pub trait ActionScheduler: Sync + Send + Unpin { /// Adds an action to the scheduler for remote execution. async fn add_action( &self, + client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error>; + ) -> Result>, Error>; /// Find an existing action by its name. - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>>; - - /// Cleans up the cache of recently completed actions. - async fn clean_recently_completed_actions(&self); + client_operation_id: &ClientOperationId, + ) -> Result>>, Error>; /// Register the metrics for the action scheduler. fn register_metrics(self: Arc, _registry: &mut Registry) {} diff --git a/nativelink-scheduler/src/api_worker_scheduler.rs b/nativelink-scheduler/src/api_worker_scheduler.rs new file mode 100644 index 000000000..9eda38012 --- /dev/null +++ b/nativelink-scheduler/src/api_worker_scheduler.rs @@ -0,0 +1,477 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_lock::Mutex; +use lru::LruCache; +use nativelink_config::schedulers::WorkerAllocationStrategy; +use nativelink_error::{error_if, make_err, make_input_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ActionInfo, ActionStage, OperationId, WorkerId}; +use nativelink_util::metrics_utils::{Collector, CollectorState, MetricsComponent, Registry}; +use nativelink_util::operation_state_manager::WorkerStateManager; +use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue}; +use tokio::sync::Notify; +use tonic::async_trait; +use tracing::{event, Level}; + +use crate::platform_property_manager::PlatformPropertyManager; +use crate::worker::{Worker, WorkerTimestamp, WorkerUpdate}; +use crate::worker_scheduler::WorkerScheduler; + +/// A collection of workers that are available to run tasks. +struct ApiWorkerSchedulerImpl { + /// A `LruCache` of workers availabled based on `allocation_strategy`. + workers: LruCache, + + /// The worker state manager. + worker_state_manager: Arc, + /// The allocation strategy for workers. + allocation_strategy: WorkerAllocationStrategy, + /// A channel to notify the matching engine that the worker pool has changed. + worker_change_notify: Arc, +} + +impl ApiWorkerSchedulerImpl { + /// Refreshes the lifetime of the worker with the given timestamp. + fn refresh_lifetime( + &mut self, + worker_id: &WorkerId, + timestamp: WorkerTimestamp, + ) -> Result<(), Error> { + let worker = self.workers.peek_mut(worker_id).ok_or_else(|| { + make_input_err!( + "Worker not found in worker map in refresh_lifetime() {}", + worker_id + ) + })?; + error_if!( + worker.last_update_timestamp > timestamp, + "Worker already had a timestamp of {}, but tried to update it with {}", + worker.last_update_timestamp, + timestamp + ); + worker.last_update_timestamp = timestamp; + Ok(()) + } + + /// Adds a worker to the pool. + /// Note: This function will not do any task matching. + fn add_worker(&mut self, worker: Worker) -> Result<(), Error> { + let worker_id = worker.id; + self.workers.put(worker_id, worker); + + // Worker is not cloneable, and we do not want to send the initial connection results until + // we have added it to the map, or we might get some strange race conditions due to the way + // the multi-threaded runtime works. + let worker = self.workers.peek_mut(&worker_id).unwrap(); + let res = worker + .send_initial_connection_result() + .err_tip(|| "Failed to send initial connection result to worker"); + if let Err(err) = &res { + event!( + Level::ERROR, + ?worker_id, + ?err, + "Worker connection appears to have been closed while adding to pool" + ); + } + self.worker_change_notify.notify_one(); + res + } + + /// Removes worker from pool. + /// Note: The caller is responsible for any rescheduling of any tasks that might be + /// running. + fn remove_worker(&mut self, worker_id: &WorkerId) -> Option { + let result = self.workers.pop(worker_id); + self.worker_change_notify.notify_one(); + result + } + + /// Sets if the worker is draining or not. + async fn set_drain_worker( + &mut self, + worker_id: &WorkerId, + is_draining: bool, + ) -> Result<(), Error> { + let worker = self + .workers + .get_mut(worker_id) + .err_tip(|| format!("Worker {worker_id} doesn't exist in the pool"))?; + worker.is_draining = is_draining; + self.worker_change_notify.notify_one(); + Ok(()) + } + + fn inner_find_worker_for_action( + &self, + platform_properties: &PlatformProperties, + ) -> Option { + let mut workers_iter = self.workers.iter(); + let workers_iter = match self.allocation_strategy { + // Use rfind to get the least recently used that satisfies the properties. + WorkerAllocationStrategy::least_recently_used => workers_iter.rfind(|(_, w)| { + w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties) + }), + // Use find to get the most recently used that satisfies the properties. + WorkerAllocationStrategy::most_recently_used => workers_iter.find(|(_, w)| { + w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties) + }), + }; + workers_iter.map(|(_, w)| &w.id).copied() + } + + async fn update_action( + &mut self, + worker_id: &WorkerId, + operation_id: &OperationId, + action_stage: Result, + ) -> Result<(), Error> { + let worker = self.workers.get_mut(worker_id).err_tip(|| { + format!("Worker {worker_id} does not exist in SimpleScheduler::update_action") + })?; + + // Ensure the worker is supposed to be running the operation. + if !worker.running_action_infos.contains_key(operation_id) { + let err = make_err!( + Code::Internal, + "Operation {operation_id} should not be running on worker {worker_id} in SimpleScheduler::update_action" + ); + return Result::<(), _>::Err(err.clone()) + .merge(self.immediate_evict_worker(worker_id, err).await); + } + + // Update the operation in the worker state manager. + { + let update_operation_res = self + .worker_state_manager + .update_operation(operation_id, worker_id, action_stage.clone()) + .await + .err_tip(|| "in update_operation on SimpleScheduler::update_action"); + if let Err(err) = update_operation_res { + event!( + Level::ERROR, + ?operation_id, + ?worker_id, + ?err, + "Failed to update_operation on update_action" + ); + return Err(err); + } + } + + // We are done if the action is not finished or there was an error. + let is_finished = action_stage + .as_ref() + .map_or_else(|_| true, |action_stage| action_stage.is_finished()); + if !is_finished { + return Ok(()); + } + + // Clear this action from the current worker if finished. + let complete_action_res = { + let was_paused = !worker.can_accept_work(); + + // Note: We need to run this before dealing with backpressure logic. + let complete_action_res = worker.complete_action(operation_id); + + let due_to_backpressure = action_stage + .as_ref() + .map_or_else(|e| e.code == Code::ResourceExhausted, |_| false); + // Only pause if there's an action still waiting that will unpause. + if (was_paused || due_to_backpressure) && worker.has_actions() { + worker.is_paused = true; + } + complete_action_res + }; + + self.worker_change_notify.notify_one(); + + complete_action_res + } + + /// Notifies the specified worker to run the given action and handles errors by evicting + /// the worker if the notification fails. + async fn worker_notify_run_action( + &mut self, + worker_id: WorkerId, + operation_id: OperationId, + action_info: Arc, + ) -> Result<(), Error> { + if let Some(worker) = self.workers.get_mut(&worker_id) { + let notify_worker_result = + worker.notify_update(WorkerUpdate::RunAction((operation_id, action_info.clone()))); + + if notify_worker_result.is_err() { + event!( + Level::WARN, + ?worker_id, + ?action_info, + ?notify_worker_result, + "Worker command failed, removing worker", + ); + + let err = make_err!( + Code::Internal, + "Worker command failed, removing worker {worker_id} -- {notify_worker_result:?}", + ); + + return Result::<(), _>::Err(err.clone()) + .merge(self.immediate_evict_worker(&worker_id, err).await); + } + } else { + event!( + Level::WARN, + ?worker_id, + ?operation_id, + ?action_info, + "Worker not found in worker map in worker_notify_run_action" + ); + } + Ok(()) + } + + /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it. + async fn immediate_evict_worker( + &mut self, + worker_id: &WorkerId, + err: Error, + ) -> Result<(), Error> { + let mut result = Ok(()); + if let Some(mut worker) = self.remove_worker(worker_id) { + // We don't care if we fail to send message to worker, this is only a best attempt. + let _ = worker.notify_update(WorkerUpdate::Disconnect); + for (operation_id, _) in worker.running_action_infos.drain() { + result = result.merge( + self.worker_state_manager + .update_operation(&operation_id, worker_id, Err(err.clone())) + .await, + ); + } + } + // Note: Calling this many time is very cheap, it'll only trigger `do_try_match` once. + // TODO(allada) This should be moved to inside the Workers struct. + self.worker_change_notify.notify_one(); + result + } +} + +pub struct ApiWorkerScheduler { + inner: Mutex, + platform_property_manager: Arc, + + /// Timeout of how long to evict workers if no response in this given amount of time in seconds. + worker_timeout_s: u64, +} + +impl ApiWorkerScheduler { + pub fn new( + worker_state_manager: Arc, + platform_property_manager: Arc, + allocation_strategy: WorkerAllocationStrategy, + worker_change_notify: Arc, + worker_timeout_s: u64, + ) -> Arc { + Arc::new(Self { + inner: Mutex::new(ApiWorkerSchedulerImpl { + workers: LruCache::unbounded(), + worker_state_manager, + allocation_strategy, + worker_change_notify, + }), + platform_property_manager, + worker_timeout_s, + }) + } + + pub async fn worker_notify_run_action( + &self, + worker_id: WorkerId, + operation_id: OperationId, + action_info: Arc, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .worker_notify_run_action(worker_id, operation_id, action_info) + .await + } + + /// Attempts to find a worker that is capable of running this action. + // TODO(blaise.bruer) This algorithm is not very efficient. Simple testing using a tree-like + // structure showed worse performance on a 10_000 worker * 7 properties * 1000 queued tasks + // simulation of worst cases in a single threaded environment. + pub async fn find_worker_for_action( + &self, + platform_properties: &PlatformProperties, + ) -> Option { + let inner = self.inner.lock().await; + inner.inner_find_worker_for_action(platform_properties) + } + + /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests. + #[must_use] + pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool { + let inner = self.inner.lock().await; + inner.workers.contains(worker_id) + } + + /// A unit test function used to send the keep alive message to the worker from the server. + pub async fn send_keep_alive_to_worker_for_test( + &self, + worker_id: &WorkerId, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + let worker = inner.workers.get_mut(worker_id).ok_or_else(|| { + make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) + })?; + worker.keep_alive() + } +} + +#[async_trait] +impl WorkerScheduler for ApiWorkerScheduler { + fn get_platform_property_manager(&self) -> &PlatformPropertyManager { + self.platform_property_manager.as_ref() + } + + async fn add_worker(&self, worker: Worker) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + let worker_id = worker.id; + let result = inner + .add_worker(worker) + .err_tip(|| "Error while adding worker, removing from pool"); + if let Err(err) = result { + return Result::<(), _>::Err(err.clone()) + .merge(inner.immediate_evict_worker(&worker_id, err).await); + } + Ok(()) + } + + async fn update_action( + &self, + worker_id: &WorkerId, + operation_id: &OperationId, + action_stage: Result, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .update_action(worker_id, operation_id, action_stage) + .await + } + + async fn worker_keep_alive_received( + &self, + worker_id: &WorkerId, + timestamp: WorkerTimestamp, + ) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .refresh_lifetime(worker_id, timestamp) + .err_tip(|| "Error refreshing lifetime in worker_keep_alive_received()") + } + + async fn remove_worker(&self, worker_id: &WorkerId) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner + .immediate_evict_worker( + worker_id, + make_err!(Code::Internal, "Received request to remove worker"), + ) + .await + } + + async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + + let mut result = Ok(()); + // Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire + // map most of the time. + let worker_ids_to_remove: Vec = inner + .workers + .iter() + .rev() + .map_while(|(worker_id, worker)| { + if worker.last_update_timestamp <= now_timestamp - self.worker_timeout_s { + Some(*worker_id) + } else { + None + } + }) + .collect(); + for worker_id in &worker_ids_to_remove { + event!( + Level::WARN, + ?worker_id, + "Worker timed out, removing from pool" + ); + result = result.merge( + inner + .immediate_evict_worker( + worker_id, + make_err!( + Code::Internal, + "Worker {worker_id} timed out, removing from pool" + ), + ) + .await, + ); + } + + result + } + + async fn set_drain_worker(&self, worker_id: &WorkerId, is_draining: bool) -> Result<(), Error> { + let mut inner = self.inner.lock().await; + inner.set_drain_worker(worker_id, is_draining).await + } + + fn register_metrics(self: Arc, registry: &mut Registry) { + self.inner + .lock_blocking() + .worker_state_manager + .clone() + .register_metrics(registry); + registry.register_collector(Box::new(Collector::new(&self))); + } +} + +impl MetricsComponent for ApiWorkerScheduler { + fn gather_metrics(&self, c: &mut CollectorState) { + let inner = self.inner.lock_blocking(); + let mut props = HashMap::<&String, u64>::new(); + for (_worker_id, worker) in inner.workers.iter() { + c.publish_with_labels( + "workers", + worker, + "", + vec![("worker_id".into(), worker.id.to_string().into())], + ); + for (property, prop_value) in &worker.platform_properties.properties { + let current_value = props.get(&property).unwrap_or(&0); + if let PlatformPropertyValue::Minimum(worker_value) = prop_value { + props.insert(property, *current_value + *worker_value); + } + } + } + for (property, prop_value) in props { + c.publish( + &format!("{property}_available_properties"), + &prop_value, + format!("Total sum of available properties for {property}"), + ); + } + } +} diff --git a/nativelink-scheduler/src/awaited_action_db/awaited_action.rs b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs new file mode 100644 index 000000000..eff7b3e01 --- /dev/null +++ b/nativelink-scheduler/src/awaited_action_db/awaited_action.rs @@ -0,0 +1,191 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; + +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionState, OperationId, WorkerId, +}; +use static_assertions::{assert_eq_size, const_assert, const_assert_eq}; + +/// The version of the awaited action. +/// This number will always increment by one each time +/// the action is updated. +#[derive(Debug, Clone, Copy)] +struct AwaitedActionVersion(u64); + +/// An action that is being awaited on and last known state. +#[derive(Debug, Clone)] +pub struct AwaitedAction { + /// The current version of the action. + version: AwaitedActionVersion, + + /// The action that is being awaited on. + action_info: Arc, + + /// The operation id of the action. + operation_id: OperationId, + + /// The currentsort key used to order the actions. + sort_key: AwaitedActionSortKey, + + /// The time the action was last updated. + last_worker_updated_timestamp: SystemTime, + + /// Worker that is currently running this action, None if unassigned. + worker_id: Option, + + /// The current state of the action. + state: Arc, + + /// Number of attempts the job has been tried. + pub attempts: usize, +} + +impl AwaitedAction { + pub fn new(operation_id: OperationId, action_info: Arc) -> Self { + let stage = ActionStage::Queued; + let sort_key = AwaitedActionSortKey::new_with_unique_key( + action_info.priority, + &action_info.insert_timestamp, + ); + let state = Arc::new(ActionState { + stage, + id: operation_id.clone(), + }); + Self { + version: AwaitedActionVersion(0), + action_info, + operation_id, + sort_key, + attempts: 0, + last_worker_updated_timestamp: SystemTime::now(), + worker_id: None, + state, + } + } + + pub fn version(&self) -> u64 { + self.version.0 + } + + pub fn increment_version(&mut self) { + self.version = AwaitedActionVersion(self.version.0 + 1); + } + + pub fn action_info(&self) -> &Arc { + &self.action_info + } + + pub fn operation_id(&self) -> &OperationId { + &self.operation_id + } + + pub fn sort_key(&self) -> AwaitedActionSortKey { + self.sort_key + } + + pub fn state(&self) -> &Arc { + &self.state + } + + pub fn worker_id(&self) -> Option { + self.worker_id + } + + pub fn last_worker_updated_timestamp(&self) -> SystemTime { + self.last_worker_updated_timestamp + } + + /// Sets the worker id that is currently processing this action. + pub fn set_worker_id(&mut self, new_maybe_worker_id: Option) { + if self.worker_id != new_maybe_worker_id { + self.worker_id = new_maybe_worker_id; + self.last_worker_updated_timestamp = SystemTime::now(); + } + } + + /// Sets the current state of the action and notifies subscribers. + /// Returns true if the state was set, false if there are no subscribers. + pub fn set_state(&mut self, mut state: Arc) { + std::mem::swap(&mut self.state, &mut state); + self.last_worker_updated_timestamp = SystemTime::now(); + } +} + +/// The key used to sort the awaited actions. +/// +/// The rules for sorting are as follows: +/// 1. priority of the action +/// 2. insert order of the action (lower = higher priority) +/// 3. (mostly random hash based on the action info) +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +#[repr(transparent)] +pub struct AwaitedActionSortKey(u64); + +impl AwaitedActionSortKey { + #[rustfmt::skip] + const fn new(priority: i32, insert_timestamp: u32) -> Self { + // Shift `new_priority` so [`i32::MIN`] is represented by zero. + // This makes it so any nagative values are positive, but + // maintains ordering. + const MIN_I32: i64 = (i32::MIN as i64).abs(); + let priority = ((priority as i64 + MIN_I32) as u32).to_be_bytes(); + + // Invert our timestamp so the larger the timestamp the lower the number. + // This makes timestamp descending order instead of ascending. + let timestamp = (insert_timestamp ^ u32::MAX).to_be_bytes(); + + AwaitedActionSortKey(u64::from_be_bytes([ + priority[0], priority[1], priority[2], priority[3], + timestamp[0], timestamp[1], timestamp[2], timestamp[3], + ])) + } + + fn new_with_unique_key(priority: i32, insert_timestamp: &SystemTime) -> Self { + let timestamp = insert_timestamp + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs() as u32; + Self::new(priority, timestamp) + } +} + +// Ensure the size of the sort key is the same as a `u64`. +assert_eq_size!(AwaitedActionSortKey, u64); + +const_assert_eq!( + AwaitedActionSortKey::new(0x1234_5678, 0x9abc_def0).0, + // Note: Result has 0x12345678 + 0x80000000 = 0x92345678 because we need + // to shift the `i32::MIN` value to be represented by zero. + // Note: `6543210f` are the inverted bits of `9abcdef0`. + // This effectively inverts the priority to now have the highest priority + // be the lowest timestamps. + AwaitedActionSortKey(0x9234_5678_6543_210f).0 +); +// Ensure the priority is used as the sort key first. +const_assert!( + AwaitedActionSortKey::new(i32::MAX, 0).0 > AwaitedActionSortKey::new(i32::MAX - 1, 0).0 +); +const_assert!(AwaitedActionSortKey::new(i32::MAX - 1, 0).0 > AwaitedActionSortKey::new(1, 0).0); +const_assert!(AwaitedActionSortKey::new(1, 0).0 > AwaitedActionSortKey::new(0, 0).0); +const_assert!(AwaitedActionSortKey::new(0, 0).0 > AwaitedActionSortKey::new(-1, 0).0); +const_assert!(AwaitedActionSortKey::new(-1, 0).0 > AwaitedActionSortKey::new(i32::MIN + 1, 0).0); +const_assert!( + AwaitedActionSortKey::new(i32::MIN + 1, 0).0 > AwaitedActionSortKey::new(i32::MIN, 0).0 +); + +// Ensure the insert timestamp is used as the sort key second. +const_assert!(AwaitedActionSortKey::new(0, u32::MIN).0 > AwaitedActionSortKey::new(0, u32::MAX).0); diff --git a/nativelink-scheduler/src/awaited_action_db/mod.rs b/nativelink-scheduler/src/awaited_action_db/mod.rs new file mode 100644 index 000000000..1d3cc623d --- /dev/null +++ b/nativelink-scheduler/src/awaited_action_db/mod.rs @@ -0,0 +1,121 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::cmp; +use std::ops::Bound; +use std::sync::Arc; + +pub use awaited_action::{AwaitedAction, AwaitedActionSortKey}; +use futures::{Future, Stream}; +use nativelink_error::Error; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId, OperationId}; +use nativelink_util::metrics_utils::MetricsComponent; + +mod awaited_action; + +/// A simple enum to represent the state of an AwaitedAction. +#[derive(Debug, Clone, Copy)] +pub enum SortedAwaitedActionState { + CacheCheck, + Queued, + Executing, + Completed, +} + +/// A struct pointing to an AwaitedAction that can be sorted. +#[derive(Debug, Clone)] +pub struct SortedAwaitedAction { + pub sort_key: AwaitedActionSortKey, + pub operation_id: OperationId, +} + +impl PartialEq for SortedAwaitedAction { + fn eq(&self, other: &Self) -> bool { + self.sort_key == other.sort_key && self.operation_id == other.operation_id + } +} + +impl Eq for SortedAwaitedAction {} + +impl PartialOrd for SortedAwaitedAction { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for SortedAwaitedAction { + fn cmp(&self, other: &Self) -> cmp::Ordering { + self.sort_key + .cmp(&other.sort_key) + .then_with(|| self.operation_id.cmp(&other.operation_id)) + } +} + +/// Subscriber that can be used to monitor when AwaitedActions change. +pub trait AwaitedActionSubscriber: Send + Sync + Sized + 'static { + /// Wait for AwaitedAction to change. + fn changed(&mut self) -> impl Future> + Send; + + /// Get the current awaited action. + fn borrow(&self) -> AwaitedAction; +} + +/// A trait that defines the interface for an AwaitedActionDb. +pub trait AwaitedActionDb: Send + Sync + MetricsComponent + 'static { + type Subscriber: AwaitedActionSubscriber; + + /// Get the AwaitedAction by the client operation id. + fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> impl Future, Error>> + Send + Sync; + + /// Get all AwaitedActions. This call should be avoided as much as possible. + fn get_all_awaited_actions( + &self, + ) -> impl Future> + Send + Sync> + + Send + + Sync; + + /// Get the AwaitedAction by the operation id. + fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> impl Future, Error>> + Send + Sync; + + /// Get a range of AwaitedActions of a specific state in sorted order. + fn get_range_of_actions( + &self, + state: SortedAwaitedActionState, + start: Bound, + end: Bound, + desc: bool, + ) -> impl Future> + Send + Sync> + + Send + + Sync; + + /// Process a change changed AwaitedAction and notify any listeners. + fn update_awaited_action( + &self, + new_awaited_action: AwaitedAction, + ) -> impl Future> + Send + Sync; + + /// Add (or join) an action to the AwaitedActionDb and subscribe + /// to changes. + fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> impl Future> + Send + Sync; +} diff --git a/nativelink-scheduler/src/cache_lookup_scheduler.rs b/nativelink-scheduler/src/cache_lookup_scheduler.rs index 18c7a7f4e..2c402f589 100644 --- a/nativelink-scheduler/src/cache_lookup_scheduler.rs +++ b/nativelink-scheduler/src/cache_lookup_scheduler.rs @@ -13,38 +13,44 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; -use futures::stream::StreamExt; -use nativelink_error::Error; +use futures::Future; +use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::{ ActionResult as ProtoActionResult, GetActionResultRequest, }; use nativelink_store::ac_utils::get_and_decode_digest; use nativelink_store::grpc_store::GrpcStore; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, OperationId, + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, OperationId, }; use nativelink_util::background_spawn; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; -use nativelink_util::store_trait::{Store, StoreLike}; +use nativelink_util::store_trait::Store; use parking_lot::{Mutex, MutexGuard}; use scopeguard::guard; -use tokio::select; -use tokio::sync::watch; -use tokio_stream::wrappers::WatchStream; +use tokio::sync::oneshot; use tonic::Request; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; use crate::platform_property_manager::PlatformPropertyManager; /// Actions that are having their cache checked or failed cache lookup and are /// being forwarded upstream. Missing the skip_cache_check actions which are /// forwarded directly. -type CheckActions = HashMap>>>; +type CheckActions = HashMap< + ActionUniqueKey, + Vec<( + ClientOperationId, + oneshot::Sender>, Error>>, + )>, +>; pub struct CacheLookupScheduler { /// A reference to the AC to find existing actions in. @@ -54,7 +60,7 @@ pub struct CacheLookupScheduler { /// in the action cache. action_scheduler: Arc, /// Actions that are currently performing a CacheCheck. - cache_check_actions: Arc>, + inflight_cache_checks: Arc>, } async fn get_action_from_store( @@ -62,7 +68,7 @@ async fn get_action_from_store( action_digest: DigestInfo, instance_name: String, digest_function: DigestHasherFunc, -) -> Option { +) -> Result { // If we are a GrpcStore we shortcut here, as this is a special store. if let Some(grpc_store) = ac_store.downcast_ref::(Some(action_digest.into())) { let action_result_request = GetActionResultRequest { @@ -77,27 +83,42 @@ async fn get_action_from_store( .get_action_result(Request::new(action_result_request)) .await .map(|response| response.into_inner()) - .ok() } else { - get_and_decode_digest::(ac_store, action_digest.into()) - .await - .ok() + get_and_decode_digest::(ac_store, action_digest.into()).await } } +/// Future for when ActionListeners are known. +type ActionListenerOneshot = oneshot::Receiver>, Error>>; + fn subscribe_to_existing_action( - cache_check_actions: &MutexGuard, - unique_qualifier: &ActionInfoHashKey, -) -> Option>> { - cache_check_actions.get(unique_qualifier).map(|tx| { - let current_value = tx.borrow(); - // Subscribe marks the current value as seen, so we have to - // re-send it to all receivers. - // TODO: Fix this when fixed upstream tokio-rs/tokio#5871 - let rx = tx.subscribe(); - let _ = tx.send(current_value.clone()); - rx - }) + inflight_cache_checks: &mut MutexGuard, + unique_qualifier: &ActionUniqueKey, + client_operation_id: &ClientOperationId, +) -> Option { + inflight_cache_checks + .get_mut(unique_qualifier) + .map(|oneshots| { + let (tx, rx) = oneshot::channel(); + oneshots.push((client_operation_id.clone(), tx)); + rx + }) +} +struct CachedActionListener { + client_operation_id: ClientOperationId, + action_state: Arc, +} + +impl ActionListener for CachedActionListener { + fn client_operation_id(&self) -> &ClientOperationId { + &self.client_operation_id + } + + fn changed( + &mut self, + ) -> Pin, Error>> + Send + '_>> { + Box::pin(async { Ok(self.action_state.clone()) }) + } } impl CacheLookupScheduler { @@ -105,7 +126,7 @@ impl CacheLookupScheduler { Ok(Self { ac_store, action_scheduler, - cache_check_actions: Default::default(), + inflight_cache_checks: Default::default(), }) } } @@ -123,117 +144,167 @@ impl ActionScheduler for CacheLookupScheduler { async fn add_action( &self, + client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error> { - let id = OperationId::new(action_info.unique_qualifier.clone()); - if action_info.skip_cache_lookup { - // Cache lookup skipped, forward to the upstream. - return self.action_scheduler.add_action(action_info).await; - } - let mut current_state = Arc::new(ActionState { - id, - stage: ActionStage::CacheCheck, - }); - let (tx, rx) = watch::channel(current_state.clone()); - let tx = Arc::new(tx); - let scope_guard = { - let mut cache_check_actions = self.cache_check_actions.lock(); - // Check this isn't a duplicate request first. - if let Some(rx) = - subscribe_to_existing_action(&cache_check_actions, &action_info.unique_qualifier) - { - return Ok(rx); + ) -> Result>, Error> { + let unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key.clone(), + ActionUniqueQualifier::Uncachable(_) => { + // Cache lookup skipped, forward to the upstream. + return self + .action_scheduler + .add_action(client_operation_id, action_info) + .await; } - cache_check_actions.insert(action_info.unique_qualifier.clone(), tx.clone()); - // In the event we loose the reference to our `scope_guard`, it will remove - // the action from the cache_check_actions map. - let cache_check_actions = self.cache_check_actions.clone(); - let unique_qualifier = action_info.unique_qualifier.clone(); - guard((), move |_| { - cache_check_actions.lock().remove(&unique_qualifier); + }; + + let cache_check_result = { + // Check this isn't a duplicate request first. + let mut inflight_cache_checks = self.inflight_cache_checks.lock(); + subscribe_to_existing_action( + &mut inflight_cache_checks, + &unique_key, + &client_operation_id, + ) + .ok_or_else(move || { + let (action_listener_tx, action_listener_rx) = oneshot::channel(); + inflight_cache_checks.insert( + unique_key.clone(), + vec![(client_operation_id, action_listener_tx)], + ); + // In the event we loose the reference to our `scope_guard`, it will remove + // the action from the inflight_cache_checks map. + let inflight_cache_checks = self.inflight_cache_checks.clone(); + ( + action_listener_rx, + guard((), move |_| { + inflight_cache_checks.lock().remove(&unique_key); + }), + ) }) }; + let (action_listener_rx, scope_guard) = match cache_check_result { + Ok(action_listener_fut) => { + let action_listener = action_listener_fut.await.map_err(|_| { + make_err!( + Code::Internal, + "ActionListener tx hung up in CacheLookupScheduler::add_action" + ) + })?; + return action_listener; + } + Err(client_tx_and_scope_guard) => client_tx_and_scope_guard, + }; let ac_store = self.ac_store.clone(); let action_scheduler = self.action_scheduler.clone(); + let inflight_cache_checks = self.inflight_cache_checks.clone(); // We need this spawn because we are returning a stream and this spawn will populate the stream's data. background_spawn!("cache_lookup_scheduler_add_action", async move { - // If our spawn ever dies, we will remove the action from the cache_check_actions map. + // If our spawn ever dies, we will remove the action from the inflight_cache_checks map. let _scope_guard = scope_guard; + let unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key, + ActionUniqueQualifier::Uncachable(unique_key) => { + event!( + Level::ERROR, + ?action_info, + "ActionInfo::unique_qualifier should be ActionUniqueQualifier::Cachable()" + ); + unique_key + } + }; + // Perform cache check. - let action_digest = current_state.action_digest(); - let instance_name = action_info.instance_name().clone(); - if let Some(action_result) = get_action_from_store( + let instance_name = action_info.unique_qualifier.instance_name().clone(); + let maybe_action_result = get_action_from_store( &ac_store, - *action_digest, + action_info.unique_qualifier.digest(), instance_name, - current_state.id.unique_qualifier.digest_function, + action_info.unique_qualifier.digest_function(), ) - .await - { - match ac_store.has(*action_digest).await { - Ok(Some(_)) => { - Arc::make_mut(&mut current_state).stage = - ActionStage::CompletedFromCache(action_result); - let _ = tx.send(current_state); - return; - } - Err(err) => { - event!( - Level::WARN, - ?err, - "Error while calling `has` on `ac_store` in `CacheLookupScheduler`'s `add_action` function" - ); - } - _ => {} - } - } - // Not in cache, forward to upstream and proxy state. - match action_scheduler.add_action(action_info).await { - Ok(rx) => { - let mut watch_stream = WatchStream::new(rx); - loop { - select!( - Some(action_state) = watch_stream.next() => { - if tx.send(action_state).is_err() { - break; - } - } - _ = tx.closed() => { - break; - } - ) + .await; + match maybe_action_result { + Ok(action_result) => { + let maybe_pending_txs = { + let mut inflight_cache_checks = inflight_cache_checks.lock(); + // We are ready to resolve the in-flight actions. We remove the + // in-flight actions from the map. + inflight_cache_checks.remove(unique_key) + }; + let Some(pending_txs) = maybe_pending_txs else { + return; // Nobody is waiting for this action anymore. + }; + let action_state = Arc::new(ActionState { + id: OperationId::new(action_info.unique_qualifier.clone()), + stage: ActionStage::CompletedFromCache(action_result), + }); + for (client_operation_id, pending_tx) in pending_txs { + // Ignore errors here, as the other end may have hung up. + let _ = pending_tx.send(Ok(Box::pin(CachedActionListener { + client_operation_id, + action_state: action_state.clone(), + }))); } + return; } Err(err) => { - Arc::make_mut(&mut current_state).stage = - ActionStage::Completed(ActionResult { - error: Some(err), - ..Default::default() - }); - let _ = tx.send(current_state); + // NotFound errors just mean we need to execute our action. + if err.code != Code::NotFound { + let err = err.append("In CacheLookupScheduler::add_action"); + let maybe_pending_txs = { + let mut inflight_cache_checks = inflight_cache_checks.lock(); + // We are ready to resolve the in-flight actions. We remove the + // in-flight actions from the map. + inflight_cache_checks.remove(unique_key) + }; + let Some(pending_txs) = maybe_pending_txs else { + return; // Nobody is waiting for this action anymore. + }; + for (_client_operation_id, pending_tx) in pending_txs { + // Ignore errors here, as the other end may have hung up. + let _ = pending_tx.send(Err(err.clone())); + } + return; + } } } + + let maybe_pending_txs = { + let mut inflight_cache_checks = inflight_cache_checks.lock(); + inflight_cache_checks.remove(unique_key) + }; + let Some(pending_txs) = maybe_pending_txs else { + return; // Noone is waiting for this action anymore. + }; + + for (client_operation_id, pending_tx) in pending_txs { + // Ignore errors here, as the other end may have hung up. + let _ = pending_tx.send( + action_scheduler + .add_action(client_operation_id, action_info.clone()) + .await, + ); + } }); - Ok(rx) + action_listener_rx + .await + .map_err(|_| { + make_err!( + Code::Internal, + "ActionListener tx hung up in CacheLookupScheduler::add_action" + ) + })? + .err_tip(|| "In CacheLookupScheduler::add_action") } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - { - let cache_check_actions = self.cache_check_actions.lock(); - if let Some(rx) = subscribe_to_existing_action(&cache_check_actions, unique_qualifier) { - return Some(rx); - } - } - // Cache skipped may be in the upstream scheduler. + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { self.action_scheduler - .find_existing_action(unique_qualifier) + .find_by_client_operation_id(client_operation_id) .await } - - async fn clean_recently_completed_actions(&self) {} } diff --git a/nativelink-scheduler/src/default_action_listener.rs b/nativelink-scheduler/src/default_action_listener.rs new file mode 100644 index 000000000..ec399790a --- /dev/null +++ b/nativelink-scheduler/src/default_action_listener.rs @@ -0,0 +1,77 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::pin::Pin; +use std::sync::Arc; + +use futures::Future; +use nativelink_error::{make_err, Code, Error}; +use nativelink_util::action_messages::{ActionState, ClientOperationId}; +use tokio::sync::watch; + +use crate::action_scheduler::ActionListener; + +/// Simple implementation of ActionListener using tokio's watch. +pub struct DefaultActionListener { + client_operation_id: ClientOperationId, + action_state: watch::Receiver>, +} + +impl DefaultActionListener { + pub fn new( + client_operation_id: ClientOperationId, + mut action_state: watch::Receiver>, + ) -> Self { + action_state.mark_changed(); + Self { + client_operation_id, + action_state, + } + } + + pub async fn changed(&mut self) -> Result, Error> { + self.action_state.changed().await.map_or_else( + |e| { + Err(make_err!( + Code::Internal, + "Sender of ActionState went away unexpectedly - {e:?}" + )) + }, + |()| Ok(self.action_state.borrow_and_update().clone()), + ) + } +} + +impl ActionListener for DefaultActionListener { + fn client_operation_id(&self) -> &ClientOperationId { + &self.client_operation_id + } + + fn changed( + &mut self, + ) -> Pin, Error>> + Send + '_>> { + Box::pin(self.changed()) + } +} + +impl Clone for DefaultActionListener { + fn clone(&self) -> Self { + let mut action_state = self.action_state.clone(); + action_state.mark_changed(); + Self { + client_operation_id: self.client_operation_id.clone(), + action_state, + } + } +} diff --git a/nativelink-scheduler/src/default_scheduler_factory.rs b/nativelink-scheduler/src/default_scheduler_factory.rs index 2dc94ddeb..e5ef716f9 100644 --- a/nativelink-scheduler/src/default_scheduler_factory.rs +++ b/nativelink-scheduler/src/default_scheduler_factory.rs @@ -14,14 +14,11 @@ use std::collections::HashSet; use std::sync::Arc; -use std::time::Duration; use nativelink_config::schedulers::SchedulerConfig; use nativelink_error::{Error, ResultExt}; use nativelink_store::store_manager::StoreManager; -use nativelink_util::background_spawn; use nativelink_util::metrics_utils::Registry; -use tokio::time::interval; use crate::action_scheduler::ActionScheduler; use crate::cache_lookup_scheduler::CacheLookupScheduler; @@ -57,8 +54,8 @@ fn inner_scheduler_factory( ) -> Result { let scheduler: SchedulerFactoryResults = match scheduler_type_cfg { SchedulerConfig::simple(config) => { - let scheduler = Arc::new(SimpleScheduler::new(config)); - (Some(scheduler.clone()), Some(scheduler)) + let (action_scheduler, worker_scheduler) = SimpleScheduler::new(config); + (Some(action_scheduler), Some(worker_scheduler)) } SchedulerConfig::grpc(config) => (Some(Arc::new(GrpcScheduler::new(config)?)), None), SchedulerConfig::cache_lookup(config) => { @@ -88,7 +85,6 @@ fn inner_scheduler_factory( if let Some(scheduler_metrics) = maybe_scheduler_metrics { if let Some(action_scheduler) = &scheduler.0 { - start_cleanup_timer(action_scheduler); // We need a way to prevent our scheduler form having `register_metrics()` called multiple times. // This is the equivalent of grabbing a uintptr_t in C++, storing it in a set, and checking if it's // already been visited. We can't use the Arc's pointer directly because it has two interfaces @@ -109,24 +105,8 @@ fn inner_scheduler_factory( visited_schedulers.insert(worker_scheduler_uintptr); worker_scheduler.clone().register_metrics(scheduler_metrics); } - worker_scheduler.clone().register_metrics(scheduler_metrics); } } Ok(scheduler) } - -fn start_cleanup_timer(action_scheduler: &Arc) { - let weak_scheduler = Arc::downgrade(action_scheduler); - background_spawn!("default_scheduler_factory_cleanup_timer", async move { - let mut ticker = interval(Duration::from_secs(1)); - loop { - ticker.tick().await; - match weak_scheduler.upgrade() { - Some(scheduler) => scheduler.clean_recently_completed_actions().await, - // If we fail to upgrade, our service is probably destroyed, so return. - None => return, - } - } - }); -} diff --git a/nativelink-scheduler/src/grpc_scheduler.rs b/nativelink-scheduler/src/grpc_scheduler.rs index 646cb4e89..8cc8b7775 100644 --- a/nativelink-scheduler/src/grpc_scheduler.rs +++ b/nativelink-scheduler/src/grpc_scheduler.rs @@ -14,6 +14,7 @@ use std::collections::HashMap; use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; @@ -28,9 +29,12 @@ use nativelink_proto::build::bazel::remote::execution::v2::{ }; use nativelink_proto::google::longrunning::Operation; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionState, DEFAULT_EXECUTION_PRIORITY, + ActionInfo, ActionState, ActionUniqueKey, ActionUniqueQualifier, ClientOperationId, + OperationId, DEFAULT_EXECUTION_PRIORITY, }; +use nativelink_util::common::DigestInfo; use nativelink_util::connection_manager::ConnectionManager; +use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::retry::{Retrier, RetryResult}; use nativelink_util::{background_spawn, tls_utils}; use parking_lot::Mutex; @@ -42,7 +46,8 @@ use tokio::time::sleep; use tonic::{Request, Streaming}; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; +use crate::default_action_listener::DefaultActionListener; use crate::platform_property_manager::PlatformPropertyManager; pub struct GrpcScheduler { @@ -112,13 +117,26 @@ impl GrpcScheduler { async fn stream_state( mut result_stream: Streaming, - ) -> Result>, Error> { + ) -> Result>, Error> { if let Some(initial_response) = result_stream .message() .await .err_tip(|| "Recieving response from upstream scheduler")? { - let (tx, rx) = watch::channel(Arc::new(initial_response.try_into()?)); + let client_operation_id = + ClientOperationId::from_raw_string(initial_response.name.clone()); + // Our operation_id is not needed here is just a place holder to recycle existing object. + // The only thing that actually matters is the operation_id. + let operation_id = + OperationId::new(ActionUniqueQualifier::Uncachable(ActionUniqueKey { + instance_name: "dummy_instance_name".to_string(), + digest_function: DigestHasherFunc::Sha256, + digest: DigestInfo::zero_digest(), + })); + let action_state = + ActionState::try_from_operation(initial_response, operation_id.clone()) + .err_tip(|| "In GrpcScheduler::stream_state")?; + let (tx, rx) = watch::channel(Arc::new(action_state)); background_spawn!("grpc_scheduler_stream_state", async move { loop { select!( @@ -135,7 +153,8 @@ impl GrpcScheduler { let Ok(Some(response)) = response else { return; }; - match response.try_into() { + let maybe_action_state = ActionState::try_from_operation(response, operation_id.clone()); + match maybe_action_state { Ok(response) => { if let Err(err) = tx.send(Arc::new(response)) { event!( @@ -158,7 +177,10 @@ impl GrpcScheduler { ) } }); - return Ok(rx); + return Ok(Box::pin(DefaultActionListener::new( + client_operation_id, + rx, + ))); } Err(make_err!( Code::Internal, @@ -218,8 +240,9 @@ impl ActionScheduler for GrpcScheduler { async fn add_action( &self, + _client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error> { + ) -> Result>, Error> { let execution_policy = if action_info.priority == DEFAULT_EXECUTION_PRIORITY { None } else { @@ -227,16 +250,20 @@ impl ActionScheduler for GrpcScheduler { priority: action_info.priority, }) }; + let skip_cache_lookup = match action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(_) => false, + ActionUniqueQualifier::Uncachable(_) => true, + }; let request = ExecuteRequest { instance_name: action_info.instance_name().clone(), - skip_cache_lookup: action_info.skip_cache_lookup, + skip_cache_lookup, action_digest: Some(action_info.digest().into()), execution_policy, // TODO: Get me from the original request, not very important as we ignore it. results_cache_policy: None, digest_function: action_info .unique_qualifier - .digest_function + .digest_function() .proto_digest_func() .into(), }; @@ -257,12 +284,12 @@ impl ActionScheduler for GrpcScheduler { Self::stream_state(result_stream).await } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { let request = WaitExecutionRequest { - name: unique_qualifier.action_name(), + name: client_operation_id.to_string(), }; let result_stream = self .perform_request(request, |request| async move { @@ -270,7 +297,7 @@ impl ActionScheduler for GrpcScheduler { .connection_manager .connection() .await - .err_tip(|| "in find_existing_action()")?; + .err_tip(|| "in find_by_client_operation_id()")?; ExecutionClient::new(channel) .wait_execution(Request::new(request)) .await @@ -279,17 +306,15 @@ impl ActionScheduler for GrpcScheduler { .and_then(|result_stream| Self::stream_state(result_stream.into_inner())) .await; match result_stream { - Ok(result_stream) => Some(result_stream), + Ok(result_stream) => Ok(Some(result_stream)), Err(err) => { event!( Level::WARN, ?err, "Error looking up action with upstream scheduler" ); - None + Ok(None) } } } - - async fn clean_recently_completed_actions(&self) {} } diff --git a/nativelink-scheduler/src/lib.rs b/nativelink-scheduler/src/lib.rs index 0818ce059..32b6a2f63 100644 --- a/nativelink-scheduler/src/lib.rs +++ b/nativelink-scheduler/src/lib.rs @@ -13,15 +13,16 @@ // limitations under the License. pub mod action_scheduler; +pub mod api_worker_scheduler; +mod awaited_action_db; pub mod cache_lookup_scheduler; +pub mod default_action_listener; pub mod default_scheduler_factory; pub mod grpc_scheduler; -pub mod operation_state_manager; +mod memory_awaited_action_db; pub mod platform_property_manager; pub mod property_modifier_scheduler; -pub mod redis_action_stage; -pub mod redis_operation_state; -pub mod scheduler_state; pub mod simple_scheduler; +mod simple_scheduler_state_manager; pub mod worker; pub mod worker_scheduler; diff --git a/nativelink-scheduler/src/memory_awaited_action_db.rs b/nativelink-scheduler/src/memory_awaited_action_db.rs new file mode 100644 index 000000000..10992108c --- /dev/null +++ b/nativelink-scheduler/src/memory_awaited_action_db.rs @@ -0,0 +1,987 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::{BTreeMap, BTreeSet, HashMap}; +use std::ops::{Bound, RangeBounds}; +use std::sync::Arc; +use std::time::{Duration, Instant, SystemTime}; + +use async_lock::Mutex; +use async_trait::async_trait; +use futures::{FutureExt, Stream}; +use nativelink_config::stores::EvictionPolicy; +use nativelink_error::{error_if, make_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, OperationId, +}; +use nativelink_util::chunked_stream::ChunkedStream; +use nativelink_util::evicting_map::{EvictingMap, LenEntry}; +use nativelink_util::metrics_utils::{CollectorState, MetricsComponent}; +use nativelink_util::operation_state_manager::ActionStateResult; +use nativelink_util::spawn; +use nativelink_util::task::JoinHandleDropGuard; +use tokio::sync::{mpsc, watch}; +use tracing::{event, Level}; + +use crate::awaited_action_db::{ + AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedAction, + SortedAwaitedActionState, +}; + +/// Number of events to process per cycle. +const MAX_ACTION_EVENTS_RX_PER_CYCLE: usize = 1024; + +/// Duration to wait before sending client keep alive messages. +const CLIENT_KEEPALIVE_DURATION: Duration = Duration::from_secs(10); + +/// Represents a client that is currently listening to an action. +/// When the client is dropped, it will send the [`AwaitedAction`] to the +/// `drop_tx` if there are other cleanups needed. +#[derive(Debug)] +struct ClientAwaitedAction { + /// The OperationId that the client is listening to. + operation_id: OperationId, + + /// The sender to notify of this struct being dropped. + drop_tx: mpsc::UnboundedSender, +} + +impl ClientAwaitedAction { + pub fn new(operation_id: OperationId, drop_tx: mpsc::UnboundedSender) -> Self { + Self { + operation_id, + drop_tx, + } + } + + pub fn operation_id(&self) -> &OperationId { + &self.operation_id + } +} + +impl Drop for ClientAwaitedAction { + fn drop(&mut self) { + // If we failed to send it means noone is listening. + let _ = self.drop_tx.send(ActionEvent::ClientDroppedOperation( + self.operation_id.clone(), + )); + } +} + +/// Trait to be able to use the EvictingMap with [`ClientAwaitedAction`]. +/// Note: We only use EvictingMap for a time based eviction, which is +/// why the implementation has fixed default values in it. +impl LenEntry for ClientAwaitedAction { + #[inline] + fn len(&self) -> usize { + 0 + } + + #[inline] + fn is_empty(&self) -> bool { + true + } +} + +/// Actions the AwaitedActionsDb needs to process. +pub(crate) enum ActionEvent { + /// A client has sent a keep alive message. + ClientKeepAlive(ClientOperationId), + /// A client has dropped and pointed to OperationId. + ClientDroppedOperation(OperationId), +} + +/// Information required to track an individual client +/// keep alive config and state. +struct ClientKeepAlive { + /// The client operation id. + client_operation_id: ClientOperationId, + /// The last time a keep alive was sent. + last_keep_alive: Instant, + /// The sender to notify of this struct being dropped. + drop_tx: mpsc::UnboundedSender, +} + +/// Subscriber that can be used to monitor when AwaitedActions change. +pub struct MemoryAwaitedActionSubscriber { + /// The receiver to listen for changes. + awaited_action_rx: watch::Receiver, + /// The client operation id and keep alive information. + client_operation_info: Option, +} + +impl MemoryAwaitedActionSubscriber { + pub fn new(mut awaited_action_rx: watch::Receiver) -> Self { + awaited_action_rx.mark_changed(); + Self { + awaited_action_rx, + client_operation_info: None, + } + } + + pub fn new_with_client( + mut awaited_action_rx: watch::Receiver, + client_operation_id: ClientOperationId, + drop_tx: mpsc::UnboundedSender, + ) -> Self { + awaited_action_rx.mark_changed(); + Self { + awaited_action_rx, + client_operation_info: Some(ClientKeepAlive { + client_operation_id, + last_keep_alive: Instant::now(), + drop_tx, + }), + } + } +} + +impl AwaitedActionSubscriber for MemoryAwaitedActionSubscriber { + async fn changed(&mut self) -> Result { + { + let changed_fut = self.awaited_action_rx.changed().map(|r| { + r.map_err(|e| { + make_err!( + Code::Internal, + "Failed to wait for awaited action to change {e:?}" + ) + }) + }); + let Some(client_keep_alive) = self.client_operation_info.as_mut() else { + changed_fut.await?; + return Ok(self.awaited_action_rx.borrow().clone()); + }; + tokio::pin!(changed_fut); + loop { + if client_keep_alive.last_keep_alive.elapsed() > CLIENT_KEEPALIVE_DURATION { + client_keep_alive.last_keep_alive = Instant::now(); + // Failing to send just means our receiver dropped. + let _ = client_keep_alive.drop_tx.send(ActionEvent::ClientKeepAlive( + client_keep_alive.client_operation_id.clone(), + )); + } + tokio::select! { + result = &mut changed_fut => { + result?; + break; + } + _ = tokio::time::sleep(CLIENT_KEEPALIVE_DURATION) => { + // If we haven't received any updates for a while, we should + // let the database know that we are still listening to prevent + // the action from being dropped. + } + + } + } + } + Ok(self.awaited_action_rx.borrow().clone()) + } + + fn borrow(&self) -> AwaitedAction { + self.awaited_action_rx.borrow().clone() + } +} + +pub struct MatchingEngineActionStateResult { + awaited_action_sub: T, +} +impl MatchingEngineActionStateResult { + pub fn new(awaited_action_sub: T) -> Self { + Self { awaited_action_sub } + } +} + +#[async_trait] +impl ActionStateResult for MatchingEngineActionStateResult { + async fn as_state(&self) -> Result, Error> { + Ok(self.awaited_action_sub.borrow().state().clone()) + } + + async fn changed(&mut self) -> Result, Error> { + let awaited_action = self.awaited_action_sub.changed().await.map_err(|e| { + make_err!( + Code::Internal, + "Failed to wait for awaited action to change {e:?}" + ) + })?; + Ok(awaited_action.state().clone()) + } + + async fn as_action_info(&self) -> Result, Error> { + Ok(self.awaited_action_sub.borrow().action_info().clone()) + } +} + +pub(crate) struct ClientActionStateResult { + inner: MatchingEngineActionStateResult, +} + +impl ClientActionStateResult { + pub fn new(sub: T) -> Self { + Self { + inner: MatchingEngineActionStateResult::new(sub), + } + } +} + +#[async_trait] +impl ActionStateResult for ClientActionStateResult { + async fn as_state(&self) -> Result, Error> { + self.inner.as_state().await + } + + async fn changed(&mut self) -> Result, Error> { + self.inner.changed().await + } + + async fn as_action_info(&self) -> Result, Error> { + self.inner.as_action_info().await + } +} + +/// A struct that is used to keep the devloper from trying to +/// return early from a function. +struct NoEarlyReturn; + +#[derive(Default)] +struct SortedAwaitedActions { + unknown: BTreeSet, + cache_check: BTreeSet, + queued: BTreeSet, + executing: BTreeSet, + completed: BTreeSet, +} + +impl SortedAwaitedActions { + fn btree_for_state(&mut self, state: &ActionStage) -> &mut BTreeSet { + match state { + ActionStage::Unknown => &mut self.unknown, + ActionStage::CacheCheck => &mut self.cache_check, + ActionStage::Queued => &mut self.queued, + ActionStage::Executing => &mut self.executing, + ActionStage::Completed(_) => &mut self.completed, + ActionStage::CompletedFromCache(_) => &mut self.completed, + } + } + + fn insert_sort_map_for_stage( + &mut self, + stage: &ActionStage, + sorted_awaited_action: SortedAwaitedAction, + ) -> Result<(), Error> { + let newly_inserted = match stage { + ActionStage::Unknown => self.unknown.insert(sorted_awaited_action.clone()), + ActionStage::CacheCheck => self.cache_check.insert(sorted_awaited_action.clone()), + ActionStage::Queued => self.queued.insert(sorted_awaited_action.clone()), + ActionStage::Executing => self.executing.insert(sorted_awaited_action.clone()), + ActionStage::Completed(_) => self.completed.insert(sorted_awaited_action.clone()), + ActionStage::CompletedFromCache(_) => { + self.completed.insert(sorted_awaited_action.clone()) + } + }; + if !newly_inserted { + return Err(make_err!( + Code::Internal, + "Tried to insert an action that was already in the sorted map. This should never happen. {:?} - {:?}", + stage, + sorted_awaited_action + )); + } + Ok(()) + } + + fn process_state_changes( + &mut self, + old_awaited_action: &AwaitedAction, + new_awaited_action: &AwaitedAction, + ) -> Result<(), Error> { + let btree = self.btree_for_state(&old_awaited_action.state().stage); + let maybe_sorted_awaited_action = btree.take(&SortedAwaitedAction { + sort_key: old_awaited_action.sort_key(), + operation_id: new_awaited_action.operation_id().clone(), + }); + + let Some(sorted_awaited_action) = maybe_sorted_awaited_action else { + return Err(make_err!( + Code::Internal, + "sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync - {} - {:?}", + new_awaited_action.operation_id(), + new_awaited_action, + )); + }; + + self.insert_sort_map_for_stage(&new_awaited_action.state().stage, sorted_awaited_action) + .err_tip(|| "In AwaitedActionDb::update_awaited_action")?; + Ok(()) + } +} + +/// The database for storing the state of all actions. +pub struct AwaitedActionDbImpl { + /// A lookup table to lookup the state of an action by its client operation id. + client_operation_to_awaited_action: + EvictingMap, SystemTime>, + + /// A lookup table to lookup the state of an action by its worker operation id. + operation_id_to_awaited_action: BTreeMap>, + + /// A lookup table to lookup the state of an action by its unique qualifier. + action_info_hash_key_to_awaited_action: HashMap, + + /// A sorted set of [`AwaitedAction`]s. A wrapper is used to perform sorting + /// based on the [`AwaitedActionSortKey`] of the [`AwaitedAction`]. + /// + /// See [`AwaitedActionSortKey`] for more information on the ordering. + sorted_action_info_hash_keys: SortedAwaitedActions, + + /// The number of connected clients for each operation id. + connected_clients_for_operation_id: HashMap, + + /// Where to send notifications about important events related to actions. + action_event_tx: mpsc::UnboundedSender, +} + +impl AwaitedActionDbImpl { + async fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> Result, Error> { + let maybe_client_awaited_action = self + .client_operation_to_awaited_action + .get(client_operation_id) + .await; + let client_awaited_action = match maybe_client_awaited_action { + Some(client_awaited_action) => client_awaited_action, + None => return Ok(None), + }; + + self.operation_id_to_awaited_action + .get(client_awaited_action.operation_id()) + .map(|tx| Some(MemoryAwaitedActionSubscriber::new(tx.subscribe()))) + .ok_or_else(|| { + make_err!( + Code::Internal, + "Failed to get client operation id {client_operation_id:?}" + ) + }) + } + + /// Processes action events that need to be handled by the database. + async fn handle_action_events( + &mut self, + action_events: impl IntoIterator, + ) -> NoEarlyReturn { + for drop_action in action_events.into_iter() { + match drop_action { + ActionEvent::ClientDroppedOperation(operation_id) => { + // Cleanup operation_id_to_awaited_action. + let Some(tx) = self.operation_id_to_awaited_action.remove(&operation_id) else { + event!( + Level::ERROR, + ?operation_id, + "operation_id_to_awaited_action does not have operation_id" + ); + continue; + }; + let connected_clients = match self + .connected_clients_for_operation_id + .remove(&operation_id) + { + Some(connected_clients) => connected_clients - 1, + None => { + event!( + Level::ERROR, + ?operation_id, + "connected_clients_for_operation_id does not have operation_id" + ); + 0 + } + }; + // Note: It is rare to have more than one client listening + // to the same action, so we assume that we are the last + // client and insert it back into the map if we detect that + // there are still clients listening (ie: the happy path + // is operation.connected_clients == 0). + if connected_clients != 0 { + self.operation_id_to_awaited_action + .insert(operation_id.clone(), tx); + self.connected_clients_for_operation_id + .insert(operation_id, connected_clients); + continue; + } + let awaited_action = tx.borrow().clone(); + // Cleanup action_info_hash_key_to_awaited_action if it was marked cached. + match &awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(action_key) => { + let maybe_awaited_action = self + .action_info_hash_key_to_awaited_action + .remove(action_key); + if !awaited_action.state().stage.is_finished() + && maybe_awaited_action.is_none() + { + event!( + Level::ERROR, + ?operation_id, + ?awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", + ); + } + } + ActionUniqueQualifier::Uncachable(_action_key) => { + // This Operation should not be in the hash_key map. + } + } + + // Cleanup sorted_awaited_action. + let sort_key = awaited_action.sort_key(); + let sort_btree_for_state = self + .sorted_action_info_hash_keys + .btree_for_state(&awaited_action.state().stage); + + let maybe_sorted_awaited_action = + sort_btree_for_state.take(&SortedAwaitedAction { + sort_key, + operation_id: operation_id.clone(), + }); + if maybe_sorted_awaited_action.is_none() { + event!( + Level::ERROR, + ?operation_id, + ?sort_key, + "Expected maybe_sorted_awaited_action to have {sort_key:?}", + ); + } + } + ActionEvent::ClientKeepAlive(client_id) => { + let maybe_size = self + .client_operation_to_awaited_action + .size_for_key(&client_id) + .await; + if maybe_size.is_none() { + event!( + Level::ERROR, + ?client_id, + "client_operation_to_awaited_action does not have client_id", + ); + } + } + } + } + NoEarlyReturn + } + + fn get_awaited_actions_range( + &self, + start: Bound<&OperationId>, + end: Bound<&OperationId>, + ) -> impl Iterator { + self.operation_id_to_awaited_action + .range((start, end)) + .map(|(operation_id, tx)| { + ( + operation_id, + MemoryAwaitedActionSubscriber::new(tx.subscribe()), + ) + }) + } + + fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> Option { + self.operation_id_to_awaited_action + .get(operation_id) + .map(|tx| MemoryAwaitedActionSubscriber::new(tx.subscribe())) + } + + fn get_range_of_actions<'a, 'b>( + &'a self, + state: SortedAwaitedActionState, + range: impl RangeBounds + 'b, + ) -> impl DoubleEndedIterator< + Item = Result<(&'a SortedAwaitedAction, MemoryAwaitedActionSubscriber), Error>, + > + 'a { + let btree = match state { + SortedAwaitedActionState::CacheCheck => &self.sorted_action_info_hash_keys.cache_check, + SortedAwaitedActionState::Queued => &self.sorted_action_info_hash_keys.queued, + SortedAwaitedActionState::Executing => &self.sorted_action_info_hash_keys.executing, + SortedAwaitedActionState::Completed => &self.sorted_action_info_hash_keys.completed, + }; + btree.range(range).map(|sorted_awaited_action| { + let operation_id = &sorted_awaited_action.operation_id; + self.get_by_operation_id(operation_id) + .ok_or_else(|| { + make_err!( + Code::Internal, + "Failed to get operation id {}", + operation_id + ) + }) + .map(|subscriber| (sorted_awaited_action, subscriber)) + }) + } + + fn process_state_changes_for_hash_key_map( + action_info_hash_key_to_awaited_action: &mut HashMap, + new_awaited_action: &AwaitedAction, + ) -> Result<(), Error> { + // Do not allow future subscribes if the action is already completed, + // this is the responsibility of the CacheLookupScheduler. + // TODO(allad) Once we land the new scheduler onto main, we can remove this check. + // It makes sense to allow users to subscribe to already completed items. + // This can be changed to `.is_error()` later. + if !new_awaited_action.state().stage.is_finished() { + return Ok(()); + } + match &new_awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(action_key) => { + let maybe_awaited_action = + action_info_hash_key_to_awaited_action.remove(action_key); + match maybe_awaited_action { + Some(removed_operation_id) => { + if &removed_operation_id != new_awaited_action.operation_id() { + event!( + Level::ERROR, + ?removed_operation_id, + ?new_awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action and operation_id_to_awaited_action are out of sync", + ); + } + } + None => { + event!( + Level::ERROR, + ?new_awaited_action, + ?action_key, + "action_info_hash_key_to_awaited_action out of sync, it should have had the unique_key", + ); + } + } + Ok(()) + } + ActionUniqueQualifier::Uncachable(_action_key) => { + // If we are not cachable, the action should not be in the + // hash_key map, so we don't need to process anything in + // action_info_hash_key_to_awaited_action. + Ok(()) + } + } + } + + fn update_awaited_action(&mut self, new_awaited_action: AwaitedAction) -> Result<(), Error> { + let tx = self + .operation_id_to_awaited_action + .get(new_awaited_action.operation_id()) + .ok_or_else(|| { + make_err!( + Code::Internal, + "OperationId does not exist in map in AwaitedActionDb::update_awaited_action" + ) + })?; + { + // Note: It's important to drop old_awaited_action before we call + // send_replace or we will have a deadlock. + let old_awaited_action = tx.borrow(); + + // Do not process changes if the action version is not in sync with + // what the sender based the update on. + if old_awaited_action.version() + 1 != new_awaited_action.version() { + return Err(make_err!( + // From: https://grpc.github.io/grpc/core/md_doc_statuscodes.html + // Use ABORTED if the client should retry at a higher level + // (e.g., when a client-specified test-and-set fails, + // indicating the client should restart a read-modify-write + // sequence) + Code::Aborted, + "{} Expected {:?} but got {:?} for operation_id {:?} - {:?}", + "Tried to update an awaited action with an incorrect version.", + old_awaited_action.version() + 1, + new_awaited_action.version(), + old_awaited_action, + new_awaited_action, + )); + } + + error_if!( + old_awaited_action.action_info().unique_qualifier + != new_awaited_action.action_info().unique_qualifier, + "Unique key changed for operation_id {:?} - {:?} - {:?}", + new_awaited_action.operation_id(), + old_awaited_action.action_info(), + new_awaited_action.action_info(), + ); + let is_same_stage = old_awaited_action + .state() + .stage + .is_same_stage(&new_awaited_action.state().stage); + + if !is_same_stage { + self.sorted_action_info_hash_keys + .process_state_changes(&old_awaited_action, &new_awaited_action)?; + Self::process_state_changes_for_hash_key_map( + &mut self.action_info_hash_key_to_awaited_action, + &new_awaited_action, + )?; + } + } + + // Notify all listeners of the new state and ignore if no one is listening. + // Note: Do not use `.send()` as it will not update the state if all listeners + // are dropped. + let _ = tx.send_replace(new_awaited_action); + + Ok(()) + } + + /// Creates a new [`ClientAwaitedAction`] and a [`watch::Receiver`] to + /// listen for changes. We don't do this in-line because it is important + /// to ALWAYS construct a [`ClientAwaitedAction`] before inserting it into + /// the map. Failing to do so may result in memory leaks. This is because + /// [`ClientAwaitedAction`] implements a drop function that will trigger + /// cleanup of the other maps on drop. + fn make_client_awaited_action( + &mut self, + operation_id: OperationId, + awaited_action: AwaitedAction, + ) -> (Arc, watch::Receiver) { + let (tx, rx) = watch::channel(awaited_action); + let client_awaited_action = Arc::new(ClientAwaitedAction::new( + operation_id.clone(), + self.action_event_tx.clone(), + )); + self.operation_id_to_awaited_action + .insert(operation_id.clone(), tx); + self.connected_clients_for_operation_id + .insert(operation_id.clone(), 1); + (client_awaited_action, rx) + } + + async fn add_action( + &mut self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + // Check to see if the action is already known and subscribe if it is. + let subscription_result = self + .try_subscribe( + &client_operation_id, + &action_info.unique_qualifier, + action_info.priority, + ) + .await + .err_tip(|| "In AwaitedActionDb::subscribe_or_add_action"); + match subscription_result { + Err(err) => return Err(err), + Ok(Some(subscription)) => return Ok(subscription), + Ok(None) => { /* Add item to queue. */ } + } + + let maybe_unique_key = match &action_info.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => Some(unique_key.clone()), + ActionUniqueQualifier::Uncachable(_unique_key) => None, + }; + let operation_id = OperationId::new(action_info.unique_qualifier.clone()); + let awaited_action = AwaitedAction::new(operation_id.clone(), action_info); + debug_assert!( + ActionStage::Queued == awaited_action.state().stage, + "Expected action to be queued" + ); + let sort_key = awaited_action.sort_key(); + + let (client_awaited_action, rx) = + self.make_client_awaited_action(operation_id.clone(), awaited_action); + + self.client_operation_to_awaited_action + .insert(client_operation_id.clone(), client_awaited_action) + .await; + + // Note: We only put items in the map that are cachable. + if let Some(unique_key) = maybe_unique_key { + let old_value = self + .action_info_hash_key_to_awaited_action + .insert(unique_key, operation_id.clone()); + if let Some(old_value) = old_value { + event!( + Level::ERROR, + ?operation_id, + ?old_value, + "action_info_hash_key_to_awaited_action already has unique_key" + ); + } + } + + self.sorted_action_info_hash_keys + .insert_sort_map_for_stage( + &ActionStage::Queued, + SortedAwaitedAction { + sort_key, + operation_id, + }, + ) + .err_tip(|| "In AwaitedActionDb::subscribe_or_add_action")?; + + Ok(MemoryAwaitedActionSubscriber::new_with_client( + rx, + client_operation_id, + self.action_event_tx.clone(), + )) + } + + async fn try_subscribe( + &mut self, + client_operation_id: &ClientOperationId, + unique_qualifier: &ActionUniqueQualifier, + // TODO(allada) To simplify the scheduler 2024 refactor, we + // removed the ability to upgrade priorities of actions. + // we should add priority upgrades back in. + _priority: i32, + ) -> Result, Error> { + let unique_key = match unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => unique_key, + ActionUniqueQualifier::Uncachable(_unique_key) => return Ok(None), + }; + + let Some(operation_id) = self.action_info_hash_key_to_awaited_action.get(unique_key) else { + return Ok(None); // Not currently running. + }; + + let Some(tx) = self.operation_id_to_awaited_action.get(operation_id) else { + return Err(make_err!( + Code::Internal, + "operation_id_to_awaited_action and action_info_hash_key_to_awaited_action are out of sync for {unique_key:?} - {operation_id}" + )); + }; + + error_if!( + tx.borrow().state().stage.is_finished(), + "Tried to subscribe to a completed action but it already finished. This should never happen. {:?}", + tx.borrow() + ); + + let maybe_connected_clients = self + .connected_clients_for_operation_id + .get_mut(operation_id); + let Some(connected_clients) = maybe_connected_clients else { + return Err(make_err!( + Code::Internal, + "connected_clients_for_operation_id and operation_id_to_awaited_action are out of sync for {unique_key:?} - {operation_id}" + )); + }; + *connected_clients += 1; + + let subscription = tx.subscribe(); + + self.client_operation_to_awaited_action + .insert( + client_operation_id.clone(), + Arc::new(ClientAwaitedAction::new( + operation_id.clone(), + self.action_event_tx.clone(), + )), + ) + .await; + + Ok(Some(MemoryAwaitedActionSubscriber::new(subscription))) + } +} + +pub struct MemoryAwaitedActionDb { + inner: Arc>, + _handle_awaited_action_events: JoinHandleDropGuard<()>, +} + +impl MemoryAwaitedActionDb { + pub fn new(eviction_config: &EvictionPolicy) -> Self { + let (action_event_tx, mut action_event_rx) = mpsc::unbounded_channel(); + let inner = Arc::new(Mutex::new(AwaitedActionDbImpl { + client_operation_to_awaited_action: EvictingMap::new( + eviction_config, + SystemTime::now(), + ), + operation_id_to_awaited_action: BTreeMap::new(), + action_info_hash_key_to_awaited_action: HashMap::new(), + sorted_action_info_hash_keys: SortedAwaitedActions::default(), + connected_clients_for_operation_id: HashMap::new(), + action_event_tx, + })); + let weak_inner = Arc::downgrade(&inner); + Self { + inner, + _handle_awaited_action_events: spawn!("handle_awaited_action_events", async move { + let mut dropped_operation_ids = Vec::with_capacity(MAX_ACTION_EVENTS_RX_PER_CYCLE); + loop { + dropped_operation_ids.clear(); + action_event_rx + .recv_many(&mut dropped_operation_ids, MAX_ACTION_EVENTS_RX_PER_CYCLE) + .await; + let Some(inner) = weak_inner.upgrade() else { + return; // Nothing to cleanup, our struct is dropped. + }; + let mut inner = inner.lock().await; + inner + .handle_action_events(dropped_operation_ids.drain(..)) + .await; + } + }), + } + } +} + +impl AwaitedActionDb for MemoryAwaitedActionDb { + type Subscriber = MemoryAwaitedActionSubscriber; + + async fn get_awaited_action_by_id( + &self, + client_operation_id: &ClientOperationId, + ) -> Result, Error> { + self.inner + .lock() + .await + .get_awaited_action_by_id(client_operation_id) + .await + } + + async fn get_all_awaited_actions(&self) -> impl Stream> { + ChunkedStream::new( + Bound::Unbounded, + Bound::Unbounded, + move |start, end, mut output| async move { + let inner = self.inner.lock().await; + let mut maybe_new_start = None; + + for (operation_id, item) in + inner.get_awaited_actions_range(start.as_ref(), end.as_ref()) + { + output.push_back(item); + maybe_new_start = Some(operation_id); + } + + Ok(maybe_new_start + .map(|new_start| ((Bound::Excluded(new_start.clone()), end), output))) + }, + ) + } + + async fn get_by_operation_id( + &self, + operation_id: &OperationId, + ) -> Result, Error> { + Ok(self.inner.lock().await.get_by_operation_id(operation_id)) + } + + async fn get_range_of_actions( + &self, + state: SortedAwaitedActionState, + start: Bound, + end: Bound, + desc: bool, + ) -> impl Stream> + Send + Sync { + ChunkedStream::new(start, end, move |start, end, mut output| async move { + let inner = self.inner.lock().await; + let mut done = true; + let mut new_start = start.as_ref(); + let mut new_end = end.as_ref(); + + let iterator = inner.get_range_of_actions(state, (start.as_ref(), end.as_ref())); + // TODO(allada) This should probably use the `.left()/right()` pattern, + // but that doesn't exist in the std or any libraries we use. + if desc { + for result in iterator.rev() { + let (sorted_awaited_action, item) = + result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + output.push_back(item); + new_end = Bound::Excluded(sorted_awaited_action); + done = false; + } + } else { + for result in iterator { + let (sorted_awaited_action, item) = + result.err_tip(|| "In AwaitedActionDb::get_range_of_actions")?; + output.push_back(item); + new_start = Bound::Excluded(sorted_awaited_action); + done = false; + } + } + if done { + return Ok(None); + } + Ok(Some(((new_start.cloned(), new_end.cloned()), output))) + }) + } + + async fn update_awaited_action(&self, new_awaited_action: AwaitedAction) -> Result<(), Error> { + self.inner + .lock() + .await + .update_awaited_action(new_awaited_action) + } + + async fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + self.inner + .lock() + .await + .add_action(client_operation_id, action_info) + .await + } +} + +impl MetricsComponent for MemoryAwaitedActionDb { + fn gather_metrics(&self, c: &mut CollectorState) { + let inner = self.inner.lock_blocking(); + c.publish( + "action_state_unknown_total", + &inner.sorted_action_info_hash_keys.unknown.len(), + "Number of actions wih the current state of unknown.", + ); + c.publish( + "action_state_cache_check_total", + &inner.sorted_action_info_hash_keys.cache_check.len(), + "Number of actions wih the current state of cache_check.", + ); + c.publish( + "action_state_queued_total", + &inner.sorted_action_info_hash_keys.queued.len(), + "Number of actions wih the current state of queued.", + ); + c.publish( + "action_state_executing_total", + &inner.sorted_action_info_hash_keys.executing.len(), + "Number of actions wih the current state of executing.", + ); + c.publish( + "action_state_completed_total", + &inner.sorted_action_info_hash_keys.completed.len(), + "Number of actions wih the current state of completed.", + ); + // TODO(allada) This is legacy and should be removed in the future. + c.publish( + "active_actions_total", + &inner.sorted_action_info_hash_keys.executing.len(), + "(LEGACY) The number of running actions.", + ); + // TODO(allada) This is legacy and should be removed in the future. + c.publish( + "queued_actions_total", + &inner.sorted_action_info_hash_keys.queued.len(), + "(LEGACY) The number actions in the queue.", + ); + } +} diff --git a/nativelink-scheduler/src/property_modifier_scheduler.rs b/nativelink-scheduler/src/property_modifier_scheduler.rs index c7b827426..c06bfc610 100644 --- a/nativelink-scheduler/src/property_modifier_scheduler.rs +++ b/nativelink-scheduler/src/property_modifier_scheduler.rs @@ -14,17 +14,17 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; use nativelink_config::schedulers::{PropertyModification, PropertyType}; use nativelink_error::{Error, ResultExt}; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionState}; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId}; use nativelink_util::metrics_utils::Registry; use parking_lot::Mutex; -use tokio::sync::watch; -use crate::action_scheduler::ActionScheduler; +use crate::action_scheduler::{ActionListener, ActionScheduler}; use crate::platform_property_manager::PlatformPropertyManager; pub struct PropertyModifierScheduler { @@ -90,10 +90,11 @@ impl ActionScheduler for PropertyModifierScheduler { async fn add_action( &self, + client_operation_id: ClientOperationId, mut action_info: ActionInfo, - ) -> Result>, Error> { + ) -> Result>, Error> { let platform_property_manager = self - .get_platform_property_manager(&action_info.unique_qualifier.instance_name) + .get_platform_property_manager(action_info.unique_qualifier.instance_name()) .await .err_tip(|| "In PropertyModifierScheduler::add_action")?; for modification in &self.modifications { @@ -111,18 +112,18 @@ impl ActionScheduler for PropertyModifierScheduler { } }; } - self.scheduler.add_action(action_info).await + self.scheduler + .add_action(client_operation_id, action_info) + .await } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - self.scheduler.find_existing_action(unique_qualifier).await - } - - async fn clean_recently_completed_actions(&self) { - self.scheduler.clean_recently_completed_actions().await + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { + self.scheduler + .find_by_client_operation_id(client_operation_id) + .await } // Register metrics for the underlying ActionScheduler. diff --git a/nativelink-scheduler/src/redis_action_stage.rs b/nativelink-scheduler/src/redis_action_stage.rs deleted file mode 100644 index 3176c7324..000000000 --- a/nativelink-scheduler/src/redis_action_stage.rs +++ /dev/null @@ -1,78 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use nativelink_error::{make_input_err, Error, ResultExt}; -use nativelink_util::action_messages::{ActionResult, ActionStage}; -use serde::{Deserialize, Serialize}; - -use crate::operation_state_manager::OperationStageFlags; - -#[derive(PartialEq, Debug, Clone, Serialize, Deserialize)] -pub enum RedisOperationStage { - CacheCheck, - Queued, - Executing, - Completed(ActionResult), - CompletedFromCache(ActionResult), -} - -impl RedisOperationStage { - pub fn as_state_flag(&self) -> OperationStageFlags { - match self { - Self::CacheCheck => OperationStageFlags::CacheCheck, - Self::Executing => OperationStageFlags::Executing, - Self::Queued => OperationStageFlags::Queued, - Self::Completed(_) => OperationStageFlags::Completed, - Self::CompletedFromCache(_) => OperationStageFlags::Completed, - } - } -} - -impl TryFrom for RedisOperationStage { - type Error = Error; - fn try_from(stage: ActionStage) -> Result { - match stage { - ActionStage::CacheCheck => Ok(RedisOperationStage::CacheCheck), - ActionStage::Queued => Ok(RedisOperationStage::Queued), - ActionStage::Executing => Ok(RedisOperationStage::Executing), - ActionStage::Completed(result) => Ok(RedisOperationStage::Completed(result)), - ActionStage::CompletedFromCache(proto_result) => { - let decoded = ActionResult::try_from(proto_result) - .err_tip(|| "In RedisOperationStage::try_from::")?; - Ok(RedisOperationStage::Completed(decoded)) - } - ActionStage::Unknown => Err(make_input_err!("ActionStage conversion to RedisOperationStage failed with Error - Unknown is not a valid OperationStage")), - } - } -} - -impl From for ActionStage { - fn from(stage: RedisOperationStage) -> ActionStage { - match stage { - RedisOperationStage::CacheCheck => ActionStage::CacheCheck, - RedisOperationStage::Queued => ActionStage::Queued, - RedisOperationStage::Executing => ActionStage::Executing, - RedisOperationStage::Completed(result) => ActionStage::Completed(result), - RedisOperationStage::CompletedFromCache(result) => { - ActionStage::CompletedFromCache(result.into()) - } - } - } -} - -impl From<&RedisOperationStage> for ActionStage { - fn from(stage: &RedisOperationStage) -> Self { - stage.clone().into() - } -} diff --git a/nativelink-scheduler/src/redis_operation_state.rs b/nativelink-scheduler/src/redis_operation_state.rs deleted file mode 100644 index 5dd4c13d0..000000000 --- a/nativelink-scheduler/src/redis_operation_state.rs +++ /dev/null @@ -1,465 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::str::FromStr; -use std::sync::Arc; -use std::time::SystemTime; - -use futures::{join, StreamExt}; -use nativelink_error::{make_input_err, Error, ResultExt}; -use nativelink_store::redis_store::RedisStore; -use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionStage, ActionState, OperationId, WorkerId, -}; -use nativelink_util::buf_channel::make_buf_channel_pair; -use nativelink_util::spawn; -use nativelink_util::store_trait::{StoreDriver, StoreLike, StoreSubscription}; -use nativelink_util::task::JoinHandleDropGuard; -use redis::aio::{ConnectionLike, ConnectionManager}; -use redis::{AsyncCommands, Pipeline}; -use redis_macros::{FromRedisValue, ToRedisArgs}; -use serde::{Deserialize, Serialize}; -use tokio::sync::watch; -use tonic::async_trait; -use tracing::{event, Level}; - -use crate::operation_state_manager::{ - ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, - OperationFilter, WorkerStateManager, -}; -use crate::redis_action_stage::RedisOperationStage; - -#[inline] -fn build_action_key(unique_qualifier: &ActionInfoHashKey) -> String { - format!("actions:{}", unique_qualifier.action_name()) -} - -#[inline] -fn build_operations_key(operation_id: &OperationId) -> String { - format!("operations:{operation_id}") -} - -pub struct RedisOperationState { - rx: watch::Receiver>, - inner: Arc, - _join_handle: JoinHandleDropGuard<()>, -} - -impl RedisOperationState { - fn new( - inner: Arc, - mut operation_subscription: Box, - ) -> Self { - let (tx, rx) = watch::channel(inner.as_state()); - - let _join_handle = spawn!("redis_subscription_watcher", async move { - loop { - let Ok(item) = operation_subscription.changed().await else { - // This might occur if the store subscription is dropped - // or if there is an error fetching the data. - return; - }; - let (mut data_tx, mut data_rx) = make_buf_channel_pair(); - let (get_res, data_res) = join!( - // We use async move because we want to transfer ownership of data_tx into the closure. - // That way if join! selects data_rx.consume(None) because get fails, - // data_tx goes out of scope and will be dropped. - async move { item.get(&mut data_tx).await }, - data_rx.consume(None) - ); - - let res = get_res - .merge(data_res) - .and_then(|data| { - RedisOperation::from_slice(&data[..]) - .err_tip(|| "Error while Publishing RedisSubscription") - }) - .map(|redis_operation| { - tx.send_modify(move |cur_state| *cur_state = redis_operation.as_state()) - }); - if let Err(e) = res { - // TODO: Refactor API to allow error to be propogated to client. - event!( - Level::ERROR, - ?e, - "Error During Redis Operation Subscription", - ); - return; - } - } - }); - Self { - rx, - _join_handle, - inner, - } - } -} - -#[async_trait] -impl ActionStateResult for RedisOperationState { - async fn as_state(&self) -> Result, Error> { - Ok(Arc::new(ActionState::from(self.inner.as_ref()))) - } - - async fn as_receiver(&self) -> Result<&'_ watch::Receiver>, Error> { - Ok(&self.rx) - } - - async fn as_action_info(&self) -> Result, Error> { - Ok(Arc::new(self.inner.info.clone())) - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, ToRedisArgs, FromRedisValue)] -pub struct RedisOperation { - operation_id: OperationId, - info: ActionInfo, - worker_id: Option, - stage: RedisOperationStage, - last_worker_update: Option, - last_client_update: Option, - last_error: Option, - completed_at: Option, -} - -impl RedisOperation { - pub fn as_json(&self) -> String { - serde_json::json!(&self).to_string() - } - - pub fn from_slice(s: &[u8]) -> Result { - serde_json::from_slice(s).map_err(|e| { - make_input_err!("Create RedisOperation from slice failed with Error - {e:?}") - }) - } - - pub fn new(info: ActionInfo, operation_id: OperationId) -> Self { - Self { - operation_id, - info, - worker_id: None, - stage: RedisOperationStage::CacheCheck, - last_worker_update: None, - last_client_update: None, - last_error: None, - completed_at: None, - } - } - - pub fn from_existing(existing: RedisOperation, operation_id: OperationId) -> Self { - Self { - operation_id, - info: existing.info, - worker_id: existing.worker_id, - stage: existing.stage, - last_worker_update: existing.last_worker_update, - last_client_update: existing.last_client_update, - last_error: existing.last_error, - completed_at: existing.completed_at, - } - } - - pub fn as_state(&self) -> Arc { - let action_state = ActionState { - stage: self.stage.clone().into(), - id: self.operation_id.clone(), - }; - Arc::new(action_state) - } - - pub fn unique_qualifier(&self) -> &ActionInfoHashKey { - &self.operation_id.unique_qualifier - } - - pub fn matches_filter(&self, filter: &OperationFilter) -> bool { - // If the filter value is None, we can match anything and return true. - // If the filter value is Some and the value is None, it can't be a match so we return false. - // If both values are Some, we compare to determine if there is a match. - let matches_stage_filter = filter.stages.contains(self.stage.as_state_flag()); - if !matches_stage_filter { - return false; - } - - let matches_operation_filter = filter - .operation_id - .as_ref() - .map_or(true, |id| &self.operation_id == id); - if !matches_operation_filter { - return false; - } - - let matches_worker_filter = self.worker_id == filter.worker_id; - if !matches_worker_filter { - return false; - }; - - let matches_digest_filter = filter - .action_digest - .map_or(true, |digest| self.unique_qualifier().digest == digest); - if !matches_digest_filter { - return false; - }; - - let matches_completed_before = filter.completed_before.map_or(true, |before| { - self.completed_at - .map_or(false, |completed_at| completed_at < before) - }); - if !matches_completed_before { - return false; - }; - - let matches_last_update = filter.last_client_update_before.map_or(true, |before| { - self.last_client_update - .map_or(false, |last_update| last_update < before) - }); - if !matches_last_update { - return false; - }; - - true - } -} - -impl FromStr for RedisOperation { - type Err = Error; - fn from_str(s: &str) -> Result { - serde_json::from_str(s).map_err(|e| { - make_input_err!( - "Decode string {s} to RedisOperation failed with error: {}", - e.to_string() - ) - }) - } -} - -impl From<&RedisOperation> for ActionState { - fn from(value: &RedisOperation) -> Self { - ActionState { - id: value.operation_id.clone(), - stage: value.stage.clone().into(), - } - } -} - -pub struct RedisStateManager< - T: ConnectionLike + Unpin + Clone + Send + Sync + 'static = ConnectionManager, -> { - store: Arc>, -} - -impl RedisStateManager { - pub fn new(store: Arc>) -> Self { - Self { store } - } - - pub async fn get_conn(&self) -> Result { - self.store.get_conn().await - } - - async fn list<'a, V>( - &self, - prefix: &str, - handler: impl Fn(String, String) -> Result, - ) -> Result, Error> - where - V: Send + Sync, - { - let mut con = self - .get_conn() - .await - .err_tip(|| "In RedisStateManager::list")?; - let ids_iter = con - .scan_match::<&str, String>(prefix) - .await - .err_tip(|| "In RedisStateManager::list")?; - let keys = ids_iter.collect::>().await; - let raw_values: Vec = con - .get(&keys) - .await - .err_tip(|| "In RedisStateManager::list")?; - keys.into_iter() - .zip(raw_values.into_iter()) - .map(|(k, v)| handler(k, v)) - .collect() - } - - async fn inner_add_action( - &self, - action_info: ActionInfo, - ) -> Result, Error> { - let operation_id = OperationId::new(action_info.unique_qualifier.clone()); - let mut con = self - .get_conn() - .await - .err_tip(|| "In RedisStateManager::inner_add_action")?; - let action_key = build_action_key(&operation_id.unique_qualifier); - // TODO: List API call to find existing actions. - let mut existing_operations: Vec = Vec::new(); - let operation = match existing_operations.pop() { - Some(existing_operation) => { - let operations_key = build_operations_key(&existing_operation); - let operation: RedisOperation = con - .get(operations_key) - .await - .err_tip(|| "In RedisStateManager::inner_add_action")?; - RedisOperation::from_existing(operation.clone(), operation_id.clone()) - } - None => RedisOperation::new(action_info, operation_id.clone()), - }; - - let operation_key = build_operations_key(&operation_id); - - // The values being stored in redis are pretty small so we can do our uploads as oneshots. - // We do not parallelize these uploads since we should always upload an operation followed by the action, - let store = self.store.as_store_driver_pin(); - store - .update_oneshot(operation_key.clone().into(), operation.as_json().into()) - .await - .err_tip(|| "In RedisStateManager::inner_add_action")?; - store - .update_oneshot(action_key.into(), operation_id.to_string().into()) - .await - .err_tip(|| "In RedisStateManager::inner_add_action")?; - - let store_subscription = self.store.clone().subscribe(operation_key.into()).await; - let state = RedisOperationState::new(Arc::new(operation), store_subscription); - Ok(Arc::new(state)) - } - - async fn inner_filter_operations( - &self, - filter: OperationFilter, - ) -> Result { - let handler = &|k: String, v: String| -> Result<(String, Arc), Error> { - let operation = Arc::new( - RedisOperation::from_str(&v) - .err_tip(|| "In RedisStateManager::inner_filter_operations")?, - ); - Ok((k, operation)) - }; - let existing_operations: Vec<(String, Arc)> = self - .list("operations:*", &handler) - .await - .err_tip(|| "In RedisStateManager::inner_filter_operations")?; - let mut v: Vec> = Vec::new(); - for (key, operation) in existing_operations.into_iter() { - if operation.matches_filter(&filter) { - let store_subscription = self.store.clone().subscribe(key.into()).await; - v.push(Arc::new(RedisOperationState::new( - operation, - store_subscription, - ))); - } - } - Ok(Box::pin(futures::stream::iter(v))) - } - - async fn inner_update_operation( - &self, - operation_id: OperationId, - worker_id: Option, - action_stage: Result, - ) -> Result<(), Error> { - let store = self.store.as_store_driver_pin(); - let key = format!("operations:{operation_id}"); - let operation_bytes_res = &store.get_part_unchunked(key.clone().into(), 0, None).await; - let Ok(operation_bytes) = operation_bytes_res else { - return Err(make_input_err!("Received request to update operation {operation_id}, but operation does not exist.")); - }; - - let mut operation = RedisOperation::from_slice(&operation_bytes[..]) - .err_tip(|| "In RedisStateManager::inner_update_operation")?; - match action_stage { - Ok(stage) => { - operation.stage = stage - .try_into() - .err_tip(|| "In RedisStateManager::inner_update_operation")?; - } - Err(e) => operation.last_error = Some(e), - } - - operation.worker_id = worker_id; - store - .update_oneshot(key.into(), operation.as_json().into()) - .await - } - - // TODO: This should be done through store but API endpoint does not exist yet. - async fn inner_remove_operation(&self, operation_id: OperationId) -> Result<(), Error> { - let mut con = self - .get_conn() - .await - .err_tip(|| "In RedisStateManager::inner_remove_operation")?; - let mut pipe = Pipeline::new(); - Ok(pipe - .del(format!("operations:{operation_id}")) - .query_async(&mut con) - .await?) - } -} - -#[async_trait] -impl ClientStateManager for RedisStateManager { - async fn add_action( - &mut self, - action_info: ActionInfo, - ) -> Result, Error> { - self.inner_add_action(action_info).await - } - - async fn filter_operations( - &self, - filter: OperationFilter, - ) -> Result { - self.inner_filter_operations(filter).await - } -} - -#[async_trait] -impl WorkerStateManager for RedisStateManager { - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: WorkerId, - action_stage: Result, - ) -> Result<(), Error> { - self.inner_update_operation(operation_id, Some(worker_id), action_stage) - .await - } -} - -#[async_trait] -impl MatchingEngineStateManager for RedisStateManager { - async fn filter_operations( - &self, - filter: OperationFilter, - ) -> Result { - self.inner_filter_operations(filter).await - } - - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: Option, - action_stage: Result, - ) -> Result<(), Error> { - self.inner_update_operation(operation_id, worker_id, action_stage) - .await - } - - async fn remove_operation(&self, operation_id: OperationId) -> Result<(), Error> { - self.inner_remove_operation(operation_id).await - } -} diff --git a/nativelink-scheduler/src/scheduler_state/awaited_action.rs b/nativelink-scheduler/src/scheduler_state/awaited_action.rs deleted file mode 100644 index bca2ef489..000000000 --- a/nativelink-scheduler/src/scheduler_state/awaited_action.rs +++ /dev/null @@ -1,67 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionState, WorkerId}; -use nativelink_util::metrics_utils::{CollectorState, MetricsComponent}; -use tokio::sync::watch; - -/// An action that is being awaited on and last known state. -pub struct AwaitedAction { - /// The action that is being awaited on. - pub(crate) action_info: Arc, - - /// The current state of the action. - pub(crate) current_state: Arc, - - /// The channel to notify subscribers of state changes when updated, completed or retrying. - pub(crate) notify_channel: watch::Sender>, - - /// Number of attempts the job has been tried. - pub(crate) attempts: usize, - - /// Possible last error set by the worker. If empty and attempts is set, it may be due to - /// something like a worker timeout. - pub(crate) last_error: Option, - - /// Worker that is currently running this action, None if unassigned. - pub(crate) worker_id: Option, -} - -impl MetricsComponent for AwaitedAction { - fn gather_metrics(&self, c: &mut CollectorState) { - c.publish( - "action_digest", - &self.action_info.unique_qualifier.action_name(), - "The digest of the action.", - ); - c.publish( - "current_state", - self.current_state.as_ref(), - "The current stage of the action.", - ); - c.publish( - "attempts", - &self.attempts, - "The number of attempts this action has tried.", - ); - c.publish( - "last_error", - &format!("{:?}", self.last_error), - "The last error this action caused from a retry (if any).", - ); - } -} diff --git a/nativelink-scheduler/src/scheduler_state/client_action_state_result.rs b/nativelink-scheduler/src/scheduler_state/client_action_state_result.rs deleted file mode 100644 index d1044b0a7..000000000 --- a/nativelink-scheduler/src/scheduler_state/client_action_state_result.rs +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use async_trait::async_trait; -use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionState}; -use tokio::sync::watch::Receiver; - -use crate::operation_state_manager::ActionStateResult; - -pub(crate) struct ClientActionStateResult { - rx: Receiver>, -} - -impl ClientActionStateResult { - pub(crate) fn new(mut rx: Receiver>) -> Self { - // Marking the initial value as changed for new or existing actions regardless if - // underlying state has changed. This allows for triggering notification after subscription - // without having to use an explicit notification. - rx.mark_changed(); - Self { rx } - } -} - -#[async_trait] -impl ActionStateResult for ClientActionStateResult { - async fn as_state(&self) -> Result, Error> { - Ok(self.rx.borrow().clone()) - } - - async fn as_receiver(&self) -> Result<&'_ Receiver>, Error> { - Ok(&self.rx) - } - - async fn as_action_info(&self) -> Result, Error> { - unimplemented!() - } -} diff --git a/nativelink-scheduler/src/scheduler_state/completed_action.rs b/nativelink-scheduler/src/scheduler_state/completed_action.rs deleted file mode 100644 index f69f10d1a..000000000 --- a/nativelink-scheduler/src/scheduler_state/completed_action.rs +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::borrow::Borrow; -use std::hash::{Hash, Hasher}; -use std::sync::Arc; -use std::time::SystemTime; - -use nativelink_util::action_messages::{ActionInfoHashKey, ActionState, OperationId}; -use nativelink_util::metrics_utils::{CollectorState, MetricsComponent}; - -/// A completed action that has no listeners. -pub struct CompletedAction { - /// The time the action was completed. - pub(crate) completed_time: SystemTime, - /// The current state of the action when it was completed. - pub(crate) state: Arc, -} - -impl Hash for CompletedAction { - fn hash(&self, state: &mut H) { - OperationId::hash(&self.state.id, state); - } -} - -impl PartialEq for CompletedAction { - fn eq(&self, other: &Self) -> bool { - OperationId::eq(&self.state.id, &other.state.id) - } -} - -impl Eq for CompletedAction {} - -impl Borrow for CompletedAction { - #[inline] - fn borrow(&self) -> &OperationId { - &self.state.id - } -} - -impl Borrow for CompletedAction { - #[inline] - fn borrow(&self) -> &ActionInfoHashKey { - &self.state.id.unique_qualifier - } -} - -impl MetricsComponent for CompletedAction { - fn gather_metrics(&self, c: &mut CollectorState) { - c.publish( - "completed_timestamp", - &self.completed_time, - "The timestamp this action was completed", - ); - c.publish( - "current_state", - self.state.as_ref(), - "The current stage of the action.", - ); - } -} diff --git a/nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs b/nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs deleted file mode 100644 index 0c6a4c74c..000000000 --- a/nativelink-scheduler/src/scheduler_state/matching_engine_action_state_result.rs +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::sync::Arc; - -use async_trait::async_trait; -use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfo, ActionState}; -use tokio::sync::watch; - -use crate::operation_state_manager::ActionStateResult; - -pub struct MatchingEngineActionStateResult { - action_info: Arc, - action_state: watch::Receiver>, -} -impl MatchingEngineActionStateResult { - pub(crate) fn new( - action_info: Arc, - action_state: watch::Receiver>, - ) -> Self { - Self { - action_info, - action_state, - } - } -} - -#[async_trait] -impl ActionStateResult for MatchingEngineActionStateResult { - async fn as_state(&self) -> Result, Error> { - Ok(self.action_state.borrow().clone()) - } - - async fn as_receiver(&self) -> Result<&'_ watch::Receiver>, Error> { - Ok(&self.action_state) - } - - async fn as_action_info(&self) -> Result, Error> { - Ok(self.action_info.clone()) - } -} diff --git a/nativelink-scheduler/src/scheduler_state/metrics.rs b/nativelink-scheduler/src/scheduler_state/metrics.rs deleted file mode 100644 index e9cfe60c5..000000000 --- a/nativelink-scheduler/src/scheduler_state/metrics.rs +++ /dev/null @@ -1,143 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use nativelink_util::metrics_utils::{CollectorState, CounterWithTime}; - -#[derive(Default)] -pub(crate) struct Metrics { - pub(crate) add_action_joined_running_action: CounterWithTime, - pub(crate) add_action_joined_queued_action: CounterWithTime, - pub(crate) add_action_new_action_created: CounterWithTime, - pub(crate) update_action_missing_action_result: CounterWithTime, - pub(crate) update_action_from_wrong_worker: CounterWithTime, - pub(crate) update_action_no_more_listeners: CounterWithTime, - pub(crate) update_action_with_internal_error: CounterWithTime, - pub(crate) update_action_with_internal_error_no_action: CounterWithTime, - pub(crate) update_action_with_internal_error_backpressure: CounterWithTime, - pub(crate) update_action_with_internal_error_from_wrong_worker: CounterWithTime, - pub(crate) workers_evicted: CounterWithTime, - pub(crate) workers_evicted_with_running_action: CounterWithTime, - pub(crate) retry_action: CounterWithTime, - pub(crate) retry_action_max_attempts_reached: CounterWithTime, - pub(crate) retry_action_no_more_listeners: CounterWithTime, - pub(crate) retry_action_but_action_missing: CounterWithTime, -} - -impl Metrics { - pub fn gather_metrics(&self, c: &mut CollectorState) { - { - c.publish_with_labels( - "add_action", - &self.add_action_joined_running_action, - "Stats about add_action().", - vec![("result".into(), "joined_running_action".into())], - ); - c.publish_with_labels( - "add_action", - &self.add_action_joined_queued_action, - "Stats about add_action().", - vec![("result".into(), "joined_queued_action".into())], - ); - c.publish_with_labels( - "add_action", - &self.add_action_new_action_created, - "Stats about add_action().", - vec![("result".into(), "new_action_created".into())], - ); - } - { - c.publish_with_labels( - "update_action_errors", - &self.update_action_missing_action_result, - "Stats about errors when worker sends update_action() to scheduler. These errors are not complete, just the most common.", - vec![("result".into(), "missing_action_result".into())], - ); - c.publish_with_labels( - "update_action_errors", - &self.update_action_from_wrong_worker, - "Stats about errors when worker sends update_action() to scheduler. These errors are not complete, just the most common.", - vec![("result".into(), "from_wrong_worker".into())], - ); - c.publish_with_labels( - "update_action_errors", - &self.update_action_no_more_listeners, - "Stats about errors when worker sends update_action() to scheduler. These errors are not complete, just the most common.", - vec![("result".into(), "no_more_listeners".into())], - ); - } - { - c.publish( - "update_action_with_internal_error", - &self.update_action_with_internal_error, - "The number of times update_action_with_internal_error was triggered.", - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_no_action, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "no_action".into())], - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_backpressure, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "backpressure".into())], - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_from_wrong_worker, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "from_wrong_worker".into())], - ); - } - { - c.publish( - "workers_evicted_total", - &self.workers_evicted, - "The number of workers evicted from scheduler.", - ); - c.publish( - "workers_evicted_with_running_action", - &self.workers_evicted_with_running_action, - "The number of jobs cancelled because worker was evicted from scheduler.", - ); - } - { - c.publish_with_labels( - "retry_action", - &self.retry_action, - "Stats about retry_action().", - vec![("result".into(), "success".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_max_attempts_reached, - "Stats about retry_action().", - vec![("result".into(), "max_attempts_reached".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_no_more_listeners, - "Stats about retry_action().", - vec![("result".into(), "no_more_listeners".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_but_action_missing, - "Stats about retry_action().", - vec![("result".into(), "action_missing".into())], - ); - } - } -} diff --git a/nativelink-scheduler/src/scheduler_state/mod.rs b/nativelink-scheduler/src/scheduler_state/mod.rs deleted file mode 100644 index 359f4f063..000000000 --- a/nativelink-scheduler/src/scheduler_state/mod.rs +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -pub(crate) mod awaited_action; -pub(crate) mod client_action_state_result; -pub(crate) mod completed_action; -pub(crate) mod matching_engine_action_state_result; -pub(crate) mod metrics; -pub(crate) mod state_manager; -pub(crate) mod workers; diff --git a/nativelink-scheduler/src/scheduler_state/state_manager.rs b/nativelink-scheduler/src/scheduler_state/state_manager.rs deleted file mode 100644 index 8dd0def9c..000000000 --- a/nativelink-scheduler/src/scheduler_state/state_manager.rs +++ /dev/null @@ -1,742 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::cmp; -use std::collections::BTreeMap; -use std::sync::Arc; -use std::time::SystemTime; - -use async_trait::async_trait; -use futures::stream; -use hashbrown::{HashMap, HashSet}; -use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; -use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, ExecutionMetadata, - OperationId, WorkerId, -}; -use tokio::sync::watch::error::SendError; -use tokio::sync::{watch, Notify}; -use tracing::{event, Level}; - -use crate::operation_state_manager::{ - ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, - OperationFilter, WorkerStateManager, -}; -use crate::scheduler_state::awaited_action::AwaitedAction; -use crate::scheduler_state::client_action_state_result::ClientActionStateResult; -use crate::scheduler_state::completed_action::CompletedAction; -use crate::scheduler_state::matching_engine_action_state_result::MatchingEngineActionStateResult; -use crate::scheduler_state::metrics::Metrics; -use crate::scheduler_state::workers::Workers; -use crate::worker::WorkerUpdate; - -#[repr(transparent)] -pub(crate) struct StateManager { - pub inner: StateManagerImpl, -} - -impl StateManager { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - queued_actions_set: HashSet>, - queued_actions: BTreeMap, AwaitedAction>, - workers: Workers, - active_actions: HashMap, AwaitedAction>, - recently_completed_actions: HashSet, - metrics: Arc, - max_job_retries: usize, - tasks_or_workers_change_notify: Arc, - ) -> Self { - Self { - inner: StateManagerImpl { - queued_actions_set, - queued_actions, - workers, - active_actions, - recently_completed_actions, - metrics, - max_job_retries, - tasks_or_workers_change_notify, - }, - } - } - - fn immediate_evict_worker(&mut self, worker_id: &WorkerId, err: Error) { - if let Some(mut worker) = self.inner.workers.remove_worker(worker_id) { - self.inner.metrics.workers_evicted.inc(); - // We don't care if we fail to send message to worker, this is only a best attempt. - let _ = worker.notify_update(WorkerUpdate::Disconnect); - // We create a temporary Vec to avoid doubt about a possible code - // path touching the worker.running_action_infos elsewhere. - for action_info in worker.running_action_infos.drain() { - self.inner.metrics.workers_evicted_with_running_action.inc(); - self.retry_action(&action_info, worker_id, err.clone()); - } - // Note: Calling this multiple times is very cheap, it'll only trigger `do_try_match` once. - self.inner.tasks_or_workers_change_notify.notify_one(); - } - } - - fn retry_action(&mut self, action_info: &Arc, worker_id: &WorkerId, err: Error) { - match self.inner.active_actions.remove(action_info) { - Some(running_action) => { - let mut awaited_action = running_action; - let send_result = if awaited_action.attempts >= self.inner.max_job_retries { - self.inner.metrics.retry_action_max_attempts_reached.inc(); - Arc::make_mut(&mut awaited_action.current_state).stage = ActionStage::Completed(ActionResult { - execution_metadata: ExecutionMetadata { - worker: format!("{worker_id}"), - ..ExecutionMetadata::default() - }, - error: Some(err.merge(make_err!( - Code::Internal, - "Job cancelled because it attempted to execute too many times and failed" - ))), - ..ActionResult::default() - }); - awaited_action - .notify_channel - .send(awaited_action.current_state.clone()) - // Do not put the action back in the queue here, as this action attempted to run too many - // times. - } else { - self.inner.metrics.retry_action.inc(); - Arc::make_mut(&mut awaited_action.current_state).stage = ActionStage::Queued; - let send_result = awaited_action - .notify_channel - .send(awaited_action.current_state.clone()); - self.inner.queued_actions_set.insert(action_info.clone()); - self.inner - .queued_actions - .insert(action_info.clone(), awaited_action); - send_result - }; - - if send_result.is_err() { - self.inner.metrics.retry_action_no_more_listeners.inc(); - // Don't remove this task, instead we keep them around for a bit just in case - // the client disconnected and will reconnect and ask for same job to be executed - // again. - event!( - Level::WARN, - ?action_info, - ?worker_id, - "Action has no more listeners during evict_worker()" - ); - } - } - None => { - self.inner.metrics.retry_action_but_action_missing.inc(); - event!( - Level::ERROR, - ?action_info, - ?worker_id, - "Worker stated it was running an action, but it was not in the active_actions" - ); - } - } - } -} - -/// StateManager is responsible for maintaining the state of the scheduler. Scheduler state -/// includes the actions that are queued, active, and recently completed. It also includes the -/// workers that are available to execute actions based on allocation strategy. -pub(crate) struct StateManagerImpl { - // TODO(adams): Move `queued_actions_set` and `queued_actions` into a single struct that - // provides a unified interface for interacting with the two containers. - - // Important: `queued_actions_set` and `queued_actions` are two containers that provide - // different search and sort capabilities. We are using the two different containers to - // optimize different use cases. `HashSet` is used to look up actions in O(1) time. The - // `BTreeMap` is used to sort actions in O(log n) time based on priority and timestamp. - // These two fields must be kept in-sync, so if you modify one, you likely need to modify the - // other. - /// A `HashSet` of all actions that are queued. A hashset is used to find actions that are queued - /// in O(1) time. This set allows us to find and join on new actions onto already existing - /// (or queued) actions where insert timestamp of queued actions is not known. Using an - /// additional `HashSet` will prevent us from having to iterate the `BTreeMap` to find actions. - /// - /// Important: `queued_actions_set` and `queued_actions` must be kept in sync. - pub(crate) queued_actions_set: HashSet>, - - /// A BTreeMap of sorted actions that are primarily based on priority and insert timestamp. - /// `ActionInfo` implements `Ord` that defines the `cmp` function for order. Using a BTreeMap - /// gives us to sorted actions that are queued in O(log n) time. - /// - /// Important: `queued_actions_set` and `queued_actions` must be kept in sync. - pub(crate) queued_actions: BTreeMap, AwaitedAction>, - - /// A `Workers` pool that contains all workers that are available to execute actions in a priority - /// order based on the allocation strategy. - pub(crate) workers: Workers, - - /// A map of all actions that are active. A hashmap is used to find actions that are active in - /// O(1) time. The key is the `ActionInfo` struct. The value is the `AwaitedAction` struct. - pub(crate) active_actions: HashMap, AwaitedAction>, - - /// These actions completed recently but had no listener, they might have - /// completed while the caller was thinking about calling wait_execution, so - /// keep their completion state around for a while to send back. - /// TODO(#192) Revisit if this is the best way to handle recently completed actions. - pub(crate) recently_completed_actions: HashSet, - - pub(crate) metrics: Arc, - - /// Default times a job can retry before failing. - pub(crate) max_job_retries: usize, - - /// Notify task<->worker matching engine that work needs to be done. - pub(crate) tasks_or_workers_change_notify: Arc, -} - -impl StateManager { - /// Modifies the `stage` of `current_state` within `AwaitedAction`. Sends notification channel - /// the new state. - /// - /// - /// # Discussion - /// - /// The use of `Arc::make_mut` is potentially dangerous because it clones the data and - /// invalidates all weak references to it. However, in this context, it is considered - /// safe because the data is going to be re-sent back out. The primary reason for using - /// `Arc` is to reduce the number of copies, not to enforce read-only access. This approach - /// ensures that all downstream components receive the same pointer. If an update occurs - /// while another thread is operating on the data, it is acceptable, since the other thread - /// will receive another update with the new version. - /// - pub(crate) fn mutate_stage( - awaited_action: &mut AwaitedAction, - action_stage: ActionStage, - ) -> Result<(), SendError>> { - Arc::make_mut(&mut awaited_action.current_state).stage = action_stage; - awaited_action - .notify_channel - .send(awaited_action.current_state.clone()) - } - - /// Modifies the `priority` of `action_info` within `ActionInfo`. - /// - fn mutate_priority(action_info: &mut Arc, priority: i32) { - Arc::make_mut(action_info).priority = priority; - } - - /// Updates the `last_error` field of the provided `AwaitedAction` and sends the current state - /// to the notify channel. - /// - fn mutate_last_error( - awaited_action: &mut AwaitedAction, - last_error: Error, - ) -> Result<(), SendError>> { - awaited_action.last_error = Some(last_error); - awaited_action - .notify_channel - .send(awaited_action.current_state.clone()) - } - - /// Notifies the specified worker to run the given action and handles errors by evicting - /// the worker if the notification fails. - /// - /// # Note - /// - /// Intended utility function for matching engine. - /// - /// # Errors - /// - /// This function will return an error if the notification to the worker fails, and in that case, - /// the worker will be immediately evicted from the system. - /// - async fn worker_notify_run_action( - &mut self, - worker_id: WorkerId, - action_info: Arc, - ) -> Result<(), Error> { - if let Some(worker) = self.inner.workers.workers.get_mut(&worker_id) { - let notify_worker_result = - worker.notify_update(WorkerUpdate::RunAction(action_info.clone())); - - if notify_worker_result.is_err() { - event!( - Level::WARN, - ?worker_id, - ?action_info, - ?notify_worker_result, - "Worker command failed, removing worker", - ); - - let err = make_err!( - Code::Internal, - "Worker command failed, removing worker {worker_id} -- {notify_worker_result:?}", - ); - - self.immediate_evict_worker(&worker_id, err.clone()); - return Err(err); - } - } - Ok(()) - } - - /// Sets the action stage for the given `AwaitedAction` based on the result of the provided - /// `action_stage`. If the `action_stage` is an error, it updates the `last_error` field - /// and logs a warning. - /// - /// # Note - /// - /// Intended utility function for matching engine. - /// - /// # Errors - /// - /// This function will return an error if updating the state of the `awaited_action` fails. - /// - async fn worker_set_action_stage( - awaited_action: &mut AwaitedAction, - action_stage: Result, - worker_id: WorkerId, - ) -> Result<(), SendError>> { - match action_stage { - Ok(action_stage) => StateManager::mutate_stage(awaited_action, action_stage), - Err(e) => { - event!( - Level::WARN, - ?worker_id, - "Action stage setting error during do_try_match()" - ); - StateManager::mutate_last_error(awaited_action, e) - } - } - } - - /// Marks the specified action as active, assigns it to the given worker, and updates the - /// action stage. This function removes the action from the queue, updates the action's state - /// or error, and inserts it into the set of active actions. - /// - /// # Note - /// - /// Intended utility function for matching engine. - /// - /// # Errors - /// - /// This function will return an error if it fails to update the action's state or if any other - /// error occurs during the process. - /// - async fn worker_set_as_active( - &mut self, - action_info: Arc, - worker_id: WorkerId, - action_stage: Result, - ) -> Result<(), Error> { - if let Some((action_info, mut awaited_action)) = - self.inner.queued_actions.remove_entry(action_info.as_ref()) - { - assert!( - self.inner.queued_actions_set.remove(&action_info), - "queued_actions_set should always have same keys as queued_actions" - ); - - awaited_action.worker_id = Some(worker_id); - - let send_result = - StateManager::worker_set_action_stage(&mut awaited_action, action_stage, worker_id) - .await; - - if send_result.is_err() { - event!( - Level::WARN, - ?action_info, - ?worker_id, - "Action has no more listeners during do_try_match()" - ); - } - - awaited_action.attempts += 1; - self.inner - .active_actions - .insert(action_info, awaited_action); - Ok(()) - } else { - Err(make_err!( - Code::Internal, - "Action not found in queued_actions_set or queued_actions" - )) - } - } - - fn update_action_with_internal_error( - &mut self, - worker_id: &WorkerId, - action_info_hash_key: ActionInfoHashKey, - err: Error, - ) { - self.inner.metrics.update_action_with_internal_error.inc(); - let Some((action_info, mut running_action)) = self - .inner - .active_actions - .remove_entry(&action_info_hash_key) - else { - self.inner - .metrics - .update_action_with_internal_error_no_action - .inc(); - event!( - Level::ERROR, - ?action_info_hash_key, - ?worker_id, - "Could not find action info in active actions" - ); - return; - }; - - let due_to_backpressure = err.code == Code::ResourceExhausted; - // Don't count a backpressure failure as an attempt for an action. - if due_to_backpressure { - self.inner - .metrics - .update_action_with_internal_error_backpressure - .inc(); - running_action.attempts -= 1; - } - let Some(running_action_worker_id) = running_action.worker_id else { - event!( - Level::ERROR, - ?action_info_hash_key, - ?worker_id, - "Got a result from a worker that should not be running the action, Removing worker. Expected action to be unassigned got worker", - ); - return; - }; - if running_action_worker_id == *worker_id { - // Don't set the error on an action that's running somewhere else. - event!( - Level::WARN, - ?action_info_hash_key, - ?worker_id, - ?running_action_worker_id, - ?err, - "Internal worker error", - ); - running_action.last_error = Some(err.clone()); - } else { - self.inner - .metrics - .update_action_with_internal_error_from_wrong_worker - .inc(); - } - - // Now put it back. retry_action() needs it to be there to send errors properly. - self.inner - .active_actions - .insert(action_info.clone(), running_action); - - // Clear this action from the current worker. - if let Some(worker) = self.inner.workers.workers.get_mut(worker_id) { - let was_paused = !worker.can_accept_work(); - // This unpauses, but since we're completing with an error, don't - // unpause unless all actions have completed. - worker.complete_action(&action_info); - // Only pause if there's an action still waiting that will unpause. - if (was_paused || due_to_backpressure) && worker.has_actions() { - worker.is_paused = true; - } - } - - // Re-queue the action or fail on max attempts. - self.retry_action(&action_info, worker_id, err); - self.inner.tasks_or_workers_change_notify.notify_one(); - } -} - -#[async_trait] -impl ClientStateManager for StateManager { - async fn add_action( - &mut self, - action_info: ActionInfo, - ) -> Result, Error> { - // Check to see if the action is running, if it is and cacheable, merge the actions. - if let Some(running_action) = self.inner.active_actions.get_mut(&action_info) { - self.inner.metrics.add_action_joined_running_action.inc(); - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(Arc::new(ClientActionStateResult::new( - running_action.notify_channel.subscribe(), - ))); - } - - // Check to see if the action is queued, if it is and cacheable, merge the actions. - if let Some(mut arc_action_info) = self.inner.queued_actions_set.take(&action_info) { - let (original_action_info, queued_action) = self - .inner - .queued_actions - .remove_entry(&arc_action_info) - .err_tip(|| "Internal error queued_actions and queued_actions_set should match")?; - self.inner.metrics.add_action_joined_queued_action.inc(); - - let new_priority = cmp::max(original_action_info.priority, action_info.priority); - drop(original_action_info); // This increases the chance Arc::make_mut won't copy. - - // In the event our task is higher priority than the one already scheduled, increase - // the priority of the scheduled one. - StateManager::mutate_priority(&mut arc_action_info, new_priority); - - let result = Arc::new(ClientActionStateResult::new( - queued_action.notify_channel.subscribe(), - )); - - // Even if we fail to send our action to the client, we need to add this action back to the - // queue because it was remove earlier. - self.inner - .queued_actions - .insert(arc_action_info.clone(), queued_action); - self.inner.queued_actions_set.insert(arc_action_info); - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(result); - } - - self.inner.metrics.add_action_new_action_created.inc(); - // Action needs to be added to queue or is not cacheable. - let action_info = Arc::new(action_info); - - let operation_id = OperationId::new(action_info.unique_qualifier.clone()); - - let current_state = Arc::new(ActionState { - stage: ActionStage::Queued, - id: operation_id, - }); - - let (tx, rx) = watch::channel(current_state.clone()); - - self.inner.queued_actions_set.insert(action_info.clone()); - self.inner.queued_actions.insert( - action_info.clone(), - AwaitedAction { - action_info, - current_state, - notify_channel: tx, - attempts: 0, - last_error: None, - worker_id: None, - }, - ); - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(Arc::new(ClientActionStateResult::new(rx))); - } - - async fn filter_operations( - &self, - filter: OperationFilter, - ) -> Result { - // TODO(adams): Build out a proper filter for other fields for state, at the moment - // this only supports the unique qualifier. - let unique_qualifier = &filter - .unique_qualifier - .err_tip(|| "No unique qualifier provided")?; - let maybe_awaited_action = self - .inner - .queued_actions_set - .get(unique_qualifier) - .and_then(|action_info| self.inner.queued_actions.get(action_info)) - .or_else(|| self.inner.active_actions.get(unique_qualifier)); - - let Some(awaited_action) = maybe_awaited_action else { - return Ok(Box::pin(stream::empty())); - }; - - let rx = awaited_action.notify_channel.subscribe(); - let action_result: [Arc; 1] = - [Arc::new(ClientActionStateResult::new(rx))]; - Ok(Box::pin(stream::iter(action_result))) - } -} - -#[async_trait] -impl WorkerStateManager for StateManager { - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: WorkerId, - action_stage: Result, - ) -> Result<(), Error> { - match action_stage { - Ok(action_stage) => { - let action_info_hash_key = operation_id.unique_qualifier; - if !action_stage.has_action_result() { - self.inner.metrics.update_action_missing_action_result.inc(); - event!( - Level::ERROR, - ?action_info_hash_key, - ?worker_id, - ?action_stage, - "Worker sent error while updating action. Removing worker" - ); - let err = make_err!( - Code::Internal, - "Worker '{worker_id}' set the action_stage of running action {action_info_hash_key:?} to {action_stage:?}. Removing worker.", - ); - self.immediate_evict_worker(&worker_id, err.clone()); - return Err(err); - } - - let (action_info, mut running_action) = self - .inner - .active_actions - .remove_entry(&action_info_hash_key) - .err_tip(|| { - format!("Could not find action info in active actions : {action_info_hash_key:?}") - })?; - - if running_action.worker_id != Some(worker_id) { - self.inner.metrics.update_action_from_wrong_worker.inc(); - let err = match running_action.worker_id { - - Some(running_action_worker_id) => make_err!( - Code::Internal, - "Got a result from a worker that should not be running the action, Removing worker. Expected worker {running_action_worker_id} got worker {worker_id}", - ), - None => make_err!( - Code::Internal, - "Got a result from a worker that should not be running the action, Removing worker. Expected action to be unassigned got worker {worker_id}", - ), - }; - event!( - Level::ERROR, - ?action_info, - ?worker_id, - ?running_action.worker_id, - ?err, - "Got a result from a worker that should not be running the action, Removing worker" - ); - // First put it back in our active_actions or we will drop the task. - self.inner - .active_actions - .insert(action_info, running_action); - self.immediate_evict_worker(&worker_id, err.clone()); - return Err(err); - } - - let send_result = StateManager::mutate_stage(&mut running_action, action_stage); - - if !running_action.current_state.stage.is_finished() { - if send_result.is_err() { - self.inner.metrics.update_action_no_more_listeners.inc(); - event!( - Level::WARN, - ?action_info, - ?worker_id, - "Action has no more listeners during update_action()" - ); - } - // If the operation is not finished it means the worker is still working on it, so put it - // back or else we will lose track of the task. - self.inner - .active_actions - .insert(action_info, running_action); - - self.inner.tasks_or_workers_change_notify.notify_one(); - return Ok(()); - } - - // Keep in case this is asked for soon. - self.inner - .recently_completed_actions - .insert(CompletedAction { - completed_time: SystemTime::now(), - state: running_action.current_state, - }); - - let worker = self - .inner - .workers - .workers - .get_mut(&worker_id) - .ok_or_else(|| { - make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) - })?; - worker.complete_action(&action_info); - self.inner.tasks_or_workers_change_notify.notify_one(); - Ok(()) - } - Err(e) => { - self.update_action_with_internal_error( - &worker_id, - operation_id.unique_qualifier, - e.clone(), - ); - return Err(e); - } - } - } -} - -#[async_trait] -impl MatchingEngineStateManager for StateManager { - async fn filter_operations( - &self, - _filter: OperationFilter, // TODO(adam): reference filter - ) -> Result { - // TODO(adams): use OperationFilter vs directly encoding it. - let action_infos = - self.inner - .queued_actions - .iter() - .rev() - .map(|(action_info, awaited_action)| { - let cloned_action_info = action_info.clone(); - Arc::new(MatchingEngineActionStateResult::new( - cloned_action_info, - awaited_action.notify_channel.subscribe(), - )) as Arc - }); - - let action_infos: Vec> = action_infos.collect(); - Ok(Box::pin(stream::iter(action_infos))) - } - - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: Option, - action_stage: Result, - ) -> Result<(), Error> { - if let Some(action_info) = self - .inner - .queued_actions_set - .get(&operation_id.unique_qualifier) - { - if let Some(worker_id) = worker_id { - let action_info = action_info.clone(); - self.worker_notify_run_action(worker_id, action_info.clone()) - .await?; - self.worker_set_as_active(action_info, worker_id, action_stage) - .await?; - } else { - event!( - Level::WARN, - ?operation_id, - ?worker_id, - "No worker found in do_try_match()" - ); - } - } else { - event!( - Level::WARN, - ?operation_id, - ?worker_id, - "No action info found in do_try_match()" - ); - } - - Ok(()) - } - - async fn remove_operation(&self, _operation_id: OperationId) -> Result<(), Error> { - todo!() - } -} diff --git a/nativelink-scheduler/src/scheduler_state/workers.rs b/nativelink-scheduler/src/scheduler_state/workers.rs deleted file mode 100644 index 25e78e2bb..000000000 --- a/nativelink-scheduler/src/scheduler_state/workers.rs +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright 2024 The NativeLink Authors. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use lru::LruCache; -use nativelink_config::schedulers::WorkerAllocationStrategy; -use nativelink_error::{error_if, make_input_err, Error, ResultExt}; -use nativelink_util::action_messages::WorkerId; -use nativelink_util::platform_properties::PlatformProperties; -use tracing::{event, Level}; - -use crate::worker::{Worker, WorkerTimestamp}; - -/// A collection of workers that are available to run tasks. -pub struct Workers { - /// A `LruCache` of workers availabled based on `allocation_strategy`. - pub(crate) workers: LruCache, - /// The allocation strategy for workers. - pub(crate) allocation_strategy: WorkerAllocationStrategy, -} - -impl Workers { - pub(crate) fn new(allocation_strategy: WorkerAllocationStrategy) -> Self { - Self { - workers: LruCache::unbounded(), - allocation_strategy, - } - } - - /// Refreshes the lifetime of the worker with the given timestamp. - pub(crate) fn refresh_lifetime( - &mut self, - worker_id: &WorkerId, - timestamp: WorkerTimestamp, - ) -> Result<(), Error> { - let worker = self.workers.get_mut(worker_id).ok_or_else(|| { - make_input_err!( - "Worker not found in worker map in refresh_lifetime() {}", - worker_id - ) - })?; - error_if!( - worker.last_update_timestamp > timestamp, - "Worker already had a timestamp of {}, but tried to update it with {}", - worker.last_update_timestamp, - timestamp - ); - worker.last_update_timestamp = timestamp; - Ok(()) - } - - /// Adds a worker to the pool. - /// Note: This function will not do any task matching. - pub(crate) fn add_worker(&mut self, worker: Worker) -> Result<(), Error> { - let worker_id = worker.id; - self.workers.put(worker_id, worker); - - // Worker is not cloneable, and we do not want to send the initial connection results until - // we have added it to the map, or we might get some strange race conditions due to the way - // the multi-threaded runtime works. - let worker = self.workers.peek_mut(&worker_id).unwrap(); - let res = worker - .send_initial_connection_result() - .err_tip(|| "Failed to send initial connection result to worker"); - if let Err(err) = &res { - event!( - Level::ERROR, - ?worker_id, - ?err, - "Worker connection appears to have been closed while adding to pool" - ); - } - res - } - - /// Removes worker from pool. - /// Note: The caller is responsible for any rescheduling of any tasks that might be - /// running. - pub(crate) fn remove_worker(&mut self, worker_id: &WorkerId) -> Option { - self.workers.pop(worker_id) - } - - // Attempts to find a worker that is capable of running this action. - // TODO(blaise.bruer) This algorithm is not very efficient. Simple testing using a tree-like - // structure showed worse performance on a 10_000 worker * 7 properties * 1000 queued tasks - // simulation of worst cases in a single threaded environment. - pub(crate) fn find_worker_for_action( - &self, - platform_properties: &PlatformProperties, - ) -> Option { - let mut workers_iter = self.workers.iter(); - let workers_iter = match self.allocation_strategy { - // Use rfind to get the least recently used that satisfies the properties. - WorkerAllocationStrategy::least_recently_used => workers_iter.rfind(|(_, w)| { - w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties) - }), - // Use find to get the most recently used that satisfies the properties. - WorkerAllocationStrategy::most_recently_used => workers_iter.find(|(_, w)| { - w.can_accept_work() && platform_properties.is_satisfied_by(&w.platform_properties) - }), - }; - workers_iter.map(|(_, w)| &w.id).copied() - } -} diff --git a/nativelink-scheduler/src/simple_scheduler.rs b/nativelink-scheduler/src/simple_scheduler.rs index 4a651fe5a..618b2077f 100644 --- a/nativelink-scheduler/src/simple_scheduler.rs +++ b/nativelink-scheduler/src/simple_scheduler.rs @@ -12,43 +12,34 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::BTreeMap; use std::pin::Pin; -use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::Arc; -use std::time::{Instant, SystemTime}; -use async_lock::{Mutex, MutexGuard}; use async_trait::async_trait; -use futures::{Future, Stream}; -use hashbrown::{HashMap, HashSet}; -use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt}; +use futures::Future; +use nativelink_config::stores::EvictionPolicy; +use nativelink_error::{Error, ResultExt}; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, ExecutionMetadata, - OperationId, WorkerId, + ActionInfo, ActionStage, ActionState, ClientOperationId, OperationId, WorkerId, }; -use nativelink_util::metrics_utils::{ - AsyncCounterWrapper, Collector, CollectorState, CounterWithTime, FuncCounterWrapper, - MetricsComponent, Registry, +use nativelink_util::metrics_utils::Registry; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, + OperationFilter, OperationStageFlags, OrderDirection, }; -use nativelink_util::platform_properties::PlatformPropertyValue; use nativelink_util::spawn; use nativelink_util::task::JoinHandleDropGuard; -use tokio::sync::{watch, Notify}; +use tokio::sync::Notify; use tokio::time::Duration; use tokio_stream::StreamExt; use tracing::{event, Level}; -use crate::action_scheduler::ActionScheduler; -use crate::operation_state_manager::{ - ActionStateResult, ClientStateManager, MatchingEngineStateManager, OperationFilter, - OperationStageFlags, WorkerStateManager, -}; +use crate::action_scheduler::{ActionListener, ActionScheduler}; +use crate::api_worker_scheduler::ApiWorkerScheduler; +use crate::memory_awaited_action_db::MemoryAwaitedActionDb; use crate::platform_property_manager::PlatformPropertyManager; -use crate::scheduler_state::metrics::Metrics as SchedulerMetrics; -use crate::scheduler_state::state_manager::StateManager; -use crate::scheduler_state::workers::Workers; -use crate::worker::{Worker, WorkerTimestamp, WorkerUpdate}; +use crate::simple_scheduler_state_manager::SimpleSchedulerStateManager; +use crate::worker::{Worker, WorkerTimestamp}; use crate::worker_scheduler::WorkerScheduler; /// Default timeout for workers in seconds. @@ -57,326 +48,212 @@ const DEFAULT_WORKER_TIMEOUT_S: u64 = 5; /// Default timeout for recently completed actions in seconds. /// If this changes, remember to change the documentation in the config. -const DEFAULT_RETAIN_COMPLETED_FOR_S: u64 = 60; +const DEFAULT_RETAIN_COMPLETED_FOR_S: u32 = 60; /// Default times a job can retry before failing. /// If this changes, remember to change the documentation in the config. const DEFAULT_MAX_JOB_RETRIES: usize = 3; -struct SimpleSchedulerImpl { - /// The manager responsible for holding the state of actions and workers. - state_manager: StateManager, - /// The duration that actions are kept in recently_completed_actions for. - retain_completed_for: Duration, - /// Timeout of how long to evict workers if no response in this given amount of time in seconds. - worker_timeout_s: u64, - /// Default times a job can retry before failing. - max_job_retries: usize, - metrics: Arc, +struct SimpleSchedulerActionListener { + client_operation_id: ClientOperationId, + action_state_result: Box, } -impl SimpleSchedulerImpl { - /// Attempts to find a worker to execute an action and begins executing it. - /// If an action is already running that is cacheable it may merge this action - /// with the results and state changes of the already running action. - /// If the task cannot be executed immediately it will be queued for execution - /// based on priority and other metrics. - /// All further updates to the action will be provided through `listener`. - async fn add_action( - &mut self, - action_info: ActionInfo, - ) -> Result>, Error> { - let add_action_result = self.state_manager.add_action(action_info).await?; - add_action_result.as_receiver().await.cloned() +impl SimpleSchedulerActionListener { + fn new( + client_operation_id: ClientOperationId, + action_state_result: Box, + ) -> Self { + Self { + client_operation_id, + action_state_result, + } } +} - fn clean_recently_completed_actions(&mut self) { - let expiry_time = SystemTime::now() - .checked_sub(self.retain_completed_for) - .unwrap(); - self.state_manager - .inner - .recently_completed_actions - .retain(|action| action.completed_time > expiry_time); +impl ActionListener for SimpleSchedulerActionListener { + fn client_operation_id(&self) -> &ClientOperationId { + &self.client_operation_id } - fn find_recently_completed_action( - &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - self.state_manager - .inner - .recently_completed_actions - .get(unique_qualifier) - .map(|action| watch::channel(action.state.clone()).1) + fn changed( + &mut self, + ) -> Pin, Error>> + Send + '_>> { + Box::pin(async move { + let action_state = self + .action_state_result + .changed() + .await + .err_tip(|| "In SimpleSchedulerActionListener::changed getting receiver")?; + Ok(action_state) + }) } +} - async fn find_existing_action( - &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - let filter_result = ::filter_operations( - &self.state_manager, - OperationFilter { - stages: OperationStageFlags::Any, - operation_id: None, - worker_id: None, - action_digest: None, - worker_update_before: None, - completed_before: None, - last_client_update_before: None, - unique_qualifier: Some(unique_qualifier.clone()), - order_by: None, - }, - ) - .await; - - let mut stream = filter_result.ok()?; - if let Some(result) = stream.next().await { - result.as_receiver().await.ok().cloned() - } else { - None - } - } +/// Engine used to manage the queued/running tasks and relationship with +/// the worker nodes. All state on how the workers and actions are interacting +/// should be held in this struct. +pub struct SimpleScheduler { + /// Manager for matching engine side of the state manager. + matching_engine_state_manager: Arc, - fn retry_action(&mut self, action_info: &Arc, worker_id: &WorkerId, err: Error) { - match self.state_manager.inner.active_actions.remove(action_info) { - Some(running_action) => { - let mut awaited_action = running_action; - let send_result = if awaited_action.attempts >= self.max_job_retries { - self.metrics.retry_action_max_attempts_reached.inc(); - - StateManager::mutate_stage(&mut awaited_action, ActionStage::Completed(ActionResult { - execution_metadata: ExecutionMetadata { - worker: format!("{worker_id}"), - ..ExecutionMetadata::default() - }, - error: Some(err.merge(make_err!( - Code::Internal, - "Job cancelled because it attempted to execute too many times and failed" - ))), - ..ActionResult::default() - })) - // Do not put the action back in the queue here, as this action attempted to run too many - // times. - } else { - self.metrics.retry_action.inc(); - let send_result = - StateManager::mutate_stage(&mut awaited_action, ActionStage::Queued); - self.state_manager - .inner - .queued_actions_set - .insert(action_info.clone()); - self.state_manager - .inner - .queued_actions - .insert(action_info.clone(), awaited_action); - send_result - }; - - if send_result.is_err() { - self.metrics.retry_action_no_more_listeners.inc(); - // Don't remove this task, instead we keep them around for a bit just in case - // the client disconnected and will reconnect and ask for same job to be executed - // again. - event!( - Level::WARN, - ?action_info, - ?worker_id, - "Action has no more listeners during evict_worker()" - ); - } - } - None => { - self.metrics.retry_action_but_action_missing.inc(); - event!( - Level::ERROR, - ?action_info, - ?worker_id, - "Worker stated it was running an action, but it was not in the active_actions" - ); - } - } - } + /// Manager for client state of this scheduler. + client_state_manager: Arc, - /// Evicts the worker from the pool and puts items back into the queue if anything was being executed on it. - fn immediate_evict_worker(&mut self, worker_id: &WorkerId, err: Error) { - if let Some(mut worker) = self.state_manager.inner.workers.remove_worker(worker_id) { - self.metrics.workers_evicted.inc(); - // We don't care if we fail to send message to worker, this is only a best attempt. - let _ = worker.notify_update(WorkerUpdate::Disconnect); - // We create a temporary Vec to avoid doubt about a possible code - // path touching the worker.running_action_infos elsewhere. - for action_info in worker.running_action_infos.drain() { - self.metrics.workers_evicted_with_running_action.inc(); - self.retry_action(&action_info, worker_id, err.clone()); - } - } - // Note: Calling this many time is very cheap, it'll only trigger `do_try_match` once. - self.state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); - } + /// Manager for platform of this scheduler. + platform_property_manager: Arc, + + /// A `Workers` pool that contains all workers that are available to execute actions in a priority + /// order based on the allocation strategy. + worker_scheduler: Arc, + + /// Background task that tries to match actions to workers. If this struct + /// is dropped the spawn will be cancelled as well. + _task_worker_matching_spawn: JoinHandleDropGuard<()>, +} - /// Sets if the worker is draining or not. - fn set_drain_worker(&mut self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> { - let worker = self - .state_manager - .inner - .workers - .workers - .get_mut(&worker_id) - .err_tip(|| format!("Worker {worker_id} doesn't exist in the pool"))?; - self.metrics.workers_drained.inc(); - worker.is_draining = is_draining; - self.state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); - Ok(()) +impl SimpleScheduler { + /// Attempts to find a worker to execute an action and begins executing it. + /// If an action is already running that is cacheable it may merge this + /// action with the results and state changes of the already running + /// action. If the task cannot be executed immediately it will be queued + /// for execution based on priority and other metrics. + /// All further updates to the action will be provided through the returned + /// value. + async fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result>, Error> { + let add_action_result = self + .client_state_manager + .add_action(client_operation_id.clone(), action_info) + .await?; + + Ok(Box::pin(SimpleSchedulerActionListener::new( + client_operation_id, + add_action_result, + ))) } - async fn get_queued_operations( + async fn find_by_client_operation_id( &self, - ) -> Result> + Send>>, Error> - { - ::filter_operations( - &self.state_manager, - OperationFilter { - stages: OperationStageFlags::Queued, - operation_id: None, - worker_id: None, - action_digest: None, - worker_update_before: None, - completed_before: None, - last_client_update_before: None, - unique_qualifier: None, - order_by: None, - }, - ) - .await + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { + let filter = OperationFilter { + client_operation_id: Some(client_operation_id.clone()), + ..Default::default() + }; + let filter_result = self.client_state_manager.filter_operations(filter).await; + + let mut stream = filter_result + .err_tip(|| "In SimpleScheduler::find_by_client_operation_id getting filter result")?; + let Some(action_state_result) = stream.next().await else { + return Ok(None); + }; + Ok(Some(Box::pin(SimpleSchedulerActionListener::new( + client_operation_id.clone(), + action_state_result, + )))) } - // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we can create a map - // of capabilities of each worker and then try and match the actions to the worker using - // the map lookup (ie. map reduce). - async fn do_try_match(&mut self) { - // TODO(blaise.bruer) This is a bit difficult because of how rust's borrow checker gets in - // the way. We need to conditionally remove items from the `queued_action`. Rust is working - // to add `drain_filter`, which would in theory solve this problem, but because we need - // to iterate the items in reverse it becomes more difficult (and it is currently an - // unstable feature [see: https://github.com/rust-lang/rust/issues/70530]). - - let action_state_results = self.get_queued_operations().await; - - match action_state_results { - Ok(mut stream) => { - while let Some(action_state_result) = stream.next().await { - let as_state_result = action_state_result.as_state().await; - let Ok(state) = as_state_result else { - let _ = as_state_result.inspect_err(|err| { - event!( - Level::ERROR, - ?err, - "Failed to get action_info from as_state_result stream" - ); - }); - continue; - }; - let action_state_result = action_state_result.as_action_info().await; - let Ok(action_info) = action_state_result else { - let _ = action_state_result.inspect_err(|err| { - event!( - Level::ERROR, - ?err, - "Failed to get action_info from action_state_results stream" - ); - }); - continue; - }; - - let maybe_worker_id: Option = { - self.state_manager - .inner - .workers - .find_worker_for_action(&action_info.platform_properties) - }; - - let operation_id = state.id.clone(); - let ret = ::update_operation( - &mut self.state_manager, - operation_id.clone(), - maybe_worker_id, - Ok(ActionStage::Executing), - ) - .await; - - if let Err(e) = ret { - event!( - Level::ERROR, - ?e, - "update operation failed for {}", - operation_id - ); - } + async fn get_queued_operations(&self) -> Result { + let filter = OperationFilter { + stages: OperationStageFlags::Queued, + order_by_priority_direction: Some(OrderDirection::Desc), + ..Default::default() + }; + self.matching_engine_state_manager + .filter_operations(filter) + .await + .err_tip(|| "In SimpleScheduler::get_queued_operations getting filter result") + } + + // TODO(blaise.bruer) This is an O(n*m) (aka n^2) algorithm. In theory we + // can create a map of capabilities of each worker and then try and match + // the actions to the worker using the map lookup (ie. map reduce). + async fn do_try_match(&self) -> Result<(), Error> { + async fn match_action_to_worker( + action_state_result: &dyn ActionStateResult, + workers: &ApiWorkerScheduler, + matching_engine_state_manager: &dyn MatchingEngineStateManager, + ) -> Result<(), Error> { + let action_info = action_state_result + .as_action_info() + .await + .err_tip(|| "Failed to get action_info from as_action_info_result stream")?; + + // Try to find a worker for the action. + let worker_id = { + let platform_properties = &action_info.platform_properties; + match workers.find_worker_for_action(platform_properties).await { + Some(worker_id) => worker_id, + // If we could not find a worker for the action, + // we have nothing to do. + None => return Ok(()), } - } - Err(e) => { - event!(Level::ERROR, ?e, "stream error in do_try_match"); + }; + + // Extract the operation_id from the action_state. + let operation_id = { + let action_state = action_state_result + .as_state() + .await + .err_tip(|| "Failed to get action_info from as_state_result stream")?; + action_state.id.clone() + }; + + // Tell the matching engine that the operation is being assigned to a worker. + matching_engine_state_manager + .assign_operation(&operation_id, Ok(&worker_id)) + .await + .err_tip(|| "Failed to assign operation in do_try_match")?; + + // Notify the worker to run the action. + { + workers + .worker_notify_run_action(worker_id, operation_id, action_info) + .await + .err_tip(|| { + "Failed to run worker_notify_run_action in SimpleScheduler::do_try_match" + }) } } - } - async fn update_action( - &mut self, - worker_id: &WorkerId, - action_info_hash_key: ActionInfoHashKey, - action_stage: Result, - ) -> Result<(), Error> { - let update_operation_result = ::update_operation( - &mut self.state_manager, - OperationId::new(action_info_hash_key.clone()), - *worker_id, - action_stage, - ) - .await; - if let Err(e) = &update_operation_result { - event!( - Level::ERROR, - ?action_info_hash_key, - ?worker_id, - ?e, - "Failed to update_operation on update_action" + let mut result = Ok(()); + + let mut stream = self + .get_queued_operations() + .await + .err_tip(|| "Failed to get queued operations in do_try_match")?; + + while let Some(action_state_result) = stream.next().await { + result = result.merge( + match_action_to_worker( + action_state_result.as_ref(), + self.worker_scheduler.as_ref(), + self.matching_engine_state_manager.as_ref(), + ) + .await, ); } - update_operation_result + result } } -/// Engine used to manage the queued/running tasks and relationship with -/// the worker nodes. All state on how the workers and actions are interacting -/// should be held in this struct. -pub struct SimpleScheduler { - inner: Arc>, - platform_property_manager: Arc, - metrics: Arc, - // Triggers `drop()`` call if scheduler is dropped. - _task_worker_matching_future: JoinHandleDropGuard<()>, -} - impl SimpleScheduler { - #[inline] - #[must_use] - pub fn new(scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler) -> Self { + pub fn new( + scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler, + ) -> (Arc, Arc) { Self::new_with_callback(scheduler_cfg, || { // The cost of running `do_try_match()` is very high, but constant - // in relation to the number of changes that have happened. This means - // that grabbing this lock to process `do_try_match()` should always - // yield to any other tasks that might want the lock. The easiest and - // most fair way to do this is to sleep for a small amount of time. - // Using something like tokio::task::yield_now() does not yield as - // aggresively as we'd like if new futures are scheduled within a future. + // in relation to the number of changes that have happened. This + // means that grabbing this lock to process `do_try_match()` should + // always yield to any other tasks that might want the lock. The + // easiest and most fair way to do this is to sleep for a small + // amount of time. Using something like tokio::task::yield_now() + // does not yield as aggresively as we'd like if new futures are + // scheduled within a future. tokio::time::sleep(Duration::from_millis(1)) }) } @@ -387,7 +264,7 @@ impl SimpleScheduler { >( scheduler_cfg: &nativelink_config::schedulers::SimpleScheduler, on_matching_engine_run: F, - ) -> Self { + ) -> (Arc, Arc) { let platform_property_manager = Arc::new(PlatformPropertyManager::new( scheduler_cfg .supported_platform_properties @@ -410,102 +287,56 @@ impl SimpleScheduler { max_job_retries = DEFAULT_MAX_JOB_RETRIES; } - let tasks_or_workers_change_notify = Arc::new(Notify::new()); - let state_manager = StateManager::new( - HashSet::new(), - BTreeMap::new(), - Workers::new(scheduler_cfg.allocation_strategy), - HashMap::new(), - HashSet::new(), - Arc::new(SchedulerMetrics::default()), + let tasks_or_worker_change_notify = Arc::new(Notify::new()); + let state_manager = SimpleSchedulerStateManager::new( + tasks_or_worker_change_notify.clone(), max_job_retries, - tasks_or_workers_change_notify.clone(), + MemoryAwaitedActionDb::new(&EvictionPolicy { + max_seconds: retain_completed_for_s, + ..Default::default() + }), ); - let metrics = Arc::new(Metrics::default()); - let metrics_for_do_try_match = metrics.clone(); - let inner = Arc::new(Mutex::new(SimpleSchedulerImpl { - state_manager, - retain_completed_for: Duration::new(retain_completed_for_s, 0), + + let worker_scheduler = ApiWorkerScheduler::new( + state_manager.clone(), + platform_property_manager.clone(), + scheduler_cfg.allocation_strategy, + tasks_or_worker_change_notify.clone(), worker_timeout_s, - max_job_retries, - metrics: metrics.clone(), - })); - let weak_inner = Arc::downgrade(&inner); - Self { - inner, - platform_property_manager, - _task_worker_matching_future: spawn!( - "simple_scheduler_task_worker_matching", - async move { + ); + + let worker_scheduler_clone = worker_scheduler.clone(); + + let action_scheduler = Arc::new_cyclic(move |weak_self| -> Self { + let weak_inner = weak_self.clone(); + let task_worker_matching_spawn = + spawn!("simple_scheduler_task_worker_matching", async move { // Break out of the loop only when the inner is dropped. loop { - tasks_or_workers_change_notify.notified().await; - match weak_inner.upgrade() { - // Note: According to `parking_lot` documentation, the default - // `Mutex` implementation is eventual fairness, so we don't - // really need to worry about this thread taking the lock - // starving other threads too much. - Some(inner_mux) => { - let mut inner = inner_mux.lock().await; - let timer = metrics_for_do_try_match.do_try_match.begin_timer(); - inner.do_try_match().await; - timer.measure(); - } + tasks_or_worker_change_notify.notified().await; + let result = match weak_inner.upgrade() { + Some(scheduler) => scheduler.do_try_match().await, // If the inner went away it means the scheduler is shutting // down, so we need to resolve our future. None => return, }; + if let Err(err) = result { + event!(Level::ERROR, ?err, "Error while running do_try_match"); + } + on_matching_engine_run().await; } // Unreachable. - } - ), - metrics, - } - } - - /// Checks to see if the worker exists in the worker pool. Should only be used in unit tests. - #[must_use] - pub async fn contains_worker_for_test(&self, worker_id: &WorkerId) -> bool { - let inner = self.get_inner_lock().await; - inner - .state_manager - .inner - .workers - .workers - .contains(worker_id) - } - - /// A unit test function used to send the keep alive message to the worker from the server. - pub async fn send_keep_alive_to_worker_for_test( - &self, - worker_id: &WorkerId, - ) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - let worker = inner - .state_manager - .inner - .workers - .workers - .get_mut(worker_id) - .ok_or_else(|| { - make_input_err!("WorkerId '{}' does not exist in workers map", worker_id) - })?; - worker.keep_alive() - } - - async fn get_inner_lock(&self) -> MutexGuard<'_, SimpleSchedulerImpl> { - // We don't use one of the wrappers because we only want to capture the time spent, - // nothing else beacuse this is a hot path. - let start = Instant::now(); - let lock: MutexGuard = self.inner.lock().await; - self.metrics - .lock_stall_time - .fetch_add(start.elapsed().as_nanos() as u64, Ordering::Relaxed); - self.metrics - .lock_stall_time_counter - .fetch_add(1, Ordering::Relaxed); - lock + }); + SimpleScheduler { + matching_engine_state_manager: state_manager.clone(), + client_state_manager: state_manager.clone(), + worker_scheduler, + platform_property_manager, + _task_worker_matching_spawn: task_worker_matching_spawn, + } + }); + (action_scheduler, worker_scheduler_clone) } } @@ -520,82 +351,52 @@ impl ActionScheduler for SimpleScheduler { async fn add_action( &self, + client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error> { - let mut inner = self.get_inner_lock().await; - self.metrics - .add_action - .wrap(inner.add_action(action_info)) + ) -> Result>, Error> { + self.add_action(client_operation_id, Arc::new(action_info)) .await } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { - let inner = self.get_inner_lock().await; - let result = inner - .find_existing_action(unique_qualifier) - .await - .or_else(|| inner.find_recently_completed_action(unique_qualifier)); - if result.is_some() { - self.metrics.existing_actions_found.inc(); - } else { - self.metrics.existing_actions_not_found.inc(); - } - result - } - - async fn clean_recently_completed_actions(&self) { - self.get_inner_lock() + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { + let maybe_receiver = self + .find_by_client_operation_id(client_operation_id) .await - .clean_recently_completed_actions(); - self.metrics.clean_recently_completed_actions.inc() + .err_tip(|| { + format!("Error while finding action with client id: {client_operation_id:?}") + })?; + Ok(maybe_receiver) } fn register_metrics(self: Arc, registry: &mut Registry) { - registry.register_collector(Box::new(Collector::new(&self))); + self.client_state_manager.clone().register_metrics(registry); + self.matching_engine_state_manager + .clone() + .register_metrics(registry); } } #[async_trait] impl WorkerScheduler for SimpleScheduler { fn get_platform_property_manager(&self) -> &PlatformPropertyManager { - self.platform_property_manager.as_ref() + self.worker_scheduler.get_platform_property_manager() } async fn add_worker(&self, worker: Worker) -> Result<(), Error> { - let worker_id = worker.id; - let mut inner = self.get_inner_lock().await; - self.metrics.add_worker.wrap(move || { - let res = inner - .state_manager - .inner - .workers - .add_worker(worker) - .err_tip(|| "Error while adding worker, removing from pool"); - if let Err(err) = &res { - inner.immediate_evict_worker(&worker_id, err.clone()); - } - inner - .state_manager - .inner - .tasks_or_workers_change_notify - .notify_one(); - res - }) + self.worker_scheduler.add_worker(worker).await } async fn update_action( &self, worker_id: &WorkerId, - action_info_hash_key: ActionInfoHashKey, + operation_id: &OperationId, action_stage: Result, ) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - self.metrics - .update_action - .wrap(inner.update_action(worker_id, action_info_hash_key, action_stage)) + self.worker_scheduler + .update_action(worker_id, operation_id, action_stage) .await } @@ -604,322 +405,24 @@ impl WorkerScheduler for SimpleScheduler { worker_id: &WorkerId, timestamp: WorkerTimestamp, ) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - inner - .state_manager - .inner - .workers - .refresh_lifetime(worker_id, timestamp) - .err_tip(|| "Error refreshing lifetime in worker_keep_alive_received()") + self.worker_scheduler + .worker_keep_alive_received(worker_id, timestamp) + .await } - async fn remove_worker(&self, worker_id: WorkerId) { - let mut inner = self.get_inner_lock().await; - inner.immediate_evict_worker( - &worker_id, - make_err!(Code::Internal, "Received request to remove worker"), - ); + async fn remove_worker(&self, worker_id: &WorkerId) -> Result<(), Error> { + self.worker_scheduler.remove_worker(worker_id).await } async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - self.metrics.remove_timedout_workers.wrap(move || { - // Items should be sorted based on last_update_timestamp, so we don't need to iterate the entire - // map most of the time. - let worker_ids_to_remove: Vec = inner - .state_manager - .inner - .workers - .workers - .iter() - .rev() - .map_while(|(worker_id, worker)| { - if worker.last_update_timestamp <= now_timestamp - inner.worker_timeout_s { - Some(*worker_id) - } else { - None - } - }) - .collect(); - for worker_id in &worker_ids_to_remove { - event!( - Level::WARN, - ?worker_id, - "Worker timed out, removing from pool" - ); - inner.immediate_evict_worker( - worker_id, - make_err!( - Code::Internal, - "Worker {worker_id} timed out, removing from pool" - ), - ); - } - - Ok(()) - }) - } - - async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error> { - let mut inner = self.get_inner_lock().await; - inner.set_drain_worker(worker_id, is_draining) - } - - fn register_metrics(self: Arc, _registry: &mut Registry) { - // We do not register anything here because we only want to register metrics - // once and we rely on the `ActionScheduler::register_metrics()` to do that. - } -} - -impl MetricsComponent for SimpleScheduler { - fn gather_metrics(&self, c: &mut CollectorState) { - self.metrics.gather_metrics(c); - { - // We use the raw lock because we dont gather stats about gathering stats. - let inner = self.inner.lock_blocking(); - inner.state_manager.inner.metrics.gather_metrics(c); - c.publish( - "queued_actions_total", - &inner.state_manager.inner.queued_actions.len(), - "The number actions in the queue.", - ); - c.publish( - "workers_total", - &inner.state_manager.inner.workers.workers.len(), - "The number workers active.", - ); - c.publish( - "active_actions_total", - &inner.state_manager.inner.active_actions.len(), - "The number of running actions.", - ); - c.publish( - "recently_completed_actions_total", - &inner.state_manager.inner.recently_completed_actions.len(), - "The number of recently completed actions in the buffer.", - ); - c.publish( - "retain_completed_for_seconds", - &inner.retain_completed_for, - "The duration completed actions are retained for.", - ); - c.publish( - "worker_timeout_seconds", - &inner.worker_timeout_s, - "The configured timeout if workers have not responded for a while.", - ); - c.publish( - "max_job_retries", - &inner.max_job_retries, - "The amount of times a job is allowed to retry from an internal error before it is dropped.", - ); - let mut props = HashMap::<&String, u64>::new(); - for (_worker_id, worker) in inner.state_manager.inner.workers.workers.iter() { - c.publish_with_labels( - "workers", - worker, - "", - vec![("worker_id".into(), worker.id.to_string().into())], - ); - for (property, prop_value) in &worker.platform_properties.properties { - let current_value = props.get(&property).unwrap_or(&0); - if let PlatformPropertyValue::Minimum(worker_value) = prop_value { - props.insert(property, *current_value + *worker_value); - } - } - } - for (property, prop_value) in props { - c.publish( - &format!("{property}_available_properties"), - &prop_value, - format!("Total sum of available properties for {property}"), - ); - } - for (_, active_action) in inner.state_manager.inner.active_actions.iter() { - let action_name = active_action - .action_info - .unique_qualifier - .action_name() - .into(); - let worker_id_str = match active_action.worker_id { - Some(id) => id.to_string(), - None => "Unassigned".to_string(), - }; - c.publish_with_labels( - "active_actions", - active_action, - "", - vec![ - ("worker_id".into(), worker_id_str.into()), - ("digest".into(), action_name), - ], - ); - } - // Note: We don't publish queued_actions because it can be very large. - // Note: We don't publish recently completed actions because it can be very large. - } + self.worker_scheduler + .remove_timedout_workers(now_timestamp) + .await } -} - -#[derive(Default)] -struct Metrics { - add_action: AsyncCounterWrapper, - existing_actions_found: CounterWithTime, - existing_actions_not_found: CounterWithTime, - clean_recently_completed_actions: CounterWithTime, - remove_timedout_workers: FuncCounterWrapper, - update_action: AsyncCounterWrapper, - update_action_with_internal_error: CounterWithTime, - update_action_with_internal_error_no_action: CounterWithTime, - update_action_with_internal_error_backpressure: CounterWithTime, - update_action_with_internal_error_from_wrong_worker: CounterWithTime, - workers_evicted: CounterWithTime, - workers_evicted_with_running_action: CounterWithTime, - workers_drained: CounterWithTime, - retry_action: CounterWithTime, - retry_action_max_attempts_reached: CounterWithTime, - retry_action_no_more_listeners: CounterWithTime, - retry_action_but_action_missing: CounterWithTime, - add_worker: FuncCounterWrapper, - timedout_workers: CounterWithTime, - lock_stall_time: AtomicU64, - lock_stall_time_counter: AtomicU64, - do_try_match: AsyncCounterWrapper, -} -impl Metrics { - fn gather_metrics(&self, c: &mut CollectorState) { - c.publish( - "add_action", - &self.add_action, - "The number of times add_action was called.", - ); - c.publish_with_labels( - "find_existing_action", - &self.existing_actions_found, - "The number of times existing_actions_found had an action found.", - vec![("result".into(), "found".into())], - ); - c.publish_with_labels( - "find_existing_action", - &self.existing_actions_not_found, - "The number of times existing_actions_found had an action not found.", - vec![("result".into(), "not_found".into())], - ); - c.publish( - "clean_recently_completed_actions", - &self.clean_recently_completed_actions, - "The number of times clean_recently_completed_actions was triggered.", - ); - c.publish( - "remove_timedout_workers", - &self.remove_timedout_workers, - "The number of times remove_timedout_workers was triggered.", - ); - { - c.publish_with_labels( - "update_action", - &self.update_action, - "Stats about errors when worker sends update_action() to scheduler.", - vec![("result".into(), "missing_action_result".into())], - ); - } - c.publish( - "update_action_with_internal_error", - &self.update_action_with_internal_error, - "The number of times update_action_with_internal_error was triggered.", - ); - { - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_no_action, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "no_action".into())], - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_backpressure, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "backpressure".into())], - ); - c.publish_with_labels( - "update_action_with_internal_error_errors", - &self.update_action_with_internal_error_from_wrong_worker, - "Stats about what errors caused update_action_with_internal_error() in scheduler.", - vec![("result".into(), "from_wrong_worker".into())], - ); - } - c.publish( - "workers_evicted_total", - &self.workers_evicted, - "The number of workers evicted from scheduler.", - ); - c.publish( - "workers_evicted_with_running_action", - &self.workers_evicted_with_running_action, - "The number of jobs cancelled because worker was evicted from scheduler.", - ); - c.publish( - "workers_drained_total", - &self.workers_drained, - "The number of workers drained from scheduler.", - ); - { - c.publish_with_labels( - "retry_action", - &self.retry_action, - "Stats about retry_action().", - vec![("result".into(), "success".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_max_attempts_reached, - "Stats about retry_action().", - vec![("result".into(), "max_attempts_reached".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_no_more_listeners, - "Stats about retry_action().", - vec![("result".into(), "no_more_listeners".into())], - ); - c.publish_with_labels( - "retry_action", - &self.retry_action_but_action_missing, - "Stats about retry_action().", - vec![("result".into(), "action_missing".into())], - ); - } - c.publish( - "add_worker", - &self.add_worker, - "Stats about add_worker() being called on the scheduler.", - ); - c.publish( - "timedout_workers", - &self.timedout_workers, - "The number of workers that timed out.", - ); - c.publish( - "lock_stall_time_nanos_total", - &self.lock_stall_time, - "The total number of nanos spent waiting on the lock in the scheduler.", - ); - c.publish( - "lock_stall_time_total", - &self.lock_stall_time_counter, - "The number of times a lock request was made in the scheduler.", - ); - c.publish( - "lock_stall_time_avg_nanos", - &(self.lock_stall_time.load(Ordering::Relaxed) - / self.lock_stall_time_counter.load(Ordering::Relaxed)), - "The average time the scheduler stalled waiting on the lock to release in nanos.", - ); - c.publish( - "matching_engine", - &self.do_try_match, - "The job<->worker matching engine stats. This is a very expensive operation, so it is not run every time (often called do_try_match).", - ); + async fn set_drain_worker(&self, worker_id: &WorkerId, is_draining: bool) -> Result<(), Error> { + self.worker_scheduler + .set_drain_worker(worker_id, is_draining) + .await } } diff --git a/nativelink-scheduler/src/simple_scheduler_state_manager.rs b/nativelink-scheduler/src/simple_scheduler_state_manager.rs new file mode 100644 index 000000000..f8bd2fc4d --- /dev/null +++ b/nativelink-scheduler/src/simple_scheduler_state_manager.rs @@ -0,0 +1,480 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::ops::Bound; +use std::sync::Arc; + +use async_trait::async_trait; +use futures::{future, stream, StreamExt, TryStreamExt}; +use nativelink_error::{make_err, Code, Error, ResultExt}; +use nativelink_util::action_messages::{ + ActionInfo, ActionResult, ActionStage, ActionState, ActionUniqueQualifier, ClientOperationId, + ExecutionMetadata, OperationId, WorkerId, +}; +use nativelink_util::metrics_utils::{Collector, CollectorState, MetricsComponent, Registry}; +use nativelink_util::operation_state_manager::{ + ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager, + OperationFilter, OperationStageFlags, OrderDirection, WorkerStateManager, +}; +use tokio::sync::Notify; +use tracing::{event, Level}; + +use super::awaited_action_db::{ + AwaitedAction, AwaitedActionDb, AwaitedActionSubscriber, SortedAwaitedActionState, +}; +use crate::memory_awaited_action_db::{ClientActionStateResult, MatchingEngineActionStateResult}; + +/// Maximum number of times an update to the database +/// can fail before giving up. +const MAX_UPDATE_RETRIES: usize = 5; + +/// Simple struct that implements the ActionStateResult trait and always returns an error. +struct ErrorActionStateResult(Error); + +#[async_trait] +impl ActionStateResult for ErrorActionStateResult { + async fn as_state(&self) -> Result, Error> { + Err(self.0.clone()) + } + + async fn changed(&mut self) -> Result, Error> { + Err(self.0.clone()) + } + + async fn as_action_info(&self) -> Result, Error> { + Err(self.0.clone()) + } +} + +fn apply_filter_predicate(awaited_action: &AwaitedAction, filter: &OperationFilter) -> bool { + // Note: The caller must filter `client_operation_id`. + + if let Some(operation_id) = &filter.operation_id { + if operation_id != awaited_action.operation_id() { + return false; + } + } + + if filter.worker_id.is_some() && filter.worker_id != awaited_action.worker_id() { + return false; + } + + { + if let Some(filter_unique_key) = &filter.unique_key { + match &awaited_action.action_info().unique_qualifier { + ActionUniqueQualifier::Cachable(unique_key) => { + if filter_unique_key != unique_key { + return false; + } + } + ActionUniqueQualifier::Uncachable(_) => { + return false; + } + } + } + if let Some(action_digest) = filter.action_digest { + if action_digest != awaited_action.action_info().digest() { + return false; + } + } + } + + { + let last_worker_update_timestamp = awaited_action.last_worker_updated_timestamp(); + if let Some(worker_update_before) = filter.worker_update_before { + if worker_update_before < last_worker_update_timestamp { + return false; + } + } + if let Some(completed_before) = filter.completed_before { + if awaited_action.state().stage.is_finished() + && completed_before < last_worker_update_timestamp + { + return false; + } + } + if filter.stages != OperationStageFlags::Any { + let stage_flag = match awaited_action.state().stage { + ActionStage::Unknown => OperationStageFlags::Any, + ActionStage::CacheCheck => OperationStageFlags::CacheCheck, + ActionStage::Queued => OperationStageFlags::Queued, + ActionStage::Executing => OperationStageFlags::Executing, + ActionStage::Completed(_) => OperationStageFlags::Completed, + ActionStage::CompletedFromCache(_) => OperationStageFlags::Completed, + }; + if !filter.stages.intersects(stage_flag) { + return false; + } + } + } + + true +} + +/// MemorySchedulerStateManager is responsible for maintaining the state of the scheduler. +/// Scheduler state includes the actions that are queued, active, and recently completed. +/// It also includes the workers that are available to execute actions based on allocation +/// strategy. +pub struct SimpleSchedulerStateManager { + /// Database for storing the state of all actions. + action_db: T, + + /// Notify matching engine that work needs to be done. + tasks_change_notify: Arc, + + /// Maximum number of times a job can be retried. + // TODO(allada) This should be a scheduler decorator instead + // of always having it on every SimpleScheduler. + max_job_retries: usize, +} + +impl SimpleSchedulerStateManager { + pub fn new( + tasks_change_notify: Arc, + max_job_retries: usize, + action_db: T, + ) -> Arc { + Arc::new(Self { + action_db, + tasks_change_notify, + max_job_retries, + }) + } + + async fn inner_update_operation( + &self, + operation_id: &OperationId, + maybe_worker_id: Option<&WorkerId>, + action_stage_result: Result, + ) -> Result<(), Error> { + let mut last_err = None; + for _ in 0..MAX_UPDATE_RETRIES { + let maybe_awaited_action_subscriber = self + .action_db + .get_by_operation_id(operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::update_operation")?; + let awaited_action_subscriber = match maybe_awaited_action_subscriber { + Some(sub) => sub, + // No action found. It is ok if the action was not found. It probably + // means that the action was dropped, but worker was still processing + // it. + None => return Ok(()), + }; + + let mut awaited_action = awaited_action_subscriber.borrow(); + + // Make sure we don't update an action that is already completed. + if awaited_action.state().stage.is_finished() { + return Err(make_err!( + Code::Internal, + "Action {operation_id:?} is already completed with state {:?} - maybe_worker_id: {:?}", + awaited_action.state().stage, + maybe_worker_id, + )); + } + + // Make sure the worker id matches the awaited action worker id. + // This might happen if the worker sending the update is not the + // worker that was assigned. + if awaited_action.worker_id().is_some() + && maybe_worker_id.is_some() + && maybe_worker_id != awaited_action.worker_id().as_ref() + { + let err = make_err!( + Code::Internal, + "Worker ids do not match - {:?} != {:?} for {:?}", + maybe_worker_id, + awaited_action.worker_id(), + awaited_action, + ); + event!( + Level::ERROR, + ?operation_id, + ?maybe_worker_id, + ?awaited_action, + "{}", + err.to_string(), + ); + return Err(err); + } + + let stage = match &action_stage_result { + Ok(stage) => stage.clone(), + Err(err) => { + // Don't count a backpressure failure as an attempt for an action. + let due_to_backpressure = err.code == Code::ResourceExhausted; + if !due_to_backpressure { + awaited_action.attempts += 1; + } + + if awaited_action.attempts > self.max_job_retries { + ActionStage::Completed(ActionResult { + execution_metadata: ExecutionMetadata { + worker: maybe_worker_id.map_or_else(String::default, |v| v.to_string()), + ..ExecutionMetadata::default() + }, + error: Some(err.clone().merge(make_err!( + Code::Internal, + "Job cancelled because it attempted to execute too many times and failed {}", + format!("for operation_id: {operation_id}, maybe_worker_id: {maybe_worker_id:?}"), + ))), + ..ActionResult::default() + }) + } else { + ActionStage::Queued + } + } + }; + if matches!(stage, ActionStage::Queued) { + // If the action is queued, we need to unset the worker id regardless of + // which worker sent the update. + awaited_action.set_worker_id(None); + } else { + awaited_action.set_worker_id(maybe_worker_id.copied()); + } + awaited_action.set_state(Arc::new(ActionState { + stage, + id: operation_id.clone(), + })); + awaited_action.increment_version(); + + let update_action_result = self + .action_db + .update_awaited_action(awaited_action) + .await + .err_tip(|| "In MemorySchedulerStateManager::update_operation"); + if let Err(err) = update_action_result { + // We use Aborted to signal that the action was not + // updated due to the data being set was not the latest + // but can be retried. + if err.code == Code::Aborted { + last_err = Some(err); + continue; + } else { + return Err(err); + } + } + + self.tasks_change_notify.notify_one(); + return Ok(()); + } + match last_err { + Some(err) => Err(err), + None => Err(make_err!( + Code::Internal, + "Failed to update action after {} retries with no error set", + MAX_UPDATE_RETRIES, + )), + } + } + + async fn inner_add_operation( + &self, + new_client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result { + let rx = self + .action_db + .add_action(new_client_operation_id, action_info) + .await + .err_tip(|| "In MemorySchedulerStateManager::add_operation")?; + self.tasks_change_notify.notify_one(); + Ok(rx) + } + + async fn inner_filter_operations<'a, F>( + &'a self, + filter: OperationFilter, + to_action_state_result: F, + ) -> Result, Error> + where + F: Fn(T::Subscriber) -> Box + Send + Sync + 'a, + { + fn sorted_awaited_action_state_for_flags( + stage: OperationStageFlags, + ) -> Option { + match stage { + OperationStageFlags::CacheCheck => Some(SortedAwaitedActionState::CacheCheck), + OperationStageFlags::Queued => Some(SortedAwaitedActionState::Queued), + OperationStageFlags::Executing => Some(SortedAwaitedActionState::Executing), + OperationStageFlags::Completed => Some(SortedAwaitedActionState::Completed), + _ => None, + } + } + + if let Some(operation_id) = &filter.operation_id { + return Ok(self + .action_db + .get_by_operation_id(operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")? + .filter(|awaited_action_rx| { + let awaited_action = awaited_action_rx.borrow(); + apply_filter_predicate(&awaited_action, &filter) + }) + .map(|awaited_action| -> ActionStateResultStream { + Box::pin(stream::once(async move { + to_action_state_result(awaited_action) + })) + }) + .unwrap_or_else(|| Box::pin(stream::empty()))); + } + if let Some(client_operation_id) = &filter.client_operation_id { + return Ok(self + .action_db + .get_awaited_action_by_id(client_operation_id) + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")? + .filter(|awaited_action_rx| { + let awaited_action = awaited_action_rx.borrow(); + apply_filter_predicate(&awaited_action, &filter) + }) + .map(|awaited_action| -> ActionStateResultStream { + Box::pin(stream::once(async move { + to_action_state_result(awaited_action) + })) + }) + .unwrap_or_else(|| Box::pin(stream::empty()))); + } + + let Some(sorted_awaited_action_state) = + sorted_awaited_action_state_for_flags(filter.stages) + else { + let mut all_items: Vec = self + .action_db + .get_all_awaited_actions() + .await + .try_filter(|awaited_action_subscriber| { + future::ready(apply_filter_predicate( + &awaited_action_subscriber.borrow(), + &filter, + )) + }) + .try_collect() + .await + .err_tip(|| "In MemorySchedulerStateManager::filter_operations")?; + match filter.order_by_priority_direction { + Some(OrderDirection::Asc) => { + all_items.sort_unstable_by_key(|a| a.borrow().sort_key()) + } + Some(OrderDirection::Desc) => { + all_items.sort_unstable_by_key(|a| std::cmp::Reverse(a.borrow().sort_key())) + } + None => {} + } + return Ok(Box::pin(stream::iter( + all_items.into_iter().map(to_action_state_result), + ))); + }; + + let desc = matches!( + filter.order_by_priority_direction, + Some(OrderDirection::Desc) + ); + let filter = filter.clone(); + let stream = self + .action_db + .get_range_of_actions( + sorted_awaited_action_state, + Bound::Unbounded, + Bound::Unbounded, + desc, + ) + .await + .try_filter(move |sub| future::ready(apply_filter_predicate(&sub.borrow(), &filter))) + .map(move |result| -> Box { + result.map_or_else( + |e| -> Box { Box::new(ErrorActionStateResult(e)) }, + |v| -> Box { to_action_state_result(v) }, + ) + }); + Ok(Box::pin(stream)) + } +} + +#[async_trait] +impl ClientStateManager for SimpleSchedulerStateManager { + async fn add_action( + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result, Error> { + let sub = self + .inner_add_operation(client_operation_id.clone(), action_info.clone()) + .await?; + + Ok(Box::new(ClientActionStateResult::new(sub))) + } + + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter, move |rx| Box::new(ClientActionStateResult::new(rx))) + .await + } +} + +#[async_trait] +impl WorkerStateManager for SimpleSchedulerStateManager { + async fn update_operation( + &self, + operation_id: &OperationId, + worker_id: &WorkerId, + action_stage_result: Result, + ) -> Result<(), Error> { + self.inner_update_operation(operation_id, Some(worker_id), action_stage_result) + .await + } +} + +#[async_trait] +impl MatchingEngineStateManager for SimpleSchedulerStateManager { + async fn filter_operations<'a>( + &'a self, + filter: OperationFilter, + ) -> Result, Error> { + self.inner_filter_operations(filter, |rx| { + Box::new(MatchingEngineActionStateResult::new(rx)) + }) + .await + } + + async fn assign_operation( + &self, + operation_id: &OperationId, + worker_id_or_reason_for_unsassign: Result<&WorkerId, Error>, + ) -> Result<(), Error> { + let (maybe_worker_id, stage_result) = match worker_id_or_reason_for_unsassign { + Ok(worker_id) => (Some(worker_id), Ok(ActionStage::Executing)), + Err(err) => (None, Err(err)), + }; + self.inner_update_operation(operation_id, maybe_worker_id, stage_result) + .await + } + + /// Register metrics with the registry. + fn register_metrics(self: Arc, registry: &mut Registry) { + // TODO(allada) We only register the metrics in one of the components instead of + // all three because it's a bit tricky to separate the metrics for each component. + registry.register_collector(Box::new(Collector::new(&self))); + } +} + +impl MetricsComponent for SimpleSchedulerStateManager { + fn gather_metrics(&self, c: &mut CollectorState) { + c.publish("", &self.action_db, ""); + } +} diff --git a/nativelink-scheduler/src/worker.rs b/nativelink-scheduler/src/worker.rs index 475e9deac..dc9879f93 100644 --- a/nativelink-scheduler/src/worker.rs +++ b/nativelink-scheduler/src/worker.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashSet; +use std::collections::HashMap; use std::hash::{Hash, Hasher}; use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; @@ -21,7 +21,7 @@ use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; -use nativelink_util::action_messages::{ActionInfo, WorkerId}; +use nativelink_util::action_messages::{ActionInfo, OperationId, WorkerId}; use nativelink_util::metrics_utils::{ CollectorState, CounterWithTime, FuncCounterWrapper, MetricsComponent, }; @@ -33,7 +33,7 @@ pub type WorkerTimestamp = u64; /// Notifications to send worker about a requested state change. pub enum WorkerUpdate { /// Requests that the worker begin executing this action. - RunAction(Arc), + RunAction((OperationId, Arc)), /// Request that the worker is no longer in the pool and may discard any jobs. Disconnect, @@ -52,7 +52,7 @@ pub struct Worker { pub tx: UnboundedSender, /// The action info of the running actions on the worker - pub running_action_infos: HashSet>, + pub running_action_infos: HashMap>, /// Timestamp of last time this worker had been communicated with. // Warning: Do not update this timestamp without updating the placement of the worker in @@ -108,7 +108,7 @@ impl Worker { id, platform_properties, tx, - running_action_infos: HashSet::new(), + running_action_infos: HashMap::new(), last_update_timestamp: timestamp, is_paused: false, is_draining: false, @@ -140,7 +140,9 @@ impl Worker { /// Notifies the worker of a requested state change. pub fn notify_update(&mut self, worker_update: WorkerUpdate) -> Result<(), Error> { match worker_update { - WorkerUpdate::RunAction(action_info) => self.run_action(action_info), + WorkerUpdate::RunAction((operation_id, action_info)) => { + self.run_action(operation_id, action_info) + } WorkerUpdate::Disconnect => { self.metrics.notify_disconnect.inc(); send_msg_to_worker(&mut self.tx, update_for_worker::Update::Disconnect(())) @@ -157,13 +159,18 @@ impl Worker { }) } - fn run_action(&mut self, action_info: Arc) -> Result<(), Error> { + fn run_action( + &mut self, + operation_id: OperationId, + action_info: Arc, + ) -> Result<(), Error> { let tx = &mut self.tx; let worker_platform_properties = &mut self.platform_properties; let running_action_infos = &mut self.running_action_infos; self.metrics.run_action.wrap(move || { let action_info_clone = action_info.as_ref().clone(); - running_action_infos.insert(action_info.clone()); + let operation_id_string = operation_id.to_string(); + running_action_infos.insert(operation_id, action_info.clone()); reduce_platform_properties( worker_platform_properties, &action_info.platform_properties, @@ -172,18 +179,24 @@ impl Worker { tx, update_for_worker::Update::StartAction(StartExecute { execute_request: Some(action_info_clone.into()), - salt: *action_info.salt(), + operation_id: operation_id_string, queued_timestamp: Some(action_info.insert_timestamp.into()), }), ) }) } - pub fn complete_action(&mut self, action_info: &Arc) { - self.running_action_infos.remove(action_info); + pub(crate) fn complete_action(&mut self, operation_id: &OperationId) -> Result<(), Error> { + let action_info = self.running_action_infos.remove(operation_id).err_tip(|| { + format!( + "Worker {} tried to complete operation {} that was not running", + self.id, operation_id + ) + })?; self.restore_platform_properties(&action_info.platform_properties); self.is_paused = false; self.metrics.actions_completed.inc(); + Ok(()) } pub fn has_actions(&self) -> bool { @@ -277,8 +290,8 @@ impl MetricsComponent for Worker { "If this worker is draining.", vec![("worker_id".into(), format!("{}", self.id).into())], ); - for action_info in self.running_action_infos.iter() { - let action_name = action_info.unique_qualifier.action_name().to_string(); + for action_info in self.running_action_infos.values() { + let action_name = action_info.unique_qualifier.to_string(); c.publish_with_labels( "timeout", &action_info.timeout, @@ -303,12 +316,6 @@ impl MetricsComponent for Worker { "When this action was created.", vec![("digest".into(), action_name.clone().into())], ); - c.publish_with_labels( - "skip_cache_lookup", - &action_info.skip_cache_lookup, - "Weather this action should skip cache lookup.", - vec![("digest".into(), action_name.clone().into())], - ); } for (prop_name, prop_type_and_value) in &self.platform_properties.properties { match prop_type_and_value { diff --git a/nativelink-scheduler/src/worker_scheduler.rs b/nativelink-scheduler/src/worker_scheduler.rs index 74c35908e..a9317189c 100644 --- a/nativelink-scheduler/src/worker_scheduler.rs +++ b/nativelink-scheduler/src/worker_scheduler.rs @@ -16,7 +16,7 @@ use std::sync::Arc; use async_trait::async_trait; use nativelink_error::Error; -use nativelink_util::action_messages::{ActionInfoHashKey, ActionStage, WorkerId}; +use nativelink_util::action_messages::{ActionStage, OperationId, WorkerId}; use nativelink_util::metrics_utils::Registry; use crate::platform_property_manager::PlatformPropertyManager; @@ -36,7 +36,7 @@ pub trait WorkerScheduler: Sync + Send + Unpin { async fn update_action( &self, worker_id: &WorkerId, - action_info_hash_key: ActionInfoHashKey, + operation_id: &OperationId, action_stage: Result, ) -> Result<(), Error>; @@ -48,14 +48,14 @@ pub trait WorkerScheduler: Sync + Send + Unpin { ) -> Result<(), Error>; /// Removes worker from pool and reschedule any tasks that might be running on it. - async fn remove_worker(&self, worker_id: WorkerId); + async fn remove_worker(&self, worker_id: &WorkerId) -> Result<(), Error>; /// Removes timed out workers from the pool. This is called periodically by an /// external source. async fn remove_timedout_workers(&self, now_timestamp: WorkerTimestamp) -> Result<(), Error>; /// Sets if the worker is draining or not. - async fn set_drain_worker(&self, worker_id: WorkerId, is_draining: bool) -> Result<(), Error>; + async fn set_drain_worker(&self, worker_id: &WorkerId, is_draining: bool) -> Result<(), Error>; /// Register the metrics for the worker scheduler. fn register_metrics(self: Arc, _registry: &mut Registry) {} diff --git a/nativelink-scheduler/tests/action_messages_test.rs b/nativelink-scheduler/tests/action_messages_test.rs index 9eddaabe1..f5cac5582 100644 --- a/nativelink-scheduler/tests/action_messages_test.rs +++ b/nativelink-scheduler/tests/action_messages_test.rs @@ -12,9 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{BTreeSet, HashMap}; -use std::sync::Arc; -use std::time::{Duration, SystemTime}; +use std::collections::HashMap; +use std::time::SystemTime; use nativelink_error::Error; use nativelink_macro::nativelink_test; @@ -22,37 +21,28 @@ use nativelink_proto::build::bazel::remote::execution::v2::ExecuteResponse; use nativelink_proto::google::longrunning::{operation, Operation}; use nativelink_proto::google::rpc::Status; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, ExecutionMetadata, - OperationId, + ActionResult, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, ExecutionMetadata, OperationId, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; -use nativelink_util::platform_properties::PlatformProperties; use pretty_assertions::assert_eq; -const NOW_TIME: u64 = 10000; - -fn make_system_time(add_time: u64) -> SystemTime { - SystemTime::UNIX_EPOCH - .checked_add(Duration::from_secs(NOW_TIME + add_time)) - .unwrap() -} - #[nativelink_test] async fn action_state_any_url_test() -> Result<(), Error> { - let unique_qualifier = ActionInfoHashKey { + let unique_qualifier = ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: "foo_instance".to_string(), digest_function: DigestHasherFunc::Sha256, digest: DigestInfo::new([1u8; 32], 5), - salt: 0, - }; - let id = OperationId::new(unique_qualifier); + }); + let client_id = ClientOperationId::new(unique_qualifier.clone()); + let operation_id = OperationId::new(unique_qualifier); let action_state = ActionState { - id, + id: operation_id.clone(), // Result is only populated if has_action_result. stage: ActionStage::Completed(ActionResult::default()), }; - let operation: Operation = action_state.clone().into(); + let operation: Operation = action_state.as_operation(client_id); match &operation.result { Some(operation::Result::Response(any)) => assert_eq!( @@ -62,7 +52,7 @@ async fn action_state_any_url_test() -> Result<(), Error> { other => panic!("Expected Some(Result(Any)), got: {other:?}"), } - let action_state_round_trip: ActionState = operation.try_into()?; + let action_state_round_trip = ActionState::try_from_operation(operation, operation_id)?; assert_eq!(action_state, action_state_round_trip); Ok(()) @@ -101,115 +91,3 @@ async fn execute_response_status_message_is_some_on_success_test() -> Result<(), Ok(()) } - -#[nativelink_test] -async fn highest_priority_action_first() -> Result<(), Error> { - const INSTANCE_NAME: &str = "foobar_instance_name"; - - let high_priority_action = Arc::new(ActionInfo { - command_digest: DigestInfo::new([0u8; 32], 0), - input_root_digest: DigestInfo::new([0u8; 32], 0), - timeout: Duration::MAX, - platform_properties: PlatformProperties { - properties: HashMap::new(), - }, - priority: 1000, - load_timestamp: SystemTime::UNIX_EPOCH, - insert_timestamp: SystemTime::UNIX_EPOCH, - unique_qualifier: ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([0u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: true, - }); - let lowest_priority_action = Arc::new(ActionInfo { - command_digest: DigestInfo::new([0u8; 32], 0), - input_root_digest: DigestInfo::new([0u8; 32], 0), - timeout: Duration::MAX, - platform_properties: PlatformProperties { - properties: HashMap::new(), - }, - priority: 0, - load_timestamp: SystemTime::UNIX_EPOCH, - insert_timestamp: SystemTime::UNIX_EPOCH, - unique_qualifier: ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([1u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: true, - }); - let mut action_set = BTreeSet::>::new(); - action_set.insert(lowest_priority_action.clone()); - action_set.insert(high_priority_action.clone()); - - assert_eq!( - vec![high_priority_action, lowest_priority_action], - action_set - .iter() - .rev() - .cloned() - .collect::>>() - ); - - Ok(()) -} - -#[nativelink_test] -async fn equal_priority_earliest_first() -> Result<(), Error> { - const INSTANCE_NAME: &str = "foobar_instance_name"; - - let first_action = Arc::new(ActionInfo { - command_digest: DigestInfo::new([0u8; 32], 0), - input_root_digest: DigestInfo::new([0u8; 32], 0), - timeout: Duration::MAX, - platform_properties: PlatformProperties { - properties: HashMap::new(), - }, - priority: 0, - load_timestamp: SystemTime::UNIX_EPOCH, - insert_timestamp: SystemTime::UNIX_EPOCH, - unique_qualifier: ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([0u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: true, - }); - let current_action = Arc::new(ActionInfo { - command_digest: DigestInfo::new([0u8; 32], 0), - input_root_digest: DigestInfo::new([0u8; 32], 0), - timeout: Duration::MAX, - platform_properties: PlatformProperties { - properties: HashMap::new(), - }, - priority: 0, - load_timestamp: SystemTime::UNIX_EPOCH, - insert_timestamp: make_system_time(0), - unique_qualifier: ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([1u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: true, - }); - let mut action_set = BTreeSet::>::new(); - action_set.insert(current_action.clone()); - action_set.insert(first_action.clone()); - - assert_eq!( - vec![first_action, current_action], - action_set - .iter() - .rev() - .cloned() - .collect::>>() - ); - - Ok(()) -} diff --git a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs index 709313078..ad39312d3 100644 --- a/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs +++ b/nativelink-scheduler/tests/cache_lookup_scheduler_test.rs @@ -27,10 +27,12 @@ use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::ActionResult as ProtoActionResult; use nativelink_scheduler::action_scheduler::ActionScheduler; use nativelink_scheduler::cache_lookup_scheduler::CacheLookupScheduler; +use nativelink_scheduler::default_action_listener::DefaultActionListener; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_store::memory_store::MemoryStore; use nativelink_util::action_messages::{ - ActionInfoHashKey, ActionResult, ActionStage, ActionState, OperationId, + ActionResult, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, OperationId, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; @@ -85,42 +87,55 @@ async fn platform_property_manager_call_passed() -> Result<(), Error> { #[nativelink_test] async fn add_action_handles_skip_cache() -> Result<(), Error> { let context = make_cache_scheduler()?; - let action_info = make_base_action_info(UNIX_EPOCH); + let action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); let action_result = ProtoActionResult::from(ActionResult::default()); context .ac_store - .update_oneshot(*action_info.digest(), action_result.encode_to_vec().into()) + .update_oneshot(action_info.digest(), action_result.encode_to_vec().into()) .await?; let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { id: OperationId::new(action_info.unique_qualifier.clone()), stage: ActionStage::Queued, })); + let ActionUniqueQualifier::Cachable(action_key) = action_info.unique_qualifier.clone() else { + panic!("This test should be testing when item was cached first"); + }; let mut skip_cache_action = action_info.clone(); - skip_cache_action.skip_cache_lookup = true; + skip_cache_action.unique_qualifier = ActionUniqueQualifier::Uncachable(action_key); + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); let _ = join!( - context.cache_scheduler.add_action(skip_cache_action), + context + .cache_scheduler + .add_action(client_operation_id.clone(), skip_cache_action), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)) + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id, + forward_watch_channel_rx + )))) ); Ok(()) } #[nativelink_test] -async fn find_existing_action_call_passed() -> Result<(), Error> { +async fn find_by_client_operation_id_call_passed() -> Result<(), Error> { let context = make_cache_scheduler()?; - let action_name = ActionInfoHashKey { - instance_name: "instance".to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([8; 32], 1), - salt: 1000, - }; - let (actual_result, actual_action_name) = join!( - context.cache_scheduler.find_existing_action(&action_name), - context.mock_scheduler.expect_find_existing_action(None), + let client_operation_id = + ClientOperationId::new(ActionUniqueQualifier::Uncachable(ActionUniqueKey { + instance_name: "instance".to_string(), + digest_function: DigestHasherFunc::Sha256, + digest: DigestInfo::new([8; 32], 1), + })); + let (actual_result, actual_client_id) = join!( + context + .cache_scheduler + .find_by_client_operation_id(&client_operation_id), + context + .mock_scheduler + .expect_find_by_client_operation_id(Ok(None)), ); - assert_eq!(true, actual_result.is_none()); - assert_eq!(action_name, actual_action_name); + assert_eq!(true, actual_result.unwrap().is_none()); + assert_eq!(client_operation_id, actual_client_id); Ok(()) } diff --git a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs index 96af1fb27..56118504f 100644 --- a/nativelink-scheduler/tests/property_modifier_scheduler_test.rs +++ b/nativelink-scheduler/tests/property_modifier_scheduler_test.rs @@ -26,9 +26,13 @@ use nativelink_config::schedulers::{PlatformPropertyAddition, PropertyModificati use nativelink_error::Error; use nativelink_macro::nativelink_test; use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::default_action_listener::DefaultActionListener; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_scheduler::property_modifier_scheduler::PropertyModifierScheduler; -use nativelink_util::action_messages::{ActionInfoHashKey, ActionStage, ActionState, OperationId}; +use nativelink_util::action_messages::{ + ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, ClientOperationId, + OperationId, +}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::PlatformPropertyValue; @@ -66,7 +70,7 @@ async fn add_action_adds_property() -> Result<(), Error> { name: name.clone(), value: value.clone(), })]); - let action_info = make_base_action_info(UNIX_EPOCH); + let action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { id: OperationId::new(action_info.unique_qualifier.clone()), @@ -76,15 +80,22 @@ async fn add_action_adds_property() -> Result<(), Error> { name.clone(), PropertyType::exact, )]))); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([(name, PlatformPropertyValue::Exact(value))]), action_info.platform_properties.properties @@ -102,7 +113,7 @@ async fn add_action_overwrites_property() -> Result<(), Error> { name: name.clone(), value: replaced_value.clone(), })]); - let mut action_info = make_base_action_info(UNIX_EPOCH); + let mut action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); action_info .platform_properties .properties @@ -116,15 +127,22 @@ async fn add_action_overwrites_property() -> Result<(), Error> { name.clone(), PropertyType::exact, )]))); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([(name, PlatformPropertyValue::Exact(replaced_value))]), action_info.platform_properties.properties @@ -143,7 +161,7 @@ async fn add_action_property_added_after_remove() -> Result<(), Error> { value: value.clone(), }), ]); - let action_info = make_base_action_info(UNIX_EPOCH); + let action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { id: OperationId::new(action_info.unique_qualifier.clone()), @@ -153,15 +171,22 @@ async fn add_action_property_added_after_remove() -> Result<(), Error> { name.clone(), PropertyType::exact, )]))); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([(name, PlatformPropertyValue::Exact(value))]), action_info.platform_properties.properties @@ -180,7 +205,7 @@ async fn add_action_property_remove_after_add() -> Result<(), Error> { }), PropertyModification::remove(name.clone()), ]); - let action_info = make_base_action_info(UNIX_EPOCH); + let action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); let (_forward_watch_channel_tx, forward_watch_channel_rx) = watch::channel(Arc::new(ActionState { id: OperationId::new(action_info.unique_qualifier.clone()), @@ -190,15 +215,22 @@ async fn add_action_property_remove_after_add() -> Result<(), Error> { name, PropertyType::exact, )]))); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([]), action_info.platform_properties.properties @@ -211,7 +243,7 @@ async fn add_action_property_remove() -> Result<(), Error> { let name = "name".to_string(); let value = "value".to_string(); let context = make_modifier_scheduler(vec![PropertyModification::remove(name.clone())]); - let mut action_info = make_base_action_info(UNIX_EPOCH); + let mut action_info = make_base_action_info(UNIX_EPOCH, DigestInfo::zero_digest()); action_info .platform_properties .properties @@ -222,15 +254,22 @@ async fn add_action_property_remove() -> Result<(), Error> { stage: ActionStage::Queued, })); let platform_property_manager = Arc::new(PlatformPropertyManager::new(HashMap::new())); - let (_, _, action_info) = join!( - context.modifier_scheduler.add_action(action_info), + let client_operation_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let (_, _, (passed_client_operation_id, action_info)) = join!( + context + .modifier_scheduler + .add_action(client_operation_id.clone(), action_info), context .mock_scheduler .expect_get_platform_property_manager(Ok(platform_property_manager)), context .mock_scheduler - .expect_add_action(Ok(forward_watch_channel_rx)), + .expect_add_action(Ok(Box::pin(DefaultActionListener::new( + client_operation_id.clone(), + forward_watch_channel_rx + )))), ); + assert_eq!(client_operation_id, passed_client_operation_id); assert_eq!( HashMap::from([]), action_info.platform_properties.properties @@ -239,22 +278,23 @@ async fn add_action_property_remove() -> Result<(), Error> { } #[nativelink_test] -async fn find_existing_action_call_passed() -> Result<(), Error> { +async fn find_by_client_operation_id_call_passed() -> Result<(), Error> { let context = make_modifier_scheduler(vec![]); - let action_name = ActionInfoHashKey { + let operation_id = ClientOperationId::new(ActionUniqueQualifier::Uncachable(ActionUniqueKey { instance_name: "instance".to_string(), digest_function: DigestHasherFunc::Sha256, digest: DigestInfo::new([8; 32], 1), - salt: 1000, - }; - let (actual_result, actual_action_name) = join!( + })); + let (actual_result, actual_operation_id) = join!( context .modifier_scheduler - .find_existing_action(&action_name), - context.mock_scheduler.expect_find_existing_action(None), + .find_by_client_operation_id(&operation_id), + context + .mock_scheduler + .expect_find_by_client_operation_id(Ok(None)), ); - assert_eq!(true, actual_result.is_none()); - assert_eq!(action_name, actual_action_name); + assert_eq!(true, actual_result.unwrap().is_none()); + assert_eq!(operation_id, actual_operation_id); Ok(()) } diff --git a/nativelink-scheduler/tests/simple_scheduler_test.rs b/nativelink-scheduler/tests/simple_scheduler_test.rs index 1b4bfcb41..e2be299ac 100644 --- a/nativelink-scheduler/tests/simple_scheduler_test.rs +++ b/nativelink-scheduler/tests/simple_scheduler_test.rs @@ -13,29 +13,33 @@ // limitations under the License. use std::collections::HashMap; +use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use futures::poll; +use futures::task::Poll; use nativelink_error::{make_err, Code, Error, ResultExt}; use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::{digest_function, ExecuteRequest}; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ update_for_worker, ConnectionResult, StartExecute, UpdateForWorker, }; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_scheduler::simple_scheduler::SimpleScheduler; use nativelink_scheduler::worker::Worker; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_util::action_messages::{ - ActionInfoHashKey, ActionResult, ActionStage, ActionState, DirectoryInfo, ExecutionMetadata, - FileInfo, NameOrPath, OperationId, SymlinkInfo, WorkerId, INTERNAL_ERROR_EXIT_CODE, + ActionResult, ActionStage, ActionState, ActionUniqueKey, ActionUniqueQualifier, + ClientOperationId, DirectoryInfo, ExecutionMetadata, FileInfo, NameOrPath, OperationId, + SymlinkInfo, WorkerId, INTERNAL_ERROR_EXIT_CODE, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::{PlatformProperties, PlatformPropertyValue}; use pretty_assertions::assert_eq; -use tokio::sync::{mpsc, watch}; +use tokio::sync::mpsc; use utils::scheduler_utils::{make_base_action_info, INSTANCE_NAME}; use uuid::Uuid; @@ -43,6 +47,45 @@ mod utils { pub(crate) mod scheduler_utils; } +fn update_eq(expected: UpdateForWorker, actual: UpdateForWorker, ignore_id: bool) -> bool { + let Some(expected_update) = expected.update else { + return actual.update.is_none(); + }; + let Some(actual_update) = actual.update else { + return false; + }; + match actual_update { + update_for_worker::Update::Disconnect(()) => { + matches!(expected_update, update_for_worker::Update::Disconnect(())) + } + update_for_worker::Update::KeepAlive(()) => { + matches!(expected_update, update_for_worker::Update::KeepAlive(())) + } + update_for_worker::Update::StartAction(actual_update) => match expected_update { + update_for_worker::Update::StartAction(mut expected_update) => { + if ignore_id { + expected_update + .operation_id + .clone_from(&actual_update.operation_id); + } + expected_update == actual_update + } + _ => false, + }, + update_for_worker::Update::KillOperationRequest(actual_update) => match expected_update { + update_for_worker::Update::KillOperationRequest(expected_update) => { + expected_update == actual_update + } + _ => false, + }, + update_for_worker::Update::ConnectionResult(actual_update) => match expected_update { + update_for_worker::Update::ConnectionResult(expected_update) => { + expected_update == actual_update + } + _ => false, + }, + } +} async fn verify_initial_connection_message( worker_id: WorkerId, rx: &mut mpsc::UnboundedReceiver, @@ -88,11 +131,11 @@ async fn setup_action( action_digest: DigestInfo, platform_properties: PlatformProperties, insert_timestamp: SystemTime, -) -> Result>, Error> { - let mut action_info = make_base_action_info(insert_timestamp); +) -> Result>, Error> { + let mut action_info = make_base_action_info(insert_timestamp, action_digest); action_info.platform_properties = platform_properties; - action_info.unique_qualifier.digest = action_digest; - let result = scheduler.add_action(action_info).await; + let client_id = ClientOperationId::new(action_info.unique_qualifier.clone()); + let result = scheduler.add_action(client_id, action_info).await; tokio::task::yield_now().await; // Allow task<->worker matcher to run. result } @@ -103,7 +146,7 @@ const WORKER_TIMEOUT_S: u64 = 100; async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -112,13 +155,14 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp, ) - .await?; + .await + .unwrap(); { // Worker should have been sent an execute command. @@ -126,21 +170,21 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + // Operation ID is random so we ignore it. + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -156,7 +200,7 @@ async fn basic_add_action_with_one_worker_test() -> Result<(), Error> { async fn find_executing_action() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -165,21 +209,23 @@ async fn find_executing_action() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let client_rx = setup_action( + let action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp, ) - .await?; + .await + .unwrap(); + let client_operation_id = action_listener.client_operation_id().clone(); // Drop our receiver and look up a new one. - let unique_qualifier = client_rx.borrow().id.unique_qualifier.clone(); - drop(client_rx); - let mut client_rx = scheduler - .find_existing_action(&unique_qualifier) + drop(action_listener); + let mut action_listener = scheduler + .find_by_client_operation_id(&client_operation_id) .await - .err_tip(|| "Action not found")?; + .expect("Action not found") + .unwrap(); { // Worker should have been sent an execute command. @@ -187,21 +233,21 @@ async fn find_executing_action() -> Result<(), Error> { update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + // Operation ID is random so we ignore it. + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -217,7 +263,7 @@ async fn find_executing_action() -> Result<(), Error> { async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Error> { let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); let worker_id2: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler { worker_timeout_s: WORKER_TIMEOUT_S, ..Default::default() @@ -230,7 +276,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; let insert_timestamp1 = make_system_time(1); - let mut client_rx1 = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, PlatformProperties::default(), @@ -238,7 +284,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err ) .await?; let insert_timestamp2 = make_system_time(2); - let mut client_rx2 = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, PlatformProperties::default(), @@ -246,83 +292,87 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err ) .await?; - let unique_qualifier = ActionInfoHashKey { - instance_name: "".to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::zero_digest(), - salt: 0, - }; - - let id = OperationId::new(unique_qualifier); - let mut expected_action_state1 = ActionState { - // Name is a random string, so we ignore it and just make it the same. - id: id.clone(), - stage: ActionStage::Executing, - }; - let mut expected_action_state2 = ActionState { - // Name is a random string, so we ignore it and just make it the same. - id, - stage: ActionStage::Executing, + let mut expected_start_execute_for_worker1 = StartExecute { + execute_request: Some(ExecuteRequest { + instance_name: INSTANCE_NAME.to_string(), + action_digest: Some(action_digest1.into()), + digest_function: digest_function::Value::Sha256.into(), + ..Default::default() + }), + operation_id: "WILL BE SET BELOW".to_string(), + queued_timestamp: Some(insert_timestamp1.into()), }; - let execution_request_for_worker1 = UpdateForWorker { - update: Some(update_for_worker::Update::StartAction(StartExecute { - execute_request: Some(ExecuteRequest { - instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, - action_digest: Some(action_digest1.into()), - digest_function: digest_function::Value::Sha256.into(), - ..Default::default() - }), - salt: 0, - queued_timestamp: Some(insert_timestamp1.into()), - })), + let mut expected_start_execute_for_worker2 = StartExecute { + execute_request: Some(ExecuteRequest { + instance_name: INSTANCE_NAME.to_string(), + action_digest: Some(action_digest2.into()), + digest_function: digest_function::Value::Sha256.into(), + ..Default::default() + }), + operation_id: "WILL BE SET BELOW".to_string(), + queued_timestamp: Some(insert_timestamp2.into()), }; - { - // Worker1 should now see execution request. - let msg_for_worker = rx_from_worker1.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker1); - } - let execution_request_for_worker2 = UpdateForWorker { - update: Some(update_for_worker::Update::StartAction(StartExecute { - execute_request: Some(ExecuteRequest { - instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, - action_digest: Some(action_digest2.into()), - digest_function: digest_function::Value::Sha256.into(), - ..Default::default() - }), - salt: 0, - queued_timestamp: Some(insert_timestamp2.into()), - })), + let operation_id1 = { + // Worker1 should now see first execution request. + let update_for_worker = rx_from_worker1 + .recv() + .await + .expect("Worker terminated stream") + .update + .expect("`update` should be set on UpdateForWorker"); + let (operation_id, rx_start_execute) = match update_for_worker { + update_for_worker::Update::StartAction(start_execute) => ( + OperationId::try_from(start_execute.operation_id.as_str()).unwrap(), + start_execute, + ), + v => panic!("Expected StartAction, got : {v:?}"), + }; + expected_start_execute_for_worker1.operation_id = operation_id.to_string(); + assert_eq!(expected_start_execute_for_worker1, rx_start_execute); + operation_id }; - { + let operation_id2 = { // Worker1 should now see second execution request. - let msg_for_worker = rx_from_worker1.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker2); - } + let update_for_worker = rx_from_worker1 + .recv() + .await + .expect("Worker terminated stream") + .update + .expect("`update` should be set on UpdateForWorker"); + let (operation_id, rx_start_execute) = match update_for_worker { + update_for_worker::Update::StartAction(start_execute) => ( + OperationId::try_from(start_execute.operation_id.as_str()).unwrap(), + start_execute, + ), + v => panic!("Expected StartAction, got : {v:?}"), + }; + expected_start_execute_for_worker2.operation_id = operation_id.to_string(); + assert_eq!(expected_start_execute_for_worker2, rx_start_execute); + operation_id + }; // Add a second worker that can take jobs if the first dies. let mut rx_from_worker2 = setup_new_worker(&scheduler, worker_id2, PlatformProperties::default()).await?; { + let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx1.borrow_and_update(); + let action_state = client1_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. - expected_action_state1.id = action_state.id.clone(); - assert_eq!(action_state.as_ref(), &expected_action_state1); + assert_eq!(&action_state.stage, &expected_action_stage); } { + let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx2.borrow_and_update(); + let action_state = client2_action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. - expected_action_state2.id = action_state.id.clone(); - assert_eq!(action_state.as_ref(), &expected_action_state2); + assert_eq!(&action_state.stage, &expected_action_stage); } // Now remove worker. - scheduler.remove_worker(worker_id1).await; + let _ = scheduler.remove_worker(&worker_id1).await; tokio::task::yield_now().await; // Allow task<->worker matcher to run. { @@ -336,26 +386,44 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err ); } { + let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx1.borrow_and_update(); - expected_action_state1.stage = ActionStage::Executing; - assert_eq!(action_state.as_ref(), &expected_action_state1); + let action_state = client1_action_listener.changed().await.unwrap(); + // We now know the name of the action so populate it. + assert_eq!(&action_state.stage, &expected_action_stage); } { + let expected_action_stage = ActionStage::Executing; // Client should get notification saying it's being executed. - let action_state = client_rx2.borrow_and_update(); - expected_action_state2.stage = ActionStage::Executing; - assert_eq!(action_state.as_ref(), &expected_action_state2); + let action_state = client2_action_listener.changed().await.unwrap(); + // We now know the name of the action so populate it. + assert_eq!(&action_state.stage, &expected_action_stage); } { // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker1); + expected_start_execute_for_worker1.operation_id = operation_id1.to_string(); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::StartAction( + expected_start_execute_for_worker1 + )), + } + ); } { // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker2); + expected_start_execute_for_worker2.operation_id = operation_id2.to_string(); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::StartAction( + expected_start_execute_for_worker2 + )), + } + ); } Ok(()) @@ -365,7 +433,7 @@ async fn remove_worker_reschedules_multiple_running_job_test() -> Result<(), Err async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -374,7 +442,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -382,23 +450,29 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> ) .await?; - { + let _operation_id = { // Other tests check full data. We only care if we got StartAction. - match rx_from_worker.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + let operation_id = match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(start_execute)) => { + OperationId::try_from(start_execute.operation_id.as_str()).unwrap() + } v => panic!("Expected StartAction, got : {v:?}"), - } + }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); - } + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); + operation_id + }; // Set the worker draining. - scheduler.set_drain_worker(worker_id, true).await?; + scheduler.set_drain_worker(&worker_id, true).await?; tokio::task::yield_now().await; let action_digest = DigestInfo::new([88u8; 32], 512); let insert_timestamp = make_system_time(14); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -408,7 +482,7 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> { // Client should get notification saying it's been queued. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -418,12 +492,12 @@ async fn set_drain_worker_pauses_and_resumes_worker_test() -> Result<(), Error> } // Set the worker not draining. - scheduler.set_drain_worker(worker_id, false).await?; + scheduler.set_drain_worker(&worker_id, false).await?; tokio::task::yield_now().await; { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -440,7 +514,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); let worker_id2: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -459,7 +533,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, platform_properties.clone()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, worker_properties.clone(), @@ -469,7 +543,7 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E { // Client should get notification saying it's been queued. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -485,21 +559,20 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), })), }; let msg_for_worker = rx_from_worker2.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -521,18 +594,17 @@ async fn worker_should_not_queue_if_properties_dont_match_test() -> Result<(), E async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); - let unique_qualifier = ActionInfoHashKey { + let unique_qualifier = ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: "".to_string(), digest: DigestInfo::zero_digest(), digest_function: DigestHasherFunc::Sha256, - salt: 0, - }; + }); let id = OperationId::new(unique_qualifier); let mut expected_action_state = ActionState { id, @@ -541,14 +613,14 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { let insert_timestamp1 = make_system_time(1); let insert_timestamp2 = make_system_time(2); - let mut client1_rx = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), insert_timestamp1, ) .await?; - let mut client2_rx = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -558,8 +630,8 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { { // Clients should get notification saying it's been queued. - let action_state1 = client1_rx.borrow_and_update(); - let action_state2 = client2_rx.borrow_and_update(); + let action_state1 = client1_action_listener.changed().await.unwrap(); + let action_state2 = client2_action_listener.changed().await.unwrap(); // Name is random so we set force it to be the same. expected_action_state.id = action_state1.id.clone(); assert_eq!(action_state1.as_ref(), &expected_action_state); @@ -575,17 +647,17 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp1.into()), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + // Operation ID is random so we ignore it. + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } // Action should now be executing. @@ -594,11 +666,11 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { // Both client1 and client2 should be receiving the same updates. // Most importantly the `name` (which is random) will be the same. assert_eq!( - client1_rx.borrow_and_update().as_ref(), + client1_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); assert_eq!( - client2_rx.borrow_and_update().as_ref(), + client2_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -606,7 +678,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { { // Now if another action is requested it should also join with executing action. let insert_timestamp3 = make_system_time(2); - let mut client3_rx = setup_action( + let mut client3_action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -614,7 +686,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { ) .await?; assert_eq!( - client3_rx.borrow_and_update().as_ref(), + client3_action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -625,7 +697,7 @@ async fn cacheable_items_join_same_action_queued_test() -> Result<(), Error> { #[nativelink_test] async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -638,7 +710,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), drop(rx_from_worker); let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -647,7 +719,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), .await?; { // Client should get notification saying it's being queued not executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -663,7 +735,7 @@ async fn worker_disconnects_does_not_schedule_for_execution_test() -> Result<(), async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); let worker_id2: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler { worker_timeout_s: WORKER_TIMEOUT_S, ..Default::default() @@ -676,7 +748,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -688,44 +760,49 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { let mut rx_from_worker2 = setup_new_worker(&scheduler, worker_id2, PlatformProperties::default()).await?; - let unique_qualifier = ActionInfoHashKey { - instance_name: "".to_string(), - digest: DigestInfo::zero_digest(), - digest_function: DigestHasherFunc::Sha256, - salt: 0, - }; - let id = OperationId::new(unique_qualifier); - let mut expected_action_state = ActionState { - id, - stage: ActionStage::Executing, - }; - - let execution_request_for_worker = UpdateForWorker { - update: Some(update_for_worker::Update::StartAction(StartExecute { - execute_request: Some(ExecuteRequest { - instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, - action_digest: Some(action_digest.into()), - digest_function: digest_function::Value::Sha256.into(), - ..Default::default() - }), - salt: 0, - queued_timestamp: Some(insert_timestamp.into()), - })), + let mut start_execute = StartExecute { + execute_request: Some(ExecuteRequest { + instance_name: INSTANCE_NAME.to_string(), + action_digest: Some(action_digest.into()), + digest_function: digest_function::Value::Sha256.into(), + ..Default::default() + }), + operation_id: "UNKNOWN HERE, WE WILL SET IT LATER".to_string(), + queued_timestamp: Some(insert_timestamp.into()), }; - { + let operation_id = { // Worker1 should now see execution request. let msg_for_worker = rx_from_worker1.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker); - } + let operation_id = if let update_for_worker::Update::StartAction(start_execute) = + msg_for_worker.update.as_ref().unwrap() + { + start_execute.operation_id.clone() + } else { + panic!("Expected StartAction, got : {msg_for_worker:?}"); + }; + start_execute.operation_id.clone_from(&operation_id); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::StartAction( + start_execute.clone() + )), + } + ); + OperationId::try_from(operation_id.as_str()).unwrap() + }; { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); - // We now know the name of the action so populate it. - expected_action_state.id = action_state.id.clone(); - assert_eq!(action_state.as_ref(), &expected_action_state); + let action_state = action_listener.changed().await.unwrap(); + assert_eq!( + action_state.as_ref(), + &ActionState { + id: operation_id.clone(), + stage: ActionStage::Executing, + } + ); } // Keep worker 2 alive. @@ -750,14 +827,26 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { } { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); - expected_action_state.stage = ActionStage::Executing; - assert_eq!(action_state.as_ref(), &expected_action_state); + let action_state = action_listener.changed().await.unwrap(); + assert_eq!( + action_state.as_ref(), + &ActionState { + id: operation_id.clone(), + stage: ActionStage::Executing, + } + ); } { // Worker2 should now see execution request. let msg_for_worker = rx_from_worker2.recv().await.unwrap(); - assert_eq!(msg_for_worker, execution_request_for_worker); + assert_eq!( + msg_for_worker, + UpdateForWorker { + update: Some(update_for_worker::Update::StartAction( + start_execute.clone() + )), + } + ); } Ok(()) @@ -767,7 +856,7 @@ async fn worker_timesout_reschedules_running_job_test() -> Result<(), Error> { async fn update_action_sends_completed_result_to_client_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -776,7 +865,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -784,22 +873,21 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err ) .await?; - { + let operation_id = { // Other tests check full data. We only care if we got StartAction. match rx_from_worker.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + Some(update_for_worker::Update::StartAction(start_execute)) => { + // Other tests check full data. We only care if client thinks we are Executing. + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); + start_execute.operation_id + } v => panic!("Expected StartAction, got : {v:?}"), } - // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); - } - - let action_info_hash_key = ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, }; + let action_result = ActionResult { output_files: vec![FileInfo { name_or_path: NameOrPath::Name("hello".to_string()), @@ -840,14 +928,14 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err scheduler .update_action( &worker_id, - action_info_hash_key, + &OperationId::try_from(operation_id.as_str())?, Ok(ActionStage::Completed(action_result.clone())), ) .await?; { // Client should get notification saying it has been completed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -855,14 +943,6 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err }; assert_eq!(action_state.as_ref(), &expected_action_state); } - { - // Update info for the action should now be closed (notification happens through Err). - let result = client_rx.changed().await; - assert!( - result.is_err(), - "Expected result to be an error : {result:?}" - ); - } Ok(()) } @@ -871,7 +951,7 @@ async fn update_action_sends_completed_result_to_client_test() -> Result<(), Err async fn update_action_sends_completed_result_after_disconnect() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -880,7 +960,7 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let client_rx = setup_action( + let action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -888,24 +968,21 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E ) .await?; + let client_id = action_listener.client_operation_id().clone(); + // Drop our receiver and don't reconnect until completed. - let unique_qualifier = client_rx.borrow().id.unique_qualifier.clone(); - drop(client_rx); + drop(action_listener); - { + let operation_id = { // Other tests check full data. We only care if we got StartAction. - match rx_from_worker.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + let operation_id = match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(exec)) => exec.operation_id, v => panic!("Expected StartAction, got : {v:?}"), - } - } - - let action_info_hash_key = ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, + }; + // Other tests check full data. We only care if client thinks we are Executing. + OperationId::try_from(operation_id.as_str())? }; + let action_result = ActionResult { output_files: vec![FileInfo { name_or_path: NameOrPath::Name("hello".to_string()), @@ -946,19 +1023,20 @@ async fn update_action_sends_completed_result_after_disconnect() -> Result<(), E scheduler .update_action( &worker_id, - action_info_hash_key, + &operation_id, Ok(ActionStage::Completed(action_result.clone())), ) .await?; // Now look up a channel after the action has completed. - let mut client_rx = scheduler - .find_existing_action(&unique_qualifier) + let mut action_listener = scheduler + .find_by_client_operation_id(&client_id) .await - .err_tip(|| "Action not found")?; + .unwrap() + .expect("Action not found"); { // Client should get notification saying it has been completed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -975,7 +1053,7 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { let good_worker_id: WorkerId = WorkerId(Uuid::new_v4()); let rogue_worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -984,7 +1062,7 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { let mut rx_from_worker = setup_new_worker(&scheduler, good_worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -999,15 +1077,18 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } + let _ = setup_new_worker(&scheduler, rogue_worker_id, PlatformProperties::default()).await?; - let action_info_hash_key = ActionInfoHashKey { + let action_info_hash_key = ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: INSTANCE_NAME.to_string(), digest_function: DigestHasherFunc::Sha256, digest: action_digest, - salt: 0, - }; + }); let action_result = ActionResult { output_files: Vec::default(), output_folders: Vec::default(), @@ -1035,14 +1116,13 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { let update_action_result = scheduler .update_action( &rogue_worker_id, - action_info_hash_key, + &OperationId::new(action_info_hash_key), Ok(ActionStage::Completed(action_result.clone())), ) .await; { - const EXPECTED_ERR: &str = - "Got a result from a worker that should not be running the action"; + const EXPECTED_ERR: &str = "should not be running on worker"; // Our request should have sent an error back. assert!( update_action_result.is_err(), @@ -1058,8 +1138,8 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { { // Ensure client did not get notified. assert_eq!( - client_rx.has_changed().unwrap(), - false, + poll!(action_listener.changed()), + Poll::Pending, "Client should not have been notified of event" ); } @@ -1071,18 +1151,17 @@ async fn update_action_with_wrong_worker_id_errors_test() -> Result<(), Error> { async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); let action_digest = DigestInfo::new([99u8; 32], 512); - let unique_qualifier = ActionInfoHashKey { + let unique_qualifier = ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: "".to_string(), digest: DigestInfo::zero_digest(), digest_function: DigestHasherFunc::Sha256, - salt: 0, - }; + }); let id = OperationId::new(unique_qualifier); let mut expected_action_state = ActionState { id, @@ -1090,7 +1169,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro }; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1106,26 +1185,27 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro update: Some(update_for_worker::Update::StartAction(StartExecute { execute_request: Some(ExecuteRequest { instance_name: INSTANCE_NAME.to_string(), - skip_cache_lookup: true, action_digest: Some(action_digest.into()), digest_function: digest_function::Value::Sha256.into(), ..Default::default() }), - salt: 0, + operation_id: "Unknown Generated internally".to_string(), queued_timestamp: Some(insert_timestamp.into()), })), }; let msg_for_worker = rx_from_worker.recv().await.unwrap(); - assert_eq!(msg_for_worker, expected_msg_for_worker); + // Operation ID is random so we ignore it. + assert!(update_eq(expected_msg_for_worker, msg_for_worker, true)); } - { + let operation_id = { // Client should get notification saying it's being executed. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); // We now know the name of the action so populate it. expected_action_state.id = action_state.id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); - } + action_state.id.clone() + }; let action_result = ActionResult { output_files: Vec::default(), @@ -1155,12 +1235,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro scheduler .update_action( &worker_id, - ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, - }, + &operation_id, Ok(ActionStage::Completed(action_result.clone())), ) .await?; @@ -1169,7 +1244,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro // Action should now be executing. expected_action_state.stage = ActionStage::Completed(action_result.clone()); assert_eq!( - client_rx.borrow_and_update().as_ref(), + action_listener.changed().await.unwrap().as_ref(), &expected_action_state ); } @@ -1179,7 +1254,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro { let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1188,7 +1263,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro .await?; // We didn't disconnect our worker, so it will have scheduled it to the worker. expected_action_state.stage = ActionStage::Executing; - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); // The name of the action changed (since it's a new action), so update it. expected_action_state.id = action_state.id.clone(); assert_eq!(action_state.as_ref(), &expected_action_state); @@ -1203,7 +1278,7 @@ async fn does_not_crash_if_operation_joined_then_relaunched() -> Result<(), Erro async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -1216,7 +1291,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, platform_properties.clone()).await?; let insert_timestamp1 = make_system_time(1); - let mut client1_rx = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, platform_properties.clone(), @@ -1224,7 +1299,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> ) .await?; let insert_timestamp2 = make_system_time(1); - let mut client2_rx = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, platform_properties, @@ -1236,12 +1311,15 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } v => panic!("Expected StartAction, got : {v:?}"), } - { + let (operation_id1, operation_id2) = { + let state_1 = client1_action_listener.changed().await.unwrap(); + let state_2 = client2_action_listener.changed().await.unwrap(); // First client should be in an Executing state. - assert_eq!(client1_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!(state_1.stage, ActionStage::Executing); // Second client should be in a queued state. - assert_eq!(client2_rx.borrow_and_update().stage, ActionStage::Queued); - } + assert_eq!(state_2.stage, ActionStage::Queued); + (state_1.id.clone(), state_2.id.clone()) + }; let action_result = ActionResult { output_files: Vec::default(), @@ -1272,25 +1350,14 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> scheduler .update_action( &worker_id, - ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest1, - salt: 0, - }, + &operation_id1, Ok(ActionStage::Completed(action_result.clone())), ) .await?; - // Ensure client did not get notified. - assert!( - client1_rx.changed().await.is_ok(), - "Client should have been notified of event" - ); - { // First action should now be completed. - let action_state = client1_rx.borrow_and_update(); + let action_state = client1_action_listener.changed().await.unwrap(); let mut expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1311,26 +1378,24 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client2_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + client2_action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } // Tell scheduler our second task is completed. scheduler .update_action( &worker_id, - ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest2, - salt: 0, - }, + &operation_id2, Ok(ActionStage::Completed(action_result.clone())), ) .await?; { // Our second client should be notified it completed. - let action_state = client2_rx.borrow_and_update(); + let action_state = client2_action_listener.changed().await.unwrap(); let mut expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1349,7 +1414,7 @@ async fn run_two_jobs_on_same_worker_with_platform_properties_restrictions() -> async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -1363,7 +1428,7 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { // This is queued after the next one (even though it's placed in the map // first), so it should execute second. let insert_timestamp2 = make_system_time(2); - let mut client2_rx = setup_action( + let mut client2_action_listener = setup_action( &scheduler, action_digest2, platform_properties.clone(), @@ -1371,7 +1436,7 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { ) .await?; let insert_timestamp1 = make_system_time(1); - let mut client1_rx = setup_action( + let mut client1_action_listener = setup_action( &scheduler, action_digest1, platform_properties.clone(), @@ -1388,9 +1453,15 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { } { // First client should be in an Executing state. - assert_eq!(client1_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + client1_action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); // Second client should be in a queued state. - assert_eq!(client2_rx.borrow_and_update().stage, ActionStage::Queued); + assert_eq!( + client2_action_listener.changed().await.unwrap().stage, + ActionStage::Queued + ); } Ok(()) @@ -1400,9 +1471,9 @@ async fn run_jobs_in_the_order_they_were_queued() -> Result<(), Error> { async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> { let worker_id: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler { - max_job_retries: 2, + max_job_retries: 1, ..Default::default() }, || async move {}, @@ -1412,7 +1483,7 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> let mut rx_from_worker = setup_new_worker(&scheduler, worker_id, PlatformProperties::default()).await?; let insert_timestamp = make_system_time(1); - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1420,33 +1491,31 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> ) .await?; - { + let operation_id = { // Other tests check full data. We only care if we got StartAction. - match rx_from_worker.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + let operation_id = match rx_from_worker.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(exec)) => exec.operation_id, v => panic!("Expected StartAction, got : {v:?}"), - } + }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); - } - - let action_info_hash_key = ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); + OperationId::try_from(operation_id.as_str())? }; + let _ = scheduler .update_action( &worker_id, - action_info_hash_key.clone(), + &operation_id, Err(make_err!(Code::Internal, "Some error")), ) .await; { // Client should get notification saying it has been queued again. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1465,18 +1534,21 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> v => panic!("Expected StartAction, got : {v:?}"), } // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } let err = make_err!(Code::Internal, "Some error"); // Send internal error from worker again. let _ = scheduler - .update_action(&worker_id, action_info_hash_key, Err(err.clone())) + .update_action(&worker_id, &operation_id, Err(err.clone())) .await; { // Client should get notification saying it has been queued again. - let action_state = client_rx.borrow_and_update(); + let action_state = action_listener.changed().await.unwrap(); let expected_action_state = ActionState { // Name is a random string, so we ignore it and just make it the same. id: action_state.id.clone(), @@ -1501,14 +1573,23 @@ async fn worker_retries_on_internal_error_and_fails_test() -> Result<(), Error> output_upload_completed_timestamp: SystemTime::UNIX_EPOCH, }, server_logs: HashMap::default(), - error: Some(err.merge(make_err!( - Code::Internal, - "Job cancelled because it attempted to execute too many times and failed" - ))), + error: Some(err.clone()), message: String::new(), }), }; - assert_eq!(action_state.as_ref(), &expected_action_state); + let mut received_state = action_state.as_ref().clone(); + if let ActionStage::Completed(stage) = &mut received_state.stage { + if let Some(real_err) = &mut stage.error { + assert!( + real_err.to_string().contains("Job cancelled because it attempted to execute too many times and failed"), + "{real_err} did not contain 'Job cancelled because it attempted to execute too many times and failed'", + ); + *real_err = err; + } + } else { + panic!("Expected Completed, got : {:?}", action_state.stage); + }; + assert_eq!(received_state, expected_action_state); } Ok(()) @@ -1533,7 +1614,7 @@ async fn ensure_scheduler_drops_inner_spawn() -> Result<(), Error> { // Since the inner spawn owns this callback, we can use the callback to know if the // inner spawn was dropped because our callback would be dropped, which dropps our // DropChecker. - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), move || { // This will ensure dropping happens if this function is ever dropped. @@ -1558,7 +1639,7 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), let worker_id1: WorkerId = WorkerId(Uuid::new_v4()); let worker_id2: WorkerId = WorkerId(Uuid::new_v4()); - let scheduler = SimpleScheduler::new_with_callback( + let (scheduler, _worker_scheduler) = SimpleScheduler::new_with_callback( &nativelink_config::schedulers::SimpleScheduler::default(), || async move {}, ); @@ -1566,7 +1647,7 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), let mut rx_from_worker1 = setup_new_worker(&scheduler, worker_id1, PlatformProperties::default()).await?; - let mut client_rx = setup_action( + let mut action_listener = setup_action( &scheduler, action_digest, PlatformProperties::default(), @@ -1577,25 +1658,24 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), let mut rx_from_worker2 = setup_new_worker(&scheduler, worker_id2, PlatformProperties::default()).await?; - { + let operation_id = { // Other tests check full data. We only care if we got StartAction. - match rx_from_worker1.recv().await.unwrap().update { - Some(update_for_worker::Update::StartAction(_)) => { /* Success */ } + let operation_id = match rx_from_worker1.recv().await.unwrap().update { + Some(update_for_worker::Update::StartAction(exec)) => exec.operation_id, v => panic!("Expected StartAction, got : {v:?}"), - } + }; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); - } + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); + OperationId::try_from(operation_id.as_str())? + }; let _ = scheduler .update_action( &worker_id1, - ActionInfoHashKey { - instance_name: INSTANCE_NAME.to_string(), - digest_function: DigestHasherFunc::Sha256, - digest: action_digest, - salt: 0, - }, + &operation_id, Err(make_err!(Code::NotFound, "Some error")), ) .await; @@ -1610,7 +1690,10 @@ async fn ensure_task_or_worker_change_notification_received_test() -> Result<(), .await .err_tip(|| "worker went away")?; // Other tests check full data. We only care if client thinks we are Executing. - assert_eq!(client_rx.borrow_and_update().stage, ActionStage::Executing); + assert_eq!( + action_listener.changed().await.unwrap().stage, + ActionStage::Executing + ); } Ok(()) diff --git a/nativelink-scheduler/tests/utils/mock_scheduler.rs b/nativelink-scheduler/tests/utils/mock_scheduler.rs index 9afd1dd6b..f02a85f13 100644 --- a/nativelink-scheduler/tests/utils/mock_scheduler.rs +++ b/nativelink-scheduler/tests/utils/mock_scheduler.rs @@ -12,26 +12,27 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::pin::Pin; use std::sync::Arc; use async_trait::async_trait; use nativelink_error::{make_input_err, Error}; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionState}; -use tokio::sync::{mpsc, watch, Mutex}; +use nativelink_util::action_messages::{ActionInfo, ClientOperationId}; +use tokio::sync::{mpsc, Mutex}; #[allow(clippy::large_enum_variant)] enum ActionSchedulerCalls { GetPlatformPropertyManager(String), - AddAction(ActionInfo), - FindExistingAction(ActionInfoHashKey), + AddAction((ClientOperationId, ActionInfo)), + FindExistingAction(ClientOperationId), } enum ActionSchedulerReturns { GetPlatformPropertyManager(Result, Error>), - AddAction(Result>, Error>), - FindExistingAction(Option>>), + AddAction(Result>, Error>), + FindExistingAction(Result>>, Error>), } pub struct MockActionScheduler { @@ -81,8 +82,8 @@ impl MockActionScheduler { pub async fn expect_add_action( &self, - result: Result>, Error>, - ) -> ActionInfo { + result: Result>, Error>, + ) -> (ClientOperationId, ActionInfo) { let mut rx_call_lock = self.rx_call.lock().await; let ActionSchedulerCalls::AddAction(req) = rx_call_lock .recv() @@ -98,17 +99,17 @@ impl MockActionScheduler { req } - pub async fn expect_find_existing_action( + pub async fn expect_find_by_client_operation_id( &self, - result: Option>>, - ) -> ActionInfoHashKey { + result: Result>>, Error>, + ) -> ClientOperationId { let mut rx_call_lock = self.rx_call.lock().await; let ActionSchedulerCalls::FindExistingAction(req) = rx_call_lock .recv() .await .expect("Could not receive msg in mpsc") else { - panic!("Got incorrect call waiting for find_existing_action") + panic!("Got incorrect call waiting for find_by_client_operation_id") }; self.tx_resp .send(ActionSchedulerReturns::FindExistingAction(result)) @@ -142,10 +143,14 @@ impl ActionScheduler for MockActionScheduler { async fn add_action( &self, + client_operation_id: ClientOperationId, action_info: ActionInfo, - ) -> Result>, Error> { + ) -> Result>, Error> { self.tx_call - .send(ActionSchedulerCalls::AddAction(action_info)) + .send(ActionSchedulerCalls::AddAction(( + client_operation_id, + action_info, + ))) .expect("Could not send request to mpsc"); let mut rx_resp_lock = self.rx_resp.lock().await; match rx_resp_lock @@ -158,13 +163,13 @@ impl ActionScheduler for MockActionScheduler { } } - async fn find_existing_action( + async fn find_by_client_operation_id( &self, - unique_qualifier: &ActionInfoHashKey, - ) -> Option>> { + client_operation_id: &ClientOperationId, + ) -> Result>>, Error> { self.tx_call .send(ActionSchedulerCalls::FindExistingAction( - unique_qualifier.clone(), + client_operation_id.clone(), )) .expect("Could not send request to mpsc"); let mut rx_resp_lock = self.rx_resp.lock().await; @@ -174,9 +179,7 @@ impl ActionScheduler for MockActionScheduler { .expect("Could not receive msg in mpsc") { ActionSchedulerReturns::FindExistingAction(result) => result, - _ => panic!("Expected find_existing_action return value"), + _ => panic!("Expected find_by_client_operation_id return value"), } } - - async fn clean_recently_completed_actions(&self) {} } diff --git a/nativelink-scheduler/tests/utils/scheduler_utils.rs b/nativelink-scheduler/tests/utils/scheduler_utils.rs index b98c4cdf8..8d9635b9c 100644 --- a/nativelink-scheduler/tests/utils/scheduler_utils.rs +++ b/nativelink-scheduler/tests/utils/scheduler_utils.rs @@ -15,14 +15,17 @@ use std::collections::HashMap; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey}; +use nativelink_util::action_messages::{ActionInfo, ActionUniqueKey, ActionUniqueQualifier}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; use nativelink_util::platform_properties::PlatformProperties; pub const INSTANCE_NAME: &str = "foobar_instance_name"; -pub fn make_base_action_info(insert_timestamp: SystemTime) -> ActionInfo { +pub fn make_base_action_info( + insert_timestamp: SystemTime, + action_digest: DigestInfo, +) -> ActionInfo { ActionInfo { command_digest: DigestInfo::new([0u8; 32], 0), input_root_digest: DigestInfo::new([0u8; 32], 0), @@ -33,12 +36,10 @@ pub fn make_base_action_info(insert_timestamp: SystemTime) -> ActionInfo { priority: 0, load_timestamp: UNIX_EPOCH, insert_timestamp, - unique_qualifier: ActionInfoHashKey { + unique_qualifier: ActionUniqueQualifier::Cachable(ActionUniqueKey { instance_name: INSTANCE_NAME.to_string(), digest_function: DigestHasherFunc::Sha256, - digest: DigestInfo::new([0u8; 32], 0), - salt: 0, - }, - skip_cache_lookup: false, + digest: action_digest, + }), } } diff --git a/nativelink-service/BUILD.bazel b/nativelink-service/BUILD.bazel index 57e53aa6f..f8f47072f 100644 --- a/nativelink-service/BUILD.bazel +++ b/nativelink-service/BUILD.bazel @@ -55,6 +55,7 @@ rust_test_suite( ], proc_macro_deps = [ "//nativelink-macro", + "@crates//:async-trait", ], deps = [ "//nativelink-config", @@ -64,6 +65,7 @@ rust_test_suite( "//nativelink-service", "//nativelink-store", "//nativelink-util", + "@crates//:async-lock", "@crates//:bytes", "@crates//:futures", "@crates//:hyper", diff --git a/nativelink-service/Cargo.toml b/nativelink-service/Cargo.toml index 18d889eeb..983d0f5d7 100644 --- a/nativelink-service/Cargo.toml +++ b/nativelink-service/Cargo.toml @@ -28,6 +28,8 @@ uuid = { version = "1.8.0", features = ["v4"] } [dev-dependencies] nativelink-macro = { path = "../nativelink-macro" } +async-trait = "0.1.80" +async-lock = "3.3.0" hyper = "0.14.28" maplit = "1.0.2" pretty_assertions = "1.4.0" diff --git a/nativelink-service/src/execution_server.rs b/nativelink-service/src/execution_server.rs index a42f0ef03..b89d80583 100644 --- a/nativelink-service/src/execution_server.rs +++ b/nativelink-service/src/execution_server.rs @@ -17,7 +17,8 @@ use std::pin::Pin; use std::sync::Arc; use std::time::{Duration, SystemTime, UNIX_EPOCH}; -use futures::{Stream, StreamExt}; +use futures::stream::unfold; +use futures::Stream; use nativelink_config::cas_server::{ExecutionConfig, InstanceName}; use nativelink_error::{make_input_err, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::execution_server::{ @@ -27,22 +28,47 @@ use nativelink_proto::build::bazel::remote::execution::v2::{ Action, Command, ExecuteRequest, WaitExecutionRequest, }; use nativelink_proto::google::longrunning::Operation; -use nativelink_scheduler::action_scheduler::ActionScheduler; +use nativelink_scheduler::action_scheduler::{ActionListener, ActionScheduler}; use nativelink_store::ac_utils::get_and_decode_digest; use nativelink_store::store_manager::StoreManager; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionState, OperationId, DEFAULT_EXECUTION_PRIORITY, + ActionInfo, ActionUniqueKey, ActionUniqueQualifier, ClientOperationId, + DEFAULT_EXECUTION_PRIORITY, }; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::{make_ctx_for_hash_func, DigestHasherFunc}; use nativelink_util::platform_properties::PlatformProperties; use nativelink_util::store_trait::Store; -use rand::{thread_rng, Rng}; -use tokio::sync::watch; -use tokio_stream::wrappers::WatchStream; use tonic::{Request, Response, Status}; use tracing::{error_span, event, instrument, Level}; +type InstanceInfoName = String; + +struct NativelinkClientOperationId { + instance_name: InstanceInfoName, + client_operation_id: ClientOperationId, +} + +impl NativelinkClientOperationId { + fn from_name(name: &str) -> Result { + let (instance_name, name) = name + .split_once('/') + .err_tip(|| "Expected instance_name and name to be separated by '/'")?; + Ok(Self { + instance_name: instance_name.to_string(), + client_operation_id: ClientOperationId::from_raw_string(name.to_string()), + }) + } + + fn into_string(self) -> String { + format!( + "{}/{}", + self.instance_name, + self.client_operation_id.into_string() + ) + } +} + struct InstanceInfo { scheduler: Arc, cas_store: Store, @@ -112,6 +138,17 @@ impl InstanceInfo { } } + let action_key = ActionUniqueKey { + instance_name, + digest_function, + digest: action_digest, + }; + let unique_qualifier = if skip_cache_lookup { + ActionUniqueQualifier::Uncachable(action_key) + } else { + ActionUniqueQualifier::Cachable(action_key) + }; + Ok(ActionInfo { command_digest, input_root_digest, @@ -120,17 +157,7 @@ impl InstanceInfo { priority, load_timestamp: UNIX_EPOCH, insert_timestamp: SystemTime::now(), - unique_qualifier: ActionInfoHashKey { - instance_name, - digest_function, - digest: action_digest, - salt: if action.do_not_cache { - thread_rng().gen::() - } else { - 0 - }, - }, - skip_cache_lookup, + unique_qualifier, }) } } @@ -139,7 +166,7 @@ pub struct ExecutionServer { instance_infos: HashMap, } -type ExecuteStream = Pin> + Send + Sync + 'static>>; +type ExecuteStream = Pin> + Send + 'static>>; impl ExecutionServer { pub fn new( @@ -179,11 +206,42 @@ impl ExecutionServer { Server::new(self) } - fn to_execute_stream(receiver: watch::Receiver>) -> Response { - let receiver_stream = Box::pin(WatchStream::new(receiver).map(|action_update| { - event!(Level::INFO, ?action_update, "Execute Resp Stream",); - Ok(Into::::into(action_update.as_ref().clone())) - })); + fn to_execute_stream( + nl_client_operation_id: NativelinkClientOperationId, + action_listener: Pin>, + ) -> Response { + let client_operation_id_string = nl_client_operation_id.into_string(); + let receiver_stream = Box::pin(unfold( + Some(action_listener), + move |maybe_action_listener| { + let client_operation_id_string = client_operation_id_string.clone(); + async move { + let mut action_listener = maybe_action_listener?; + match action_listener.changed().await { + Ok(action_update) => { + event!(Level::INFO, ?action_update, "Execute Resp Stream"); + let client_operation_id = ClientOperationId::from_raw_string( + client_operation_id_string.clone(), + ); + // If the action is finished we won't be sending any more updates. + let maybe_action_listener = if action_update.stage.is_finished() { + None + } else { + Some(action_listener) + }; + Some(( + Ok(action_update.as_operation(client_operation_id)), + maybe_action_listener, + )) + } + Err(err) => { + event!(Level::ERROR, ?err, "Error in action_listener stream"); + Some((Err(err.into()), None)) + } + } + } + }, + )); tonic::Response::new(receiver_stream) } @@ -213,7 +271,7 @@ impl ExecutionServer { get_and_decode_digest::(&instance_info.cas_store, digest.into()).await?; let action_info = instance_info .build_action_info( - instance_name, + instance_name.clone(), digest, &action, priority, @@ -225,38 +283,52 @@ impl ExecutionServer { ) .await?; - let rx = instance_info + let action_listener = instance_info .scheduler - .add_action(action_info) + .add_action( + ClientOperationId::new(action_info.unique_qualifier.clone()), + action_info, + ) .await .err_tip(|| "Failed to schedule task")?; - Ok(Self::to_execute_stream(rx)) + Ok(Self::to_execute_stream( + NativelinkClientOperationId { + instance_name, + client_operation_id: action_listener.client_operation_id().clone(), + }, + action_listener, + )) } async fn inner_wait_execution( &self, request: Request, ) -> Result, Status> { - let operation_id = OperationId::try_from(request.into_inner().name.as_str()) - .err_tip(|| "Decoding operation name into OperationId")?; - let Some(instance_info) = self - .instance_infos - .get(&operation_id.unique_qualifier.instance_name) - else { + let (instance_name, client_operation_id) = + NativelinkClientOperationId::from_name(&request.into_inner().name) + .map(|v| (v.instance_name, v.client_operation_id)) + .err_tip(|| "Failed to parse operation_id in ExecutionServer::wait_execution")?; + let Some(instance_info) = self.instance_infos.get(&instance_name) else { return Err(Status::not_found(format!( - "No scheduler with the instance name {}", - operation_id.unique_qualifier.instance_name + "No scheduler with the instance name {instance_name}" ))); }; let Some(rx) = instance_info .scheduler - .find_existing_action(&operation_id.unique_qualifier) + .find_by_client_operation_id(&client_operation_id) .await + .err_tip(|| "Error running find_existing_action in ExecutionServer::wait_execution")? else { return Err(Status::not_found("Failed to find existing task")); }; - Ok(Self::to_execute_stream(rx)) + Ok(Self::to_execute_stream( + NativelinkClientOperationId { + instance_name, + client_operation_id, + }, + rx, + )) } } diff --git a/nativelink-service/src/worker_api_server.rs b/nativelink-service/src/worker_api_server.rs index 2a4a81ead..6061d6d55 100644 --- a/nativelink-service/src/worker_api_server.rs +++ b/nativelink-service/src/worker_api_server.rs @@ -27,11 +27,10 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ execute_result, ExecuteResult, GoingAwayRequest, KeepAliveRequest, SupportedProperties, UpdateForWorker, }; -use nativelink_scheduler::worker::{Worker}; +use nativelink_scheduler::worker::Worker; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_util::background_spawn; -use nativelink_util::action_messages::{ActionInfoHashKey, WorkerId}; -use nativelink_util::common::DigestInfo; +use nativelink_util::action_messages::{OperationId, WorkerId}; use nativelink_util::platform_properties::PlatformProperties; use tokio::sync::mpsc; use tokio::time::interval; @@ -188,7 +187,10 @@ impl WorkerApiServer { going_away_request: GoingAwayRequest, ) -> Result, Error> { let worker_id: WorkerId = going_away_request.worker_id.try_into()?; - self.scheduler.remove_worker(worker_id).await; + self.scheduler + .remove_worker(&worker_id) + .await + .err_tip(|| "While calling WorkerApiServer::inner_going_away")?; Ok(Response::new(())) } @@ -196,21 +198,11 @@ impl WorkerApiServer { &self, execute_result: ExecuteResult, ) -> Result, Error> { - let digest_function = execute_result - .digest_function() - .try_into() - .err_tip(|| "In inner_execution_response")?; let worker_id: WorkerId = execute_result.worker_id.try_into()?; - let action_digest: DigestInfo = execute_result - .action_digest - .err_tip(|| "Expected action_digest to exist")? - .try_into()?; - let action_info_hash_key = ActionInfoHashKey { - instance_name: execute_result.instance_name, - digest_function, - digest: action_digest, - salt: execute_result.salt, - }; + let operation_id = + OperationId::try_from(execute_result.operation_id.as_str()).err_tip(|| { + "Failed to convert operation_id in WorkerApiServer::inner_execution_response" + })?; match execute_result .result @@ -221,15 +213,15 @@ impl WorkerApiServer { .try_into() .err_tip(|| "Failed to convert ExecuteResponse into an ActionStage")?; self.scheduler - .update_action(&worker_id, action_info_hash_key, Ok(action_stage)) + .update_action(&worker_id, &operation_id, Ok(action_stage)) .await - .err_tip(|| format!("Failed to update_action {action_digest:?}"))?; + .err_tip(|| format!("Failed to operation {operation_id:?}"))?; } execute_result::Result::InternalError(e) => { self.scheduler - .update_action(&worker_id, action_info_hash_key, Err(e.into())) + .update_action(&worker_id, &operation_id, Err(e.into())) .await - .err_tip(|| format!("Failed to update_action {action_digest:?}"))?; + .err_tip(|| format!("Failed to operation {operation_id:?}"))?; } } Ok(Response::new(())) diff --git a/nativelink-service/tests/worker_api_server_test.rs b/nativelink-service/tests/worker_api_server_test.rs index 4d5f3d84f..e0c28cef5 100644 --- a/nativelink-service/tests/worker_api_server_test.rs +++ b/nativelink-service/tests/worker_api_server_test.rs @@ -16,7 +16,10 @@ use std::collections::HashMap; use std::sync::{Arc, Mutex}; use std::time::{Duration, SystemTime, UNIX_EPOCH}; +use async_lock::Mutex as AsyncMutex; +use async_trait::async_trait; use nativelink_config::cas_server::WorkerApiConfig; +use nativelink_config::schedulers::WorkerAllocationStrategy; use nativelink_error::{Error, ResultExt}; use nativelink_macro::nativelink_test; use nativelink_proto::build::bazel::remote::execution::v2::{ @@ -28,23 +31,99 @@ use nativelink_proto::com::github::trace_machina::nativelink::remote_execution:: execute_result, update_for_worker, ExecuteResult, KeepAliveRequest, SupportedProperties, }; use nativelink_proto::google::rpc::Status as ProtoStatus; -use nativelink_scheduler::action_scheduler::ActionScheduler; -use nativelink_scheduler::simple_scheduler::SimpleScheduler; +use nativelink_scheduler::api_worker_scheduler::ApiWorkerScheduler; +use nativelink_scheduler::platform_property_manager::PlatformPropertyManager; use nativelink_scheduler::worker_scheduler::WorkerScheduler; use nativelink_service::worker_api_server::{ConnectWorkerStream, NowFn, WorkerApiServer}; -use nativelink_util::action_messages::{ActionInfo, ActionInfoHashKey, ActionStage, WorkerId}; +use nativelink_util::action_messages::{ + ActionInfo, ActionStage, ActionUniqueKey, ActionUniqueQualifier, OperationId, WorkerId, +}; use nativelink_util::common::DigestInfo; use nativelink_util::digest_hasher::DigestHasherFunc; +use nativelink_util::operation_state_manager::WorkerStateManager; use nativelink_util::platform_properties::PlatformProperties; use pretty_assertions::assert_eq; +use tokio::join; +use tokio::sync::{mpsc, Notify}; use tokio_stream::StreamExt; use tonic::Request; const BASE_NOW_S: u64 = 10; const BASE_WORKER_TIMEOUT_S: u64 = 100; +#[derive(Debug)] +enum WorkerStateManagerCalls { + UpdateOperation((OperationId, WorkerId, Result)), +} + +#[derive(Debug)] +enum WorkerStateManagerReturns { + UpdateOperation(Result<(), Error>), +} + +struct MockWorkerStateManager { + rx_call: Arc>>, + tx_call: mpsc::UnboundedSender, + rx_resp: Arc>>, + tx_resp: mpsc::UnboundedSender, +} + +impl MockWorkerStateManager { + pub fn new() -> Self { + let (tx_call, rx_call) = mpsc::unbounded_channel(); + let (tx_resp, rx_resp) = mpsc::unbounded_channel(); + Self { + rx_call: Arc::new(AsyncMutex::new(rx_call)), + tx_call, + rx_resp: Arc::new(AsyncMutex::new(rx_resp)), + tx_resp, + } + } + + pub async fn expect_update_operation( + &self, + result: Result<(), Error>, + ) -> (OperationId, WorkerId, Result) { + let mut rx_call_lock = self.rx_call.lock().await; + let recv = rx_call_lock.recv(); + let WorkerStateManagerCalls::UpdateOperation(req) = + recv.await.expect("Could not receive msg in mpsc"); + self.tx_resp + .send(WorkerStateManagerReturns::UpdateOperation(result)) + .expect("Could not send request to mpsc"); + req + } +} + +#[async_trait] +impl WorkerStateManager for MockWorkerStateManager { + async fn update_operation( + &self, + operation_id: &OperationId, + worker_id: &WorkerId, + action_stage: Result, + ) -> Result<(), Error> { + self.tx_call + .send(WorkerStateManagerCalls::UpdateOperation(( + operation_id.clone(), + *worker_id, + action_stage, + ))) + .expect("Could not send request to mpsc"); + let mut rx_resp_lock = self.rx_resp.lock().await; + match rx_resp_lock + .recv() + .await + .expect("Could not receive msg in mpsc") + { + WorkerStateManagerReturns::UpdateOperation(result) => result, + } + } +} + struct TestContext { - scheduler: Arc, + scheduler: Arc, + state_manager: Arc, worker_api_server: WorkerApiServer, connection_worker_stream: ConnectWorkerStream, worker_id: WorkerId, @@ -57,12 +136,16 @@ fn static_now_fn() -> Result { async fn setup_api_server(worker_timeout: u64, now_fn: NowFn) -> Result { const SCHEDULER_NAME: &str = "DUMMY_SCHEDULE_NAME"; - let scheduler = Arc::new(SimpleScheduler::new( - &nativelink_config::schedulers::SimpleScheduler { - worker_timeout_s: worker_timeout, - ..Default::default() - }, - )); + let platform_property_manager = Arc::new(PlatformPropertyManager::new(HashMap::new())); + let tasks_or_worker_change_notify = Arc::new(Notify::new()); + let state_manager = Arc::new(MockWorkerStateManager::new()); + let scheduler = ApiWorkerScheduler::new( + state_manager.clone(), + platform_property_manager, + WorkerAllocationStrategy::default(), + tasks_or_worker_change_notify, + worker_timeout, + ); let mut schedulers: HashMap> = HashMap::new(); schedulers.insert(SCHEDULER_NAME.to_string(), scheduler.clone()); @@ -107,6 +190,7 @@ async fn setup_api_server(worker_timeout: u64, now_fn: NowFn) -> Result Result<(), Box SystemTime { #[nativelink_test] pub async fn execution_response_success_test() -> Result<(), Box> { - let test_context = setup_api_server(BASE_WORKER_TIMEOUT_S, Box::new(static_now_fn)).await?; + let mut test_context = setup_api_server(BASE_WORKER_TIMEOUT_S, Box::new(static_now_fn)).await?; - const SALT: u64 = 5; let action_digest = DigestInfo::new([7u8; 32], 123); let instance_name = "instance_name".to_string(); - let action_info = ActionInfo { + let unique_qualifier = ActionUniqueQualifier::Uncachable(ActionUniqueKey { + instance_name: instance_name.clone(), + digest_function: DigestHasherFunc::Sha256, + digest: action_digest, + }); + let action_info = Arc::new(ActionInfo { command_digest: DigestInfo::new([0u8; 32], 0), input_root_digest: DigestInfo::new([0u8; 32], 0), timeout: Duration::MAX, @@ -300,15 +389,18 @@ pub async fn execution_response_success_test() -> Result<(), Box Result<(), Box { - drop(action_state); - // Note: `.changed()` might be triggered twice, since the first trigger - // might be Queued and the second will always be Executing, but there's no - // guarantee that the first trigger will be Queued. - client_action_state_receiver.changed().await?; - client_action_state_receiver.borrow() - } - _ => client_action_state_receiver.borrow(), - }; - assert_eq!(action_state.stage, ActionStage::Executing); - } - // Now send the result of our execution to the scheduler. - test_context - .worker_api_server - .execution_response(Request::new(result.clone())) - .await?; + let update_for_worker = test_context + .connection_worker_stream + .next() + .await + .expect("Worker stream ended early")? + .update + .expect("Expected update field to be populated"); + let update_for_worker::Update::StartAction(start_execute) = update_for_worker else { + panic!("Expected StartAction message"); + }; + assert_eq!(result.operation_id, start_execute.operation_id); { - // Check the result that the client would have received. - client_action_state_receiver.changed().await?; - let client_given_state = client_action_state_receiver.borrow(); - let execute_response = - if let execute_result::Result::ExecuteResponse(v) = result.result.unwrap() { - v - } else { - panic!("Expected type to be ExecuteResponse"); - }; - - assert_eq!( - client_given_state.stage, - execute_response.clone().try_into()? + // Ensure our state manager got the same result as the server. + let (execution_response_result, (operation_id, worker_id, client_given_state)) = join!( + test_context + .worker_api_server + .execution_response(Request::new(result.clone())), + test_context.state_manager.expect_update_operation(Ok(())), ); + execution_response_result.unwrap(); - // We just checked if conversion from ExecuteResponse into ActionStage was an exact mach. - // Now check if we cast the ActionStage into an ExecuteResponse we get the exact same struct. - assert_eq!(execute_response, client_given_state.stage.clone().into()); + assert_eq!(operation_id, expected_operation_id); + assert_eq!(worker_id, test_context.worker_id); + assert_eq!(client_given_state, Ok(execute_response.clone().try_into()?)); + assert_eq!(execute_response, client_given_state.unwrap().into()); } Ok(()) } diff --git a/nativelink-store/tests/cas_utils_test.rs b/nativelink-store/tests/cas_utils_test.rs index 497bfb58a..ae84a1ff7 100644 --- a/nativelink-store/tests/cas_utils_test.rs +++ b/nativelink-store/tests/cas_utils_test.rs @@ -23,7 +23,7 @@ fn sha256_is_zero_digest() { packed_hash: Sha256::new().finalize().into(), size_bytes: 0, }; - assert!(is_zero_digest(&digest)); + assert!(is_zero_digest(digest)); } #[test] @@ -34,7 +34,7 @@ fn sha256_is_non_zero_digest() { packed_hash: hasher.finalize().into(), size_bytes: 1, }; - assert!(!is_zero_digest(&digest)); + assert!(!is_zero_digest(digest)); } #[test] @@ -43,7 +43,7 @@ fn blake_is_zero_digest() { packed_hash: Blake3::new().finalize().into(), size_bytes: 0, }; - assert!(is_zero_digest(&digest)); + assert!(is_zero_digest(digest)); } #[test] @@ -54,5 +54,5 @@ fn blake_is_non_zero_digest() { packed_hash: hasher.finalize().into(), size_bytes: 1, }; - assert!(!is_zero_digest(&digest)); + assert!(!is_zero_digest(digest)); } diff --git a/nativelink-util/BUILD.bazel b/nativelink-util/BUILD.bazel index aae3c1cb5..7175719ca 100644 --- a/nativelink-util/BUILD.bazel +++ b/nativelink-util/BUILD.bazel @@ -11,6 +11,7 @@ rust_library( srcs = [ "src/action_messages.rs", "src/buf_channel.rs", + "src/chunked_stream.rs", "src/common.rs", "src/connection_manager.rs", "src/default_store_key_subscribe.rs", @@ -21,6 +22,7 @@ rust_library( "src/health_utils.rs", "src/lib.rs", "src/metrics_utils.rs", + "src/operation_state_manager.rs", "src/origin_context.rs", "src/platform_properties.rs", "src/proto_stream_utils.rs", @@ -40,6 +42,7 @@ rust_library( "//nativelink-error", "//nativelink-proto", "@crates//:async-lock", + "@crates//:bitflags", "@crates//:blake3", "@crates//:bytes", "@crates//:console-subscriber", @@ -49,6 +52,7 @@ rust_library( "@crates//:hyper-util", "@crates//:lru", "@crates//:parking_lot", + "@crates//:pin-project", "@crates//:pin-project-lite", "@crates//:prometheus-client", "@crates//:prost", diff --git a/nativelink-util/Cargo.toml b/nativelink-util/Cargo.toml index c15a014ef..a11c6a5e6 100644 --- a/nativelink-util/Cargo.toml +++ b/nativelink-util/Cargo.toml @@ -10,8 +10,10 @@ nativelink-proto = { path = "../nativelink-proto" } async-lock = "3.3.0" async-trait = "0.1.80" +bitflags = "2.5.0" blake3 = { version = "1.5.1", features = ["mmap"] } bytes = "1.6.0" +pin-project = "1.1.5" console-subscriber = { version = "0.3.0" } futures = "0.3.30" hex = "0.4.3" diff --git a/nativelink-util/src/action_messages.rs b/nativelink-util/src/action_messages.rs index b1b185b23..e2e6351c8 100644 --- a/nativelink-util/src/action_messages.rs +++ b/nativelink-util/src/action_messages.rs @@ -12,14 +12,11 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::borrow::Borrow; use std::cmp::Ordering; use std::collections::HashMap; use std::hash::{Hash, Hasher}; -use std::sync::Arc; use std::time::{Duration, SystemTime}; -use blake3::Hasher as Blake3Hasher; use nativelink_error::{error_if, make_input_err, Error, ResultExt}; use nativelink_proto::build::bazel::remote::execution::v2::{ execution_stage, Action, ActionResult as ProtoActionResult, ExecuteOperationMetadata, @@ -43,39 +40,74 @@ use crate::platform_properties::PlatformProperties; /// Default priority remote execution jobs will get when not provided. pub const DEFAULT_EXECUTION_PRIORITY: i32 = 0; -pub type WorkerTimestamp = u64; +/// Exit code sent if there is an internal error. +pub const INTERNAL_ERROR_EXIT_CODE: i32 = -178; + +/// Holds an id that is unique to the client for a requested operation. +/// Each client should be issued a unique id even if they are attached +/// to the same underlying operation. +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct ClientOperationId(String); -#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)] +impl ClientOperationId { + pub fn new(unique_qualifier: ActionUniqueQualifier) -> Self { + Self(OperationId::new(unique_qualifier).to_string()) + } + + pub fn from_raw_string(name: String) -> Self { + Self(name) + } + + pub fn into_string(self) -> String { + self.0 + } +} + +impl std::fmt::Display for ClientOperationId { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_fmt(format_args!("{}", self.0.clone())) + } +} + +#[derive(Clone, Serialize, Deserialize)] pub struct OperationId { - pub unique_qualifier: ActionInfoHashKey, + pub unique_qualifier: ActionUniqueQualifier, pub id: Uuid, } -// TODO: Eventually we should make this it's own hash rather than delegate to ActionInfoHashKey. +impl PartialEq for OperationId { + fn eq(&self, other: &Self) -> bool { + self.id.eq(&other.id) + } +} + +impl Eq for OperationId {} + +impl PartialOrd for OperationId { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for OperationId { + fn cmp(&self, other: &Self) -> Ordering { + self.id.cmp(&other.id) + } +} + impl Hash for OperationId { fn hash(&self, state: &mut H) { - ActionInfoHashKey::hash(&self.unique_qualifier, state) + self.id.hash(state) } } impl OperationId { - pub fn new(unique_qualifier: ActionInfoHashKey) -> Self { + pub fn new(unique_qualifier: ActionUniqueQualifier) -> Self { Self { - id: uuid::Uuid::new_v4(), + id: Uuid::new_v4(), unique_qualifier, } } - - /// Utility function used to make a unique hash of the digest including the salt. - pub fn get_hash(&self) -> [u8; 32] { - self.unique_qualifier.get_hash() - } - - /// Returns the salt used for cache busting/hashing. - #[inline] - pub fn action_name(&self) -> String { - self.unique_qualifier.action_name() - } } impl TryFrom<&str> for OperationId { @@ -84,7 +116,7 @@ impl TryFrom<&str> for OperationId { /// Attempts to convert a string slice into an `OperationId`. /// /// The input string `value` is expected to be in the format: - /// `//-//`. + /// `//-//`. /// /// # Parameters /// @@ -105,7 +137,7 @@ impl TryFrom<&str> for OperationId { /// /// ```no_run /// use nativelink_util::action_messages::OperationId; - /// let operation_id_str = "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"; + /// let operation_id_str = "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"; /// let operation_id = OperationId::try_from(operation_id_str); /// ``` /// @@ -119,30 +151,41 @@ impl TryFrom<&str> for OperationId { .err_tip(|| format!("Invalid OperationId unique_qualifier / id fragment - {value}"))?; let (instance_name, rest) = unique_qualifier .split_once('/') - .err_tip(|| format!("Invalid ActionInfoHashKey instance name fragment - {value}"))?; + .err_tip(|| format!("Invalid UniqueQualifier instance name fragment - {value}"))?; let (digest_function, rest) = rest .split_once('/') - .err_tip(|| format!("Invalid ActionInfoHashKey digest function fragment - {value}"))?; + .err_tip(|| format!("Invalid UniqueQualifier digest function fragment - {value}"))?; let (digest_hash, rest) = rest .split_once('-') - .err_tip(|| format!("Invalid ActionInfoHashKey digest hash fragment - {value}"))?; - let (digest_size, salt) = rest + .err_tip(|| format!("Invalid UniqueQualifier digest hash fragment - {value}"))?; + let (digest_size, cachable) = rest .split_once('/') - .err_tip(|| format!("Invalid ActionInfoHashKey digest size fragment - {value}"))?; + .err_tip(|| format!("Invalid UniqueQualifier digest size fragment - {value}"))?; let digest = DigestInfo::try_new( digest_hash, digest_size .parse::() - .err_tip(|| format!("Invalid ActionInfoHashKey size value fragment - {value}"))?, + .err_tip(|| format!("Invalid UniqueQualifier size value fragment - {value}"))?, ) .err_tip(|| format!("Invalid DigestInfo digest hash - {value}"))?; - let salt = u64::from_str_radix(salt, 16) - .err_tip(|| format!("Invalid ActionInfoHashKey salt hex conversion - {value}"))?; - let unique_qualifier = ActionInfoHashKey { + let cachable = match cachable { + "u" => false, + "c" => true, + _ => { + return Err(make_input_err!( + "Invalid UniqueQualifier cachable value fragment - {value}" + )); + } + }; + let unique_key = ActionUniqueKey { instance_name: instance_name.to_string(), digest_function: digest_function.try_into()?, digest, - salt, + }; + let unique_qualifier = if cachable { + ActionUniqueQualifier::Cachable(unique_key) + } else { + ActionUniqueQualifier::Uncachable(unique_key) }; let id = Uuid::parse_str(id).map_err(|e| make_input_err!("Failed to parse {e} as uuid"))?; Ok(Self { @@ -154,26 +197,18 @@ impl TryFrom<&str> for OperationId { impl std::fmt::Display for OperationId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "{}/{}", - self.unique_qualifier.action_name(), - self.id - )) + f.write_fmt(format_args!("{}/{}", self.unique_qualifier, self.id)) } } impl std::fmt::Debug for OperationId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_fmt(format_args!( - "{}:{}", - self.unique_qualifier.action_name(), - self.id - )) + std::fmt::Display::fmt(&self, f) } } /// Unique id of worker. -#[derive(Eq, PartialEq, Hash, Copy, Clone, Serialize, Deserialize)] +#[derive(Default, Eq, PartialEq, Hash, Copy, Clone, Serialize, Deserialize)] pub struct WorkerId(pub Uuid); impl std::fmt::Display for WorkerId { @@ -186,9 +221,7 @@ impl std::fmt::Display for WorkerId { impl std::fmt::Debug for WorkerId { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - let mut buf = Uuid::encode_buffer(); - let worker_id_str = self.0.hyphenated().encode_lower(&mut buf); - f.write_fmt(format_args!("{worker_id_str}")) + std::fmt::Display::fmt(&self, f) } } @@ -197,58 +230,76 @@ impl TryFrom for WorkerId { fn try_from(s: String) -> Result { match Uuid::parse_str(&s) { Err(e) => Err(make_input_err!( - "Failed to convert string to WorkerId : {} : {:?}", - s, - e + "Failed to convert string to WorkerId : {s} : {e:?}", )), Ok(my_uuid) => Ok(WorkerId(my_uuid)), } } } + +/// Holds the information needed to uniquely identify an action +/// and if it is cachable or not. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)] +pub enum ActionUniqueQualifier { + /// The action is cachable. + Cachable(ActionUniqueKey), + /// The action is uncachable. + Uncachable(ActionUniqueKey), +} + +impl ActionUniqueQualifier { + /// Get the instance_name of the action. + pub const fn instance_name(&self) -> &String { + match self { + Self::Cachable(action) => &action.instance_name, + Self::Uncachable(action) => &action.instance_name, + } + } + + /// Get the digest function of the action. + pub const fn digest_function(&self) -> DigestHasherFunc { + match self { + Self::Cachable(action) => action.digest_function, + Self::Uncachable(action) => action.digest_function, + } + } + + /// Get the digest of the action. + pub const fn digest(&self) -> DigestInfo { + match self { + Self::Cachable(action) => action.digest, + Self::Uncachable(action) => action.digest, + } + } +} + +impl std::fmt::Display for ActionUniqueQualifier { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let (cachable, unique_key) = match self { + Self::Cachable(action) => (true, action), + Self::Uncachable(action) => (false, action), + }; + f.write_fmt(format_args!( + "{}/{}/{}-{}/{}", + unique_key.instance_name, + unique_key.digest_function, + unique_key.digest.hash_str(), + unique_key.digest.size_bytes, + if cachable { 'c' } else { 'u' }, + )) + } +} + /// This is a utility struct used to make it easier to match `ActionInfos` in a /// `HashMap` without needing to construct an entire `ActionInfo`. -/// Since the hashing only needs the digest and salt we can just alias them here -/// and point the original `ActionInfo` structs to reference these structs for -/// it's hashing functions. -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ActionInfoHashKey { +#[derive(Debug, Clone, Eq, PartialEq, Hash, Serialize, Deserialize)] +pub struct ActionUniqueKey { /// Name of instance group this action belongs to. pub instance_name: String, /// The digest function this action expects. pub digest_function: DigestHasherFunc, /// Digest of the underlying `Action`. pub digest: DigestInfo, - /// Salt that can be filled with a random number to ensure no `ActionInfo` will be a match - /// to another `ActionInfo` in the scheduler. When caching is wanted this value is usually - /// zero. - pub salt: u64, -} - -impl ActionInfoHashKey { - /// Utility function used to make a unique hash of the digest including the salt. - pub fn get_hash(&self) -> [u8; 32] { - Blake3Hasher::new() - .update(self.instance_name.as_bytes()) - .update(&i32::from(self.digest_function.proto_digest_func()).to_le_bytes()) - .update(&self.digest.packed_hash[..]) - .update(&self.digest.size_bytes.to_le_bytes()) - .update(&self.salt.to_le_bytes()) - .finalize() - .into() - } - - /// Returns the salt used for cache busting/hashing. - #[inline] - pub fn action_name(&self) -> String { - format!( - "{}/{}/{}-{}/{:X}", - self.instance_name, - self.digest_function, - self.digest.hash_str(), - self.digest.size_bytes, - self.salt - ) - } } /// Information needed to execute an action. This struct is used over bazel's proto `Action` @@ -272,47 +323,43 @@ pub struct ActionInfo { pub load_timestamp: SystemTime, /// When this action was created. pub insert_timestamp: SystemTime, - - /// Info used to uniquely identify this ActionInfo. Normally the hash function would just - /// use the fields it needs and you wouldn't need to separate them, however we have a use - /// case where we sometimes want to lookup an entry in a HashMap, but we don't have the - /// info to construct an entire ActionInfo. In such case we construct only a ActionInfoHashKey - /// then use that object to lookup the entry in the map. The root problem is that HashMap - /// requires `ActionInfo :Borrow` in order for this to work, which means - /// we need to be able to return a &ActionInfoHashKey from ActionInfo, but since we cannot - /// return a temporary reference we must have an object tied to ActionInfo's lifetime and - /// return it's reference. - pub unique_qualifier: ActionInfoHashKey, - - /// Whether to try looking up this action in the cache. - pub skip_cache_lookup: bool, + /// Info used to uniquely identify this ActionInfo and if it is cachable. + /// This is primarily used to join actions/operations together using this key. + pub unique_qualifier: ActionUniqueQualifier, } impl ActionInfo { #[inline] pub const fn instance_name(&self) -> &String { - &self.unique_qualifier.instance_name + self.unique_qualifier.instance_name() } /// Returns the underlying digest of the `Action`. #[inline] - pub const fn digest(&self) -> &DigestInfo { - &self.unique_qualifier.digest - } - - /// Returns the salt used for cache busting/hashing. - #[inline] - pub const fn salt(&self) -> &u64 { - &self.unique_qualifier.salt + pub const fn digest(&self) -> DigestInfo { + self.unique_qualifier.digest() } - pub fn try_from_action_and_execute_request_with_salt( + pub fn try_from_action_and_execute_request( execute_request: ExecuteRequest, action: Action, - salt: u64, load_timestamp: SystemTime, queued_timestamp: SystemTime, ) -> Result { + let unique_key = ActionUniqueKey { + instance_name: execute_request.instance_name, + digest_function: DigestHasherFunc::try_from(execute_request.digest_function) + .err_tip(|| format!("Could not find digest_function in try_from_action_and_execute_request {:?}", execute_request.digest_function))?, + digest: execute_request + .action_digest + .err_tip(|| "Expected action_digest to exist on ExecuteRequest")? + .try_into()?, + }; + let unique_qualifier = if execute_request.skip_cache_lookup { + ActionUniqueQualifier::Uncachable(unique_key) + } else { + ActionUniqueQualifier::Cachable(unique_key) + }; Ok(Self { command_digest: action .command_digest @@ -328,20 +375,13 @@ impl ActionInfo { .try_into() .map_err(|_| make_input_err!("Failed convert proto duration to system duration"))?, platform_properties: action.platform.unwrap_or_default().into(), - priority: execute_request.execution_policy.unwrap_or_default().priority, + priority: execute_request + .execution_policy + .unwrap_or_default() + .priority, load_timestamp, insert_timestamp: queued_timestamp, - unique_qualifier: ActionInfoHashKey { - instance_name: execute_request.instance_name, - digest_function: DigestHasherFunc::try_from(execute_request.digest_function) - .err_tip(|| format!("Could not find digest_function in try_from_action_and_execute_request_with_salt {:?}", execute_request.digest_function))?, - digest: execute_request - .action_digest - .err_tip(|| "Expected action_digest to exist on ExecuteRequest")? - .try_into()?, - salt, - }, - skip_cache_lookup: execute_request.skip_cache_lookup, + unique_qualifier, }) } } @@ -349,92 +389,21 @@ impl ActionInfo { impl From for ExecuteRequest { fn from(val: ActionInfo) -> Self { let digest = val.digest().into(); + let (skip_cache_lookup, unique_qualifier) = match val.unique_qualifier { + ActionUniqueQualifier::Cachable(unique_qualifier) => (false, unique_qualifier), + ActionUniqueQualifier::Uncachable(unique_qualifier) => (true, unique_qualifier), + }; Self { - instance_name: val.unique_qualifier.instance_name, + instance_name: unique_qualifier.instance_name, action_digest: Some(digest), - skip_cache_lookup: true, // The worker should never cache lookup. - execution_policy: None, // Not used in the worker. + skip_cache_lookup, + execution_policy: None, // Not used in the worker. results_cache_policy: None, // Not used in the worker. - digest_function: val - .unique_qualifier - .digest_function - .proto_digest_func() - .into(), + digest_function: unique_qualifier.digest_function.proto_digest_func().into(), } } } -// Note: Hashing, Eq, and Ord matching on this struct is unique. Normally these functions -// must play well with each other, but in our case the following rules apply: -// * Hash - Hashing must be unique on the exact command being run and must never match -// when do_not_cache is enabled, but must be consistent between identical data -// hashes. -// * Eq - Same as hash. -// * Ord - Used when sorting `ActionInfo` together. The only major sorting is priority and -// insert_timestamp, everything else is undefined, but must be deterministic. -impl Hash for ActionInfo { - fn hash(&self, state: &mut H) { - ActionInfoHashKey::hash(&self.unique_qualifier, state); - } -} - -impl PartialEq for ActionInfo { - fn eq(&self, other: &Self) -> bool { - ActionInfoHashKey::eq(&self.unique_qualifier, &other.unique_qualifier) - } -} - -impl Eq for ActionInfo {} - -impl Ord for ActionInfo { - fn cmp(&self, other: &Self) -> Ordering { - // Want the highest priority on top, but the lowest insert_timestamp. - self.priority - .cmp(&other.priority) - .then_with(|| other.insert_timestamp.cmp(&self.insert_timestamp)) - .then_with(|| self.salt().cmp(other.salt())) - .then_with(|| self.digest().size_bytes.cmp(&other.digest().size_bytes)) - .then_with(|| self.digest().packed_hash.cmp(&other.digest().packed_hash)) - .then_with(|| { - self.unique_qualifier - .digest_function - .cmp(&other.unique_qualifier.digest_function) - }) - } -} - -impl PartialOrd for ActionInfo { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Borrow for Arc { - #[inline] - fn borrow(&self) -> &ActionInfoHashKey { - &self.unique_qualifier - } -} - -impl Hash for ActionInfoHashKey { - fn hash(&self, state: &mut H) { - // Digest is unique, so hashing it is all we need. - self.digest_function.hash(state); - self.digest.hash(state); - self.salt.hash(state); - } -} - -impl PartialEq for ActionInfoHashKey { - fn eq(&self, other: &Self) -> bool { - self.digest == other.digest - && self.salt == other.salt - && self.digest_function == other.digest_function - } -} - -impl Eq for ActionInfoHashKey {} - /// Simple utility struct to determine if a string is representing a full path or /// just the name of the file. /// This is in order to be able to reuse the same struct instead of building different @@ -728,9 +697,6 @@ impl TryFrom for ExecutionMetadata { } } -/// Exit code sent if there is an internal error. -pub const INTERNAL_ERROR_EXIT_CODE: i32 = -178; - /// Represents the results of an execution. /// This struct must be 100% compatible with `ActionResult` in `remote_execution.proto`. #[derive(Eq, PartialEq, Debug, Clone, Serialize, Deserialize)] @@ -813,6 +779,20 @@ impl ActionStage { pub const fn is_finished(&self) -> bool { self.has_action_result() } + + /// Returns if the stage enum is the same as the other stage enum, but + /// does not compare the values of the enum. + pub const fn is_same_stage(&self, other: &Self) -> bool { + matches!( + (self, other), + (Self::Unknown, Self::Unknown) + | (Self::CacheCheck, Self::CacheCheck) + | (Self::Queued, Self::Queued) + | (Self::Executing, Self::Executing) + | (Self::Completed(_), Self::Completed(_)) + | (Self::CompletedFromCache(_), Self::CompletedFromCache(_)) + ) + } } impl MetricsComponent for ActionStage { @@ -1093,10 +1073,19 @@ where } } -impl TryFrom for ActionState { - type Error = Error; +/// Current state of the action. +/// This must be 100% compatible with `Operation` in `google/longrunning/operations.proto`. +#[derive(PartialEq, Debug, Clone)] +pub struct ActionState { + pub stage: ActionStage, + pub id: OperationId, +} - fn try_from(operation: Operation) -> Result { +impl ActionState { + pub fn try_from_operation( + operation: Operation, + operation_id: OperationId, + ) -> Result { let metadata = from_any::( &operation .metadata @@ -1135,51 +1124,23 @@ impl TryFrom for ActionState { } }; - // NOTE: This will error if we are forwarding an operation from - // one remote execution system to another that does not use our operation name - // format (ie: very unlikely, but possible). - let id = OperationId::try_from(operation.name.as_str())?; - Ok(Self { id, stage }) - } -} - -/// Current state of the action. -/// This must be 100% compatible with `Operation` in `google/longrunning/operations.proto`. -#[derive(PartialEq, Debug, Clone)] -pub struct ActionState { - pub stage: ActionStage, - pub id: OperationId, -} - -impl ActionState { - #[inline] - pub fn unique_qualifier(&self) -> &ActionInfoHashKey { - &self.id.unique_qualifier - } - #[inline] - pub fn action_digest(&self) -> &DigestInfo { - &self.id.unique_qualifier.digest - } -} - -impl MetricsComponent for ActionState { - fn gather_metrics(&self, c: &mut CollectorState) { - c.publish("stage", &self.stage, ""); + Ok(Self { + id: operation_id, + stage, + }) } -} -impl From for Operation { - fn from(val: ActionState) -> Self { - let stage = Into::::into(&val.stage) as i32; - let name = val.id.to_string(); + pub fn as_operation(&self, client_operation_id: ClientOperationId) -> Operation { + let stage = Into::::into(&self.stage) as i32; + let name = client_operation_id.into_string(); - let result = if val.stage.has_action_result() { - let execute_response: ExecuteResponse = val.stage.into(); + let result = if self.stage.has_action_result() { + let execute_response: ExecuteResponse = self.stage.clone().into(); Some(LongRunningResult::Response(to_any(&execute_response))) } else { None }; - let digest = Some(val.id.unique_qualifier.digest.into()); + let digest = Some(self.id.unique_qualifier.digest().into()); let metadata = ExecuteOperationMetadata { stage, @@ -1190,7 +1151,7 @@ impl From for Operation { partial_execution_metadata: None, }; - Self { + Operation { name, metadata: Some(to_any(&metadata)), done: result.is_some(), @@ -1198,3 +1159,9 @@ impl From for Operation { } } } + +impl MetricsComponent for ActionState { + fn gather_metrics(&self, c: &mut CollectorState) { + c.publish("stage", &self.stage, ""); + } +} diff --git a/nativelink-util/src/chunked_stream.rs b/nativelink-util/src/chunked_stream.rs new file mode 100644 index 000000000..e3665562d --- /dev/null +++ b/nativelink-util/src/chunked_stream.rs @@ -0,0 +1,110 @@ +// Copyright 2024 The NativeLink Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::collections::VecDeque; +use std::ops::Bound; +use std::pin::Pin; +use std::task::{Context, Poll}; + +use futures::{Future, Stream}; +use pin_project::pin_project; + +#[pin_project(project = StreamStateProj)] +enum StreamState { + Future(#[pin] Fut), + Next, +} + +/// Takes a range of keys and a function that returns a future that yields +/// an iterator of key-value pairs. The stream will yield all key-value pairs +/// in the range, in order and buffered. A great use case is where you need +/// to implement Stream, but to access the underlying data requires a lock, +/// but API does not require the data to be in sync with data already received. +#[pin_project] +pub struct ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + chunk_fn: F, + buffer: VecDeque, + start_key: Option>, + end_key: Option>, + #[pin] + stream_state: StreamState, +} + +impl ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + pub fn new(start_key: Bound, end_key: Bound, chunk_fn: F) -> Self { + Self { + chunk_fn, + buffer: VecDeque::new(), + start_key: Some(start_key), + end_key: Some(end_key), + stream_state: StreamState::Next, + } + } +} + +impl Stream for ChunkedStream +where + K: Ord, + F: FnMut(Bound, Bound, VecDeque) -> Fut, + Fut: Future, Bound), VecDeque)>, E>>, +{ + type Item = Result; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let mut this = self.project(); + loop { + if let Some(item) = this.buffer.pop_front() { + return Poll::Ready(Some(Ok(item))); + } + match this.stream_state.as_mut().project() { + StreamStateProj::Future(fut) => { + match futures::ready!(fut.poll(cx)) { + Ok(Some(((start, end), mut buffer))) => { + *this.start_key = Some(start); + *this.end_key = Some(end); + std::mem::swap(&mut buffer, this.buffer); + } + Ok(None) => return Poll::Ready(None), // End of stream. + Err(err) => return Poll::Ready(Some(Err(err))), + } + this.stream_state.set(StreamState::Next); + // Loop again. + } + StreamStateProj::Next => { + this.buffer.clear(); + // This trick is used to recycle capacity. + let buffer = std::mem::take(this.buffer); + let start_key = this + .start_key + .take() + .expect("start_key should never be None"); + let end_key = this.end_key.take().expect("end_key should never be None"); + let fut = (this.chunk_fn)(start_key, end_key, buffer); + this.stream_state.set(StreamState::Future(fut)); + // Loop again. + } + } + } + } +} diff --git a/nativelink-util/src/lib.rs b/nativelink-util/src/lib.rs index 2811a4d68..6b47b714a 100644 --- a/nativelink-util/src/lib.rs +++ b/nativelink-util/src/lib.rs @@ -14,6 +14,7 @@ pub mod action_messages; pub mod buf_channel; +pub mod chunked_stream; pub mod common; pub mod connection_manager; pub mod default_store_key_subscribe; @@ -23,6 +24,7 @@ pub mod fastcdc; pub mod fs; pub mod health_utils; pub mod metrics_utils; +pub mod operation_state_manager; pub mod origin_context; pub mod platform_properties; pub mod proto_stream_utils; diff --git a/nativelink-scheduler/src/operation_state_manager.rs b/nativelink-util/src/operation_state_manager.rs similarity index 56% rename from nativelink-scheduler/src/operation_state_manager.rs rename to nativelink-util/src/operation_state_manager.rs index 2b7184d3f..cb1b331e3 100644 --- a/nativelink-scheduler/src/operation_state_manager.rs +++ b/nativelink-util/src/operation_state_manager.rs @@ -20,14 +20,15 @@ use async_trait::async_trait; use bitflags::bitflags; use futures::Stream; use nativelink_error::Error; -use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionStage, ActionState, OperationId, WorkerId, +use prometheus_client::registry::Registry; + +use crate::action_messages::{ + ActionInfo, ActionStage, ActionState, ActionUniqueKey, ClientOperationId, OperationId, WorkerId, }; -use nativelink_util::common::DigestInfo; -use tokio::sync::watch; +use crate::common::DigestInfo; bitflags! { - #[derive(Debug, Clone, PartialEq, Eq, Hash)] + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub struct OperationStageFlags: u32 { const CacheCheck = 1 << 1; const Queued = 1 << 2; @@ -37,24 +38,39 @@ bitflags! { } } +impl Default for OperationStageFlags { + fn default() -> Self { + Self::Any + } +} + #[async_trait] pub trait ActionStateResult: Send + Sync + 'static { // Provides the current state of the action. async fn as_state(&self) -> Result, Error>; - // Subscribes to the state of the action, receiving updates as they are published. - async fn as_receiver(&self) -> Result<&'_ watch::Receiver>, Error>; + // Waits for the state of the action to change. + async fn changed(&mut self) -> Result, Error>; // Provide result as action info. This behavior will not be supported by all implementations. - // TODO(adams): Expectation is this to experimental and removed in the future. async fn as_action_info(&self) -> Result, Error>; } +/// The direction in which the results are ordered. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum OrderDirection { + Asc, + Desc, +} + /// The filters used to query operations from the state manager. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Default, Debug, Clone, PartialEq, Eq, Hash)] pub struct OperationFilter { // TODO(adams): create rust builder pattern? /// The stage(s) that the operation must be in. pub stages: OperationStageFlags, + /// The client operation id. + pub client_operation_id: Option, + /// The operation id. pub operation_id: Option, @@ -70,80 +86,67 @@ pub struct OperationFilter { /// The operation must have been completed before this time. pub completed_before: Option, - /// The operation must have it's last client update before this time. - pub last_client_update_before: Option, - /// The unique key for filtering specific action results. - pub unique_qualifier: Option, - - /// The order by in which results are returned by the filter operation. - pub order_by: Option, -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub enum OperationFields { - Priority, - Timestamp, -} + pub unique_key: Option, -/// The order in which results are returned by the filter operation. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct OrderBy { - /// The fields to order by, each field is ordered in the order they are provided. - pub fields: Vec, - /// The order of the fields, true for descending, false for ascending. - pub desc: bool, + /// If the results should be ordered by priority and in which direction. + pub order_by_priority_direction: Option, } -pub type ActionStateResultStream = Pin> + Send>>; +pub type ActionStateResultStream<'a> = + Pin> + Send + 'a>>; #[async_trait] -pub trait ClientStateManager { +pub trait ClientStateManager: Sync + Send { /// Add a new action to the queue or joins an existing action. async fn add_action( - &mut self, - action_info: ActionInfo, - ) -> Result, Error>; + &self, + client_operation_id: ClientOperationId, + action_info: Arc, + ) -> Result, Error>; /// Returns a stream of operations that match the filter. - async fn filter_operations( - &self, + async fn filter_operations<'a>( + &'a self, filter: OperationFilter, - ) -> Result; + ) -> Result, Error>; + + /// Register metrics with the registry. + fn register_metrics(self: Arc, _registry: &mut Registry) {} } #[async_trait] -pub trait WorkerStateManager { +pub trait WorkerStateManager: Sync + Send { /// Update that state of an operation. /// The worker must also send periodic updates even if the state /// did not change with a modified timestamp in order to prevent /// the operation from being considered stale and being rescheduled. async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: WorkerId, + &self, + operation_id: &OperationId, + worker_id: &WorkerId, action_stage: Result, ) -> Result<(), Error>; + + /// Register metrics with the registry. + fn register_metrics(self: Arc, _registry: &mut Registry) {} } #[async_trait] -pub trait MatchingEngineStateManager { +pub trait MatchingEngineStateManager: Sync + Send { /// Returns a stream of operations that match the filter. - async fn filter_operations( - &self, + async fn filter_operations<'a>( + &'a self, filter: OperationFilter, - ) -> Result; + ) -> Result, Error>; - /// Update that state of an operation. - async fn update_operation( - &mut self, - operation_id: OperationId, - worker_id: Option, - action_stage: Result, + /// Assign an operation to a worker or unassign it. + async fn assign_operation( + &self, + operation_id: &OperationId, + worker_id_or_reason_for_unsassign: Result<&WorkerId, Error>, ) -> Result<(), Error>; - /// Remove an operation from the state manager. - /// It is important to use this function to remove operations - /// that are no longer needed to prevent memory leaks. - async fn remove_operation(&self, operation_id: OperationId) -> Result<(), Error>; + /// Register metrics with the registry. + fn register_metrics(self: Arc, _registry: &mut Registry) {} } diff --git a/nativelink-util/tests/operation_id_tests.rs b/nativelink-util/tests/operation_id_tests.rs index e1e8b5e30..a2513c8f7 100644 --- a/nativelink-util/tests/operation_id_tests.rs +++ b/nativelink-util/tests/operation_id_tests.rs @@ -19,22 +19,28 @@ use pretty_assertions::assert_eq; #[nativelink_test] async fn parse_operation_id() -> Result<(), Error> { - let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").unwrap(); - assert_eq!( - operation_id.to_string(), - "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); - assert_eq!( - operation_id.action_name(), - "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0" - ); - assert_eq!( - operation_id.id.to_string(), - "19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" - ); - assert_eq!( - hex::encode(operation_id.get_hash()), - "5a36f0db39e27667c4b91937cd29c1df8799ba468f2de6810c6865be05517644" - ); + { + // Check no cached. + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").unwrap(); + assert_eq!( + operation_id.to_string(), + "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!( + operation_id.id.to_string(), + "19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" + ); + } + { + // Check cached. + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/c/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").unwrap(); + assert_eq!( + operation_id.to_string(), + "main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/c/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!( + operation_id.id.to_string(), + "19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" + ); + } Ok(()) } @@ -53,7 +59,7 @@ async fn parse_empty_failure() -> Result<(), Error> { assert_eq!(operation_id.messages.len(), 1); assert_eq!( operation_id.messages[0], - "Invalid ActionInfoHashKey instance name fragment - /" + "Invalid UniqueQualifier instance name fragment - /" ); let operation_id = OperationId::try_from("main").err().unwrap(); @@ -64,7 +70,7 @@ async fn parse_empty_failure() -> Result<(), Error> { "Invalid OperationId unique_qualifier / id fragment - main" ); - let operation_id = OperationId::try_from("main/nohashfn/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); + let operation_id = OperationId::try_from("main/nohashfn/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); assert_eq!(operation_id.code, Code::InvalidArgument); assert_eq!(operation_id.messages.len(), 1); assert_eq!( @@ -73,7 +79,7 @@ async fn parse_empty_failure() -> Result<(), Error> { ); let operation_id = - OperationId::try_from("main/SHA256/badhash-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52") + OperationId::try_from("main/SHA256/badhash-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52") .err() .unwrap(); assert_eq!(operation_id.messages.len(), 3); @@ -82,37 +88,35 @@ async fn parse_empty_failure() -> Result<(), Error> { assert_eq!(operation_id.messages[1], "Invalid sha256 hash: badhash"); assert_eq!( operation_id.messages[2], - "Invalid DigestInfo digest hash - main/SHA256/badhash-211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" + "Invalid DigestInfo digest hash - main/SHA256/badhash-211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52" ); - let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); assert_eq!(operation_id.messages.len(), 2); assert_eq!(operation_id.code, Code::InvalidArgument); assert_eq!( operation_id.messages[0], "cannot parse integer from empty string" ); - assert_eq!(operation_id.messages[1], "Invalid ActionInfoHashKey size value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!(operation_id.messages[1], "Invalid UniqueQualifier size value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); - let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f--211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f--211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); assert_eq!(operation_id.code, Code::InvalidArgument); assert_eq!(operation_id.messages.len(), 2); assert_eq!(operation_id.messages[0], "invalid digit found in string"); - assert_eq!(operation_id.messages[1], "Invalid ActionInfoHashKey size value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f--211/0/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!(operation_id.messages[1], "Invalid UniqueQualifier size value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f--211/u/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/x/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); - assert_eq!(operation_id.messages.len(), 2); + assert_eq!(operation_id.messages.len(), 1); assert_eq!(operation_id.code, Code::InvalidArgument); - assert_eq!(operation_id.messages[0], "invalid digit found in string"); - assert_eq!(operation_id.messages[1], "Invalid ActionInfoHashKey salt hex conversion - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/x/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!(operation_id.messages[0], "Invalid UniqueQualifier cachable value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/x/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/-10/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52").err().unwrap(); - assert_eq!(operation_id.messages.len(), 2); + assert_eq!(operation_id.messages.len(), 1); assert_eq!(operation_id.code, Code::InvalidArgument); - assert_eq!(operation_id.messages[0], "invalid digit found in string"); - assert_eq!(operation_id.messages[1], "Invalid ActionInfoHashKey salt hex conversion - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/-10/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); + assert_eq!(operation_id.messages[0], "Invalid UniqueQualifier cachable value fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/-10/19b16cf8-a1ad-4948-aaac-b6f4eb7fca52"); - let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0/baduuid").err().unwrap(); + let operation_id = OperationId::try_from("main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/u/baduuid").err().unwrap(); assert_eq!(operation_id.messages.len(), 1); assert_eq!(operation_id.code, Code::InvalidArgument); assert_eq!(operation_id.messages[0], "Failed to parse invalid character: expected an optional prefix of `urn:uuid:` followed by [0-9a-fA-F-], found `u` at 4 as uuid"); @@ -124,7 +128,7 @@ async fn parse_empty_failure() -> Result<(), Error> { .unwrap(); assert_eq!(operation_id.messages.len(), 1); assert_eq!(operation_id.code, Code::Internal); - assert_eq!(operation_id.messages[0], "Invalid ActionInfoHashKey digest size fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0"); + assert_eq!(operation_id.messages[0], "Invalid UniqueQualifier digest size fragment - main/SHA256/4a0885a39d5ba8da3123c02ff56b73196a8b23fd3c835e1446e74a3a3ff4313f-211/0"); Ok(()) } diff --git a/nativelink-worker/src/local_worker.rs b/nativelink-worker/src/local_worker.rs index 7b4237d15..28fcc0b51 100644 --- a/nativelink-worker/src/local_worker.rs +++ b/nativelink-worker/src/local_worker.rs @@ -211,20 +211,29 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, Update::KeepAlive(()) => { self.metrics.keep_alives_received.inc(); } - Update::KillActionRequest(kill_action_request) => { - let mut action_id = [0u8; 32]; - hex::decode_to_slice(kill_action_request.action_id, &mut action_id as &mut [u8]) - .map_err(|e| make_input_err!( - "KillActionRequest failed to decode ActionId hex with error {}", - e - ))?; - - if let Err(err) = self.running_actions_manager.kill_action(action_id).await { + Update::KillOperationRequest(kill_operation_request) => { + let operation_id_res = kill_operation_request + .operation_id + .as_str() + .try_into(); + let operation_id = match operation_id_res { + Ok(operation_id) => operation_id, + Err(err) => { + event!( + Level::ERROR, + ?kill_operation_request, + ?err, + "Failed to convert string to operation_id" + ); + continue; + } + }; + if let Err(err) = self.running_actions_manager.kill_operation(&operation_id).await { event!( Level::ERROR, - action_id = hex::encode(action_id), + ?operation_id, ?err, - "Failed to send kill request for action" + "Failed to send kill request for operation" ); }; } @@ -232,7 +241,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, self.metrics.start_actions_received.inc(); let execute_request = start_execute.execute_request.as_ref(); - let salt = start_execute.salt; + let operation_id = start_execute.operation_id.clone(); let maybe_instance_name = execute_request.map(|v| v.instance_name.clone()); let action_digest = execute_request.and_then(|v| v.action_digest.clone()); let digest_hasher = execute_request @@ -257,7 +266,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, .and_then(|action| { event!( Level::INFO, - action_id = hex::encode(action.get_action_id()), + operation_id = ?action.get_operation_id(), "Received request to run action" ); action @@ -303,9 +312,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, ExecuteResult{ worker_id, instance_name, - action_digest, - salt, - digest_function: digest_hasher.proto_digest_func().into(), + operation_id, result: Some(execute_result::Result::ExecuteResponse(action_stage.into())), } ) @@ -316,9 +323,7 @@ impl<'a, T: WorkerApiClientTrait, U: RunningActionsManager> LocalWorkerImpl<'a, grpc_client.execution_response(ExecuteResult{ worker_id, instance_name, - action_digest, - salt, - digest_function: digest_hasher.proto_digest_func().into(), + operation_id, result: Some(execute_result::Result::InternalError(e.into())), }).await.err_tip(|| "Error calling execution_response with error")?; }, diff --git a/nativelink-worker/src/running_actions_manager.rs b/nativelink-worker/src/running_actions_manager.rs index a22910065..ec01256df 100644 --- a/nativelink-worker/src/running_actions_manager.rs +++ b/nativelink-worker/src/running_actions_manager.rs @@ -56,7 +56,7 @@ use nativelink_store::filesystem_store::{FileEntry, FilesystemStore}; use nativelink_store::grpc_store::GrpcStore; use nativelink_util::action_messages::{ to_execute_response, ActionInfo, ActionResult, DirectoryInfo, ExecutionMetadata, FileInfo, - NameOrPath, SymlinkInfo, + NameOrPath, OperationId, SymlinkInfo, }; use nativelink_util::common::{fs, DigestInfo}; use nativelink_util::digest_hasher::{DigestHasher, DigestHasherFunc}; @@ -78,8 +78,6 @@ use tonic::Request; use tracing::{enabled, event, Level}; use uuid::Uuid; -pub type ActionId = [u8; 32]; - /// For simplicity we use a fixed exit code for cases when our program is terminated /// due to a signal. const EXIT_CODE_FOR_SIGNAL: i32 = 9; @@ -531,7 +529,7 @@ async fn process_side_channel_file( async fn do_cleanup( running_actions_manager: &RunningActionsManagerImpl, - action_id: &ActionId, + operation_id: &OperationId, action_directory: &str, ) -> Result<(), Error> { event!(Level::INFO, "Worker cleaning up"); @@ -539,10 +537,10 @@ async fn do_cleanup( let remove_dir_result = fs::remove_dir_all(action_directory) .await .err_tip(|| format!("Could not remove working directory {action_directory}")); - if let Err(err) = running_actions_manager.cleanup_action(action_id) { + if let Err(err) = running_actions_manager.cleanup_action(operation_id) { event!( Level::ERROR, - action_id = hex::encode(action_id), + ?operation_id, ?err, "Error cleaning up action" ); @@ -551,7 +549,7 @@ async fn do_cleanup( if let Err(err) = remove_dir_result { event!( Level::ERROR, - action_id = hex::encode(action_id), + ?operation_id, ?err, "Error removing working directory" ); @@ -562,7 +560,7 @@ async fn do_cleanup( pub trait RunningAction: Sync + Send + Sized + Unpin + 'static { /// Returns the action id of the action. - fn get_action_id(&self) -> &ActionId; + fn get_operation_id(&self) -> &OperationId; /// Anything that needs to execute before the actions is actually executed should happen here. fn prepare_action(self: Arc) -> impl Future, Error>> + Send; @@ -611,7 +609,7 @@ struct RunningActionImplState { } pub struct RunningActionImpl { - action_id: ActionId, + operation_id: OperationId, action_directory: String, work_directory: String, action_info: ActionInfo, @@ -624,7 +622,7 @@ pub struct RunningActionImpl { impl RunningActionImpl { fn new( execution_metadata: ExecutionMetadata, - action_id: ActionId, + operation_id: OperationId, action_directory: String, action_info: ActionInfo, timeout: Duration, @@ -633,7 +631,7 @@ impl RunningActionImpl { let work_directory = format!("{}/{}", action_directory, "work"); let (kill_channel_tx, kill_channel_rx) = oneshot::channel(); Self { - action_id, + operation_id, action_directory, work_directory, action_info, @@ -988,14 +986,14 @@ impl RunningActionImpl { if let Err(err) = child_process_guard.start_kill() { event!( Level::ERROR, - action_id = hex::encode(self.action_id), + operation_id = ?self.operation_id, ?err, "Could not kill process", ); } else { event!( Level::ERROR, - action_id = hex::encode(self.action_id), + operation_id = ?self.operation_id, "Could not get child process id, maybe already dead?", ); } @@ -1034,7 +1032,7 @@ impl RunningActionImpl { ) }; let cas_store = self.running_actions_manager.cas_store.as_ref(); - let hasher = self.action_info.unique_qualifier.digest_function; + let hasher = self.action_info.unique_qualifier.digest_function(); enum OutputType { None, File(FileInfo), @@ -1250,23 +1248,23 @@ impl Drop for RunningActionImpl { if self.did_cleanup.load(Ordering::Acquire) { return; } + let operation_id = self.operation_id.clone(); event!( Level::ERROR, - action_id = hex::encode(self.action_id), + ?operation_id, "RunningActionImpl did not cleanup. This is a violation of the requirements, will attempt to do it in the background." ); let running_actions_manager = self.running_actions_manager.clone(); - let action_id = self.action_id; let action_directory = self.action_directory.clone(); background_spawn!("running_action_impl_drop", async move { let Err(err) = - do_cleanup(&running_actions_manager, &action_id, &action_directory).await + do_cleanup(&running_actions_manager, &operation_id, &action_directory).await else { return; }; event!( Level::ERROR, - action_id = hex::encode(action_id), + ?operation_id, ?action_directory, ?err, "Error cleaning up action" @@ -1276,8 +1274,8 @@ impl Drop for RunningActionImpl { } impl RunningAction for RunningActionImpl { - fn get_action_id(&self) -> &ActionId { - &self.action_id + fn get_operation_id(&self) -> &OperationId { + &self.operation_id } async fn prepare_action(self: Arc) -> Result, Error> { @@ -1311,7 +1309,7 @@ impl RunningAction for RunningActionImpl { .wrap(async move { let result = do_cleanup( &self.running_actions_manager, - &self.action_id, + &self.operation_id, &self.action_directory, ) .await; @@ -1352,7 +1350,10 @@ pub trait RunningActionsManager: Sync + Send + Sized + Unpin + 'static { fn kill_all(&self) -> impl Future + Send; - fn kill_action(&self, action_id: ActionId) -> impl Future> + Send; + fn kill_operation( + &self, + operation_id: &OperationId, + ) -> impl Future> + Send; fn metrics(&self) -> &Arc; } @@ -1643,7 +1644,7 @@ pub struct RunningActionsManagerImpl { upload_action_results: UploadActionResults, max_action_timeout: Duration, timeout_handled_externally: bool, - running_actions: Mutex>>, + running_actions: Mutex>>, // Note: We don't use Notify because we need to support a .wait_for()-like function, which // Notify does not support. action_done_tx: watch::Sender<()>, @@ -1699,11 +1700,10 @@ impl RunningActionsManagerImpl { fn make_action_directory<'a>( &'a self, - action_id: &'a ActionId, + operation_id: &'a OperationId, ) -> impl Future> + 'a { self.metrics.make_action_directory.wrap(async move { - let action_directory = - format!("{}/{}", self.root_action_directory, hex::encode(action_id)); + let action_directory = format!("{}/{}", self.root_action_directory, operation_id.id); fs::create_dir(&action_directory) .await .err_tip(|| format!("Error creating action directory {action_directory}"))?; @@ -1730,10 +1730,9 @@ impl RunningActionsManagerImpl { get_and_decode_digest::(self.cas_store.as_ref(), action_digest.into()) .await .err_tip(|| "During start_action")?; - let action_info = ActionInfo::try_from_action_and_execute_request_with_salt( + let action_info = ActionInfo::try_from_action_and_execute_request( execute_request, action, - start_execute.salt, load_start_timestamp, queued_timestamp, ) @@ -1742,10 +1741,10 @@ impl RunningActionsManagerImpl { }) } - fn cleanup_action(&self, action_id: &ActionId) -> Result<(), Error> { + fn cleanup_action(&self, operation_id: &OperationId) -> Result<(), Error> { let mut running_actions = self.running_actions.lock(); - let result = running_actions.remove(action_id).err_tip(|| { - format!("Expected action id '{action_id:?}' to exist in RunningActionsManagerImpl") + let result = running_actions.remove(operation_id).err_tip(|| { + format!("Expected action id '{operation_id:?}' to exist in RunningActionsManagerImpl") }); // No need to copy anything, we just are telling the receivers an event happened. self.action_done_tx.send_modify(|_| {}); @@ -1754,11 +1753,11 @@ impl RunningActionsManagerImpl { // Note: We do not capture metrics on this call, only `.kill_all()`. // Important: When the future returns the process may still be running. - async fn kill_action(action: Arc) { + async fn kill_operation(action: Arc) { event!( Level::WARN, - action_id = ?hex::encode(action.action_id), - "Sending kill to running action", + operation_id = ?action.operation_id, + "Sending kill to running operation", ); let kill_channel_tx = { let mut action_state = action.state.lock(); @@ -1768,8 +1767,8 @@ impl RunningActionsManagerImpl { if kill_channel_tx.send(()).is_err() { event!( Level::ERROR, - action_id = ?hex::encode(action.action_id), - "Error sending kill to running action", + operation_id = ?action.operation_id, + "Error sending kill to running operation", ); } } @@ -1792,14 +1791,18 @@ impl RunningActionsManager for RunningActionsManagerImpl { .clone() .and_then(|time| time.try_into().ok()) .unwrap_or(SystemTime::UNIX_EPOCH); + let operation_id: OperationId = start_execute + .operation_id + .as_str() + .try_into() + .err_tip(|| "Could not convert to operation_id in RunningActionsManager::create_and_add_action")?; let action_info = self.create_action_info(start_execute, queued_timestamp).await?; event!( Level::INFO, ?action_info, "Worker received action", ); - let action_id = action_info.unique_qualifier.get_hash(); - let action_directory = self.make_action_directory(&action_id).await?; + let action_directory = self.make_action_directory(&operation_id).await?; let execution_metadata = ExecutionMetadata { worker: worker_id, queued_timestamp: action_info.insert_timestamp, @@ -1827,7 +1830,7 @@ impl RunningActionsManager for RunningActionsManagerImpl { } let running_action = Arc::new(RunningActionImpl::new( execution_metadata, - action_id, + operation_id.clone(), action_directory, action_info, timeout, @@ -1835,7 +1838,7 @@ impl RunningActionsManager for RunningActionsManagerImpl { )); { let mut running_actions = self.running_actions.lock(); - running_actions.insert(action_id, Arc::downgrade(&running_action)); + running_actions.insert(operation_id, Arc::downgrade(&running_action)); } Ok(running_action) }) @@ -1858,17 +1861,15 @@ impl RunningActionsManager for RunningActionsManagerImpl { .await } - async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> { + async fn kill_operation(&self, operation_id: &OperationId) -> Result<(), Error> { let running_action = { let running_actions = self.running_actions.lock(); running_actions - .get(&action_id) + .get(operation_id) .and_then(|action| action.upgrade()) - .ok_or_else(|| { - make_input_err!("Failed to get running action {}", hex::encode(action_id)) - })? + .ok_or_else(|| make_input_err!("Failed to get running action {operation_id}"))? }; - Self::kill_action(running_action).await; + Self::kill_operation(running_action).await; Ok(()) } @@ -1877,15 +1878,15 @@ impl RunningActionsManager for RunningActionsManagerImpl { self.metrics .kill_all .wrap_no_capture_result(async move { - let kill_actions: Vec> = { + let kill_operations: Vec> = { let running_actions = self.running_actions.lock(); running_actions .iter() - .filter_map(|(_action_id, action)| action.upgrade()) + .filter_map(|(_operation_id, action)| action.upgrade()) .collect() }; - for action in kill_actions { - Self::kill_action(action).await; + for action in kill_operations { + Self::kill_operation(action).await; } }) .await; diff --git a/nativelink-worker/tests/local_worker_test.rs b/nativelink-worker/tests/local_worker_test.rs index aef1c6e6a..673579b75 100644 --- a/nativelink-worker/tests/local_worker_test.rs +++ b/nativelink-worker/tests/local_worker_test.rs @@ -32,18 +32,18 @@ mod utils { use nativelink_config::cas_server::{LocalWorkerConfig, WorkerProperty}; use nativelink_error::{make_err, make_input_err, Code, Error}; use nativelink_macro::nativelink_test; -use nativelink_proto::build::bazel::remote::execution::v2::digest_function; use nativelink_proto::build::bazel::remote::execution::v2::platform::Property; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::update_for_worker::Update; use nativelink_proto::com::github::trace_machina::nativelink::remote_execution::{ - execute_result, ConnectionResult, ExecuteResult, KillActionRequest, StartExecute, + execute_result, ConnectionResult, ExecuteResult, KillOperationRequest, StartExecute, SupportedProperties, UpdateForWorker, }; use nativelink_store::fast_slow_store::FastSlowStore; use nativelink_store::filesystem_store::FilesystemStore; use nativelink_store::memory_store::MemoryStore; use nativelink_util::action_messages::{ - ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ExecutionMetadata, + ActionInfo, ActionResult, ActionStage, ActionUniqueKey, ActionUniqueQualifier, + ExecutionMetadata, OperationId, }; use nativelink_util::common::{encode_stream_proto, fs, DigestInfo}; use nativelink_util::digest_hasher::DigestHasherFunc; @@ -195,8 +195,6 @@ async fn kill_all_called_on_disconnect() -> Result<(), Box Result<(), Box> { - const SALT: u64 = 1000; - let mut test_context = setup_local_worker(HashMap::new()).await; let streaming_response = test_context.maybe_streaming_response.take().unwrap(); @@ -233,13 +231,11 @@ async fn blake3_digest_function_registerd_properly() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { - const SALT: u64 = 1000; - let mut test_context = setup_local_worker(HashMap::new()).await; let streaming_response = test_context.maybe_streaming_response.take().unwrap(); @@ -319,13 +313,11 @@ async fn simple_worker_start_action_test() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { - const SALT: u64 = 1000; - let mut test_context = setup_local_worker(HashMap::new()).await; let streaming_response = test_context.maybe_streaming_response.take().unwrap(); @@ -677,22 +660,21 @@ async fn kill_action_request_kills_action() -> Result<(), Box Result<(), Box SystemTime { previous_time } +fn make_operation_id(execute_request: &ExecuteRequest) -> OperationId { + let unique_qualifier = ActionUniqueQualifier::Cachable(ActionUniqueKey { + instance_name: execute_request.instance_name.clone(), + digest_function: execute_request.digest_function.try_into().unwrap(), + digest: execute_request + .action_digest + .clone() + .unwrap() + .try_into() + .unwrap(), + }); + OperationId::new(unique_qualifier) +} + #[nativelink_test] async fn download_to_directory_file_download_test() -> Result<(), Box> { const FILE1_NAME: &str = "file1.txt"; @@ -443,7 +459,6 @@ async fn ensure_output_files_full_directories_are_created_no_working_directory_t }, )?); { - const SALT: u64 = 55; let command = Command { arguments: vec!["touch".to_string(), "./some/path/test.txt".to_string()], output_files: vec!["some/path/test.txt".to_string()], @@ -487,16 +502,18 @@ async fn ensure_output_files_full_directories_are_created_no_working_directory_t ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: None, }, ) @@ -557,7 +574,6 @@ async fn ensure_output_files_full_directories_are_created_test( }, )?); { - const SALT: u64 = 55; let working_directory = "some_cwd"; let command = Command { arguments: vec!["touch".to_string(), "./some/path/test.txt".to_string()], @@ -603,16 +619,18 @@ async fn ensure_output_files_full_directories_are_created_test( ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: None, }, ) @@ -673,7 +691,6 @@ async fn blake3_upload_files() -> Result<(), Box> { }, )?); let action_result = { - const SALT: u64 = 55; #[cfg(target_family = "unix")] let arguments = vec![ "sh".to_string(), @@ -734,16 +751,19 @@ async fn blake3_upload_files() -> Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + digest_function: ProtoDigestFunction::Blake3.into(), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Blake3.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: None, }, ) @@ -844,7 +864,6 @@ async fn upload_files_from_above_cwd_test() -> Result<(), Box Result<(), Box Result<(), Box> )?); let queued_timestamp = make_system_time(1000); let action_result = { - const SALT: u64 = 55; let command = Command { arguments: vec![ "sh".to_string(), @@ -1060,16 +1080,18 @@ async fn upload_dir_and_symlink_test() -> Result<(), Box> ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(queued_timestamp.into()), }, ) @@ -1223,7 +1245,6 @@ async fn cleanup_happens_on_job_failure() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 55; let (_, _, cas_store, ac_store) = setup_stores().await?; let root_action_directory = make_temp_path("root_action_directory"); @@ -1383,17 +1405,19 @@ async fn kill_ends_action() -> Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .clone() .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -1445,7 +1469,6 @@ echo | set /p=\"Wrapper script did run\" 1>&2 exit 0 "; const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 66; const EXPECTED_STDOUT: &str = "Action did run"; let (_, _, cas_store, ac_store) = setup_stores().await?; @@ -1528,17 +1551,19 @@ exit 0 ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .clone() .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -1587,7 +1612,6 @@ echo | set /p=\"Wrapper script did run with property %PROPERTY% %VALUE% %INNER_T exit 0 "; const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 66; const EXPECTED_STDOUT: &str = "Action did run"; let (_, _, cas_store, ac_store) = setup_stores().await?; @@ -1694,17 +1718,19 @@ exit 0 ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .clone() .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -1751,7 +1777,6 @@ echo | set /p={\"failure\":\"timeout\"} 1>&2 > %SIDE_CHANNEL_FILE% exit 1 "; const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 66; let (_, _, cas_store, ac_store) = setup_stores().await?; let root_action_directory = make_temp_path("root_action_directory"); @@ -1833,17 +1858,19 @@ exit 1 ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .clone() .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -2346,16 +2373,18 @@ async fn ensure_worker_timeout_chooses_correct_values() -> Result<(), Box Result<(), Box Result<(), Box Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let execute_results_fut = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: 0, + execute_request: Some(execute_request), + operation_id, queued_timestamp: Some(make_system_time(1000).into()), }, ) @@ -2740,18 +2775,20 @@ async fn kill_all_waits_for_all_tasks_to_finish() -> Result<(), Box Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), + execute_request: Some(execute_request), + operation_id, ..Default::default() }, ) @@ -2946,7 +2986,6 @@ async fn action_directory_contents_are_cleaned() -> Result<(), Box Result<(), Box Result<(), Box> { }, )?); let action_result = { - const SALT: u64 = 55; #[cfg(target_family = "unix")] let arguments = vec![ "sh".to_string(), @@ -3110,16 +3150,18 @@ async fn upload_with_single_permit() -> Result<(), Box> { ) .await?; + let execute_request = ExecuteRequest { + action_digest: Some(action_digest.into()), + ..Default::default() + }; + let operation_id = make_operation_id(&execute_request).to_string(); + let running_action_impl = running_actions_manager .create_and_add_action( WORKER_ID.to_string(), StartExecute { - execute_request: Some(ExecuteRequest { - action_digest: Some(action_digest.into()), - digest_function: ProtoDigestFunction::Sha256.into(), - ..Default::default() - }), - salt: SALT, + execute_request: Some(execute_request), + operation_id, queued_timestamp: None, }, ) @@ -3196,7 +3238,6 @@ async fn upload_with_single_permit() -> Result<(), Box> { async fn running_actions_manager_respects_action_timeout() -> Result<(), Box> { const WORKER_ID: &str = "foo_worker_id"; - const SALT: u64 = 66; let (_, _, cas_store, ac_store) = setup_stores().await?; let root_action_directory = make_temp_path("root_work_directory"); @@ -3282,17 +3323,19 @@ async fn running_actions_manager_respects_action_timeout() -> Result<(), Box>, tx_kill_all: mpsc::UnboundedSender<()>, - rx_kill_action: Mutex>, - tx_kill_action: mpsc::UnboundedSender, + rx_kill_operation: Mutex>, + tx_kill_operation: mpsc::UnboundedSender, metrics: Arc, } @@ -61,7 +59,7 @@ impl MockRunningActionsManager { let (tx_call, rx_call) = mpsc::unbounded_channel(); let (tx_resp, rx_resp) = mpsc::unbounded_channel(); let (tx_kill_all, rx_kill_all) = mpsc::unbounded_channel(); - let (tx_kill_action, rx_kill_action) = mpsc::unbounded_channel(); + let (tx_kill_operation, rx_kill_operation) = mpsc::unbounded_channel(); Self { rx_call: Mutex::new(rx_call), tx_call, @@ -69,8 +67,8 @@ impl MockRunningActionsManager { tx_resp, rx_kill_all: Mutex::new(rx_kill_all), tx_kill_all, - rx_kill_action: Mutex::new(rx_kill_action), - tx_kill_action, + rx_kill_operation: Mutex::new(rx_kill_operation), + tx_kill_operation, metrics: Arc::new(Metrics::default()), } } @@ -116,9 +114,9 @@ impl MockRunningActionsManager { .expect("Could not receive msg in mpsc"); } - pub async fn expect_kill_action(&self) -> ActionId { - let mut rx_kill_action_lock = self.rx_kill_action.lock().await; - rx_kill_action_lock + pub async fn expect_kill_operation(&self) -> OperationId { + let mut rx_kill_operation_lock = self.rx_kill_operation.lock().await; + rx_kill_operation_lock .recv() .await .expect("Could not receive msg in mpsc") @@ -165,9 +163,9 @@ impl RunningActionsManager for MockRunningActionsManager { Ok(()) } - async fn kill_action(&self, action_id: ActionId) -> Result<(), Error> { - self.tx_kill_action - .send(action_id) + async fn kill_operation(&self, operation_id: &OperationId) -> Result<(), Error> { + self.tx_kill_operation + .send(operation_id.clone()) .expect("Could not send request to mpsc"); Ok(()) } @@ -344,7 +342,7 @@ impl MockRunningAction { } impl RunningAction for MockRunningAction { - fn get_action_id(&self) -> &ActionId { + fn get_operation_id(&self) -> &OperationId { unreachable!("not implemented for tests"); } diff --git a/src/bin/nativelink.rs b/src/bin/nativelink.rs index 49142778f..c05c62b16 100644 --- a/src/bin/nativelink.rs +++ b/src/bin/nativelink.rs @@ -529,7 +529,7 @@ async fn inner_main( })? .clone() .set_drain_worker( - WorkerId::try_from(worker_id.clone())?, + &WorkerId::try_from(worker_id.clone())?, is_draining, ) .await?;