Skip to content

Commit

Permalink
QA fixes for scheduler-v2 (#1092)
Browse files Browse the repository at this point in the history
Introduces various code quality improvements to scheduler-v2
which allow bazel test to pass.

Co-authored-by: Zach Birenbaum <zacharyobirenbaum@gmail.com>
  • Loading branch information
allada and zbirenbaum committed Jul 7, 2024
1 parent f2cea0c commit d70d31d
Show file tree
Hide file tree
Showing 15 changed files with 114 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ pub struct ExecuteResult {
/// that case the server SHOULD infer the digest function using the
/// length of the action digest hash and the digest functions announced
/// in the server's capabilities.
/// TODO!!! THIS IS UNUSED AFAIK.
#[prost(
enumeration = "super::super::super::super::super::build::bazel::remote::execution::v2::digest_function::Value",
tag = "7"
Expand Down Expand Up @@ -159,7 +160,7 @@ pub struct StartExecute {
pub execute_request: ::core::option::Option<
super::super::super::super::super::build::bazel::remote::execution::v2::ExecuteRequest,
>,
/// Id of the operation.
/// / Id of the operation.
#[prost(string, tag = "4")]
pub operation_id: ::prost::alloc::string::String,
/// / The time at which the command was added to the queue to allow population
Expand Down
1 change: 1 addition & 0 deletions nativelink-scheduler/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ rust_library(
"@crates//:scopeguard",
"@crates//:serde",
"@crates//:serde_json",
"@crates//:static_assertions",
"@crates//:tokio",
"@crates//:tokio-stream",
"@crates//:tonic",
Expand Down
9 changes: 5 additions & 4 deletions nativelink-scheduler/src/cache_lookup_scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ use tracing::{event, Level};
use crate::action_scheduler::ActionScheduler;
use crate::platform_property_manager::PlatformPropertyManager;

/// A future containing the resolved `ClientOperationId` once it is figured out.
/// This future may be cloned and will always yield the same value once resolved.
type ClientOperationIdFuture = SharedFuture<oneshot::Receiver<ClientOperationId>>;

/// Actions that are having their cache checked or failed cache lookup and are
/// being forwarded upstream. Missing the skip_cache_check actions which are
/// forwarded directly.
Expand Down Expand Up @@ -97,10 +101,7 @@ async fn get_action_from_store(
fn subscribe_to_existing_action(
cache_check_actions: &MutexGuard<CheckActions>,
unique_qualifier: &ActionInfoHashKey,
) -> Option<(
SharedFuture<oneshot::Receiver<ClientOperationId>>,
watch::Receiver<Arc<ActionState>>,
)> {
) -> Option<(ClientOperationIdFuture, watch::Receiver<Arc<ActionState>>)> {
cache_check_actions
.get(unique_qualifier)
.map(|(client_operation_id_rx, rx)| {
Expand Down
20 changes: 10 additions & 10 deletions nativelink-scheduler/src/scheduler_state/awaited_action.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ impl AwaitedAction {
/// will hold a lock preventing anyone else from reading or modifying the
/// sort key until the result is dropped.
#[must_use]
pub fn set_priority<'a>(&'a self, new_priority: i32) -> Option<SortInfoLock<'a>> {
pub fn set_priority(&self, new_priority: i32) -> Option<SortInfoLock> {
let sort_info_lock = self.sort_info.upgradable_read();
if sort_info_lock.priority == new_priority {
return None;
Expand All @@ -173,7 +173,7 @@ impl AwaitedAction {
}

/// Gets the sort info of the action.
pub fn get_sort_info<'a>(&'a self) -> SortInfoLock<'a> {
pub fn get_sort_info(&self) -> SortInfoLock {
let sort_info = self.sort_info.read();
SortInfoLock {
previous_sort_key: sort_info.sort_key,
Expand All @@ -199,7 +199,7 @@ impl AwaitedAction {

/// Gets the worker id that is currently processing this action.
pub fn get_worker_id(&self) -> Option<WorkerId> {
self.worker_id.read().clone()
*self.worker_id.read()
}

/// Sets the worker id that is currently processing this action.
Expand Down Expand Up @@ -244,16 +244,18 @@ impl AwaitedAction {
pub struct AwaitedActionSortKey(u128);

impl AwaitedActionSortKey {
#[rustfmt::skip]
const fn new(priority: i32, insert_timestamp: u64, hash: [u8; 4]) -> Self {
// Shift `new_priority` so [`i32::MIN`] is represented by zero.
// This makes it so any nagative values are positive, but
// maintains ordering.
const MIN_I32: i64 = (i32::MIN as i64).abs();
let priority = ((priority as i64 + MIN_I32) as u32).to_be_bytes();

// Invert our timestamp so the larger the timestamp the lower the number.
// This makes timestamp descending order instead of ascending.
let timestamp = (insert_timestamp ^ u64::MAX).to_be_bytes();

#[cfg_attr(rustfmt, rustfmt_skip)]
AwaitedActionSortKey(u128::from_be_bytes([
priority[0], priority[1], priority[2], priority[3],
timestamp[0], timestamp[1], timestamp[2], timestamp[3],
Expand Down Expand Up @@ -286,7 +288,7 @@ const_assert_eq!(
// Note: `6543210fedcba987` are the inverted bits of `9abcdef012345678`.
// This effectively inverts the priority to now have the highest priority
// be the lowest timestamps.
AwaitedActionSortKey(0x92345678_6543210fedcba987_9abcdef0).0
AwaitedActionSortKey(0x9234_5678_6543_210f_edcb_a987_9abc_def0).0
);
// Ensure the priority is used as the sort key first.
const_assert!(
Expand Down Expand Up @@ -354,11 +356,9 @@ impl MetricsComponent for AwaitedAction {
);
c.publish(
"worker_id",
&format!(
"{}",
self.get_worker_id()
.map_or(String::new(), |v| v.to_string())
),
&self
.get_worker_id()
.map_or(String::new(), |v| v.to_string()),
"The current worker processing the action (if any).",
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{borrow::Cow, sync::Arc};
use std::borrow::Cow;
use std::sync::Arc;

use async_trait::async_trait;
use nativelink_error::Error;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{borrow::Cow, sync::Arc};
use std::borrow::Cow;
use std::sync::Arc;

use async_trait::async_trait;
use nativelink_error::Error;
use nativelink_util::action_messages::{ActionInfo, ActionState};
use tokio::sync::watch;

use crate::operation_state_manager::ActionStateResult;

use super::awaited_action::AwaitedAction;
use crate::operation_state_manager::ActionStateResult;

pub struct MatchingEngineActionStateResult {
awaited_action: Arc<AwaitedAction>,
Expand Down
99 changes: 52 additions & 47 deletions nativelink-scheduler/src/scheduler_state/state_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,24 +21,22 @@ use async_lock::Mutex;
use async_trait::async_trait;
use futures::stream::{self, unfold};
use hashbrown::HashMap;
use nativelink_error::{make_err, Code, Error};
use nativelink_error::{make_err, make_input_err, Code, Error, ResultExt};
use nativelink_util::action_messages::{
ActionInfo, ActionInfoHashKey, ActionResult, ActionStage, ActionState, ClientOperationId,
ExecutionMetadata, OperationId, WorkerId,
};
use tokio::sync::Notify;
use tracing::{event, Level};

use super::awaited_action::AwaitedActionSortKey;
use crate::operation_state_manager::{
ActionStateResult, ActionStateResultStream, ClientStateManager, MatchingEngineStateManager,
OperationFilter, OperationStageFlags, OrderDirection, WorkerStateManager,
};
use crate::scheduler_state::awaited_action::AwaitedAction;
use crate::scheduler_state::client_action_state_result::ClientActionStateResult;
use crate::scheduler_state::matching_engine_action_state_result::MatchingEngineActionStateResult;
// use crate::scheduler_state::metrics::Metrics;

use super::awaited_action::AwaitedActionSortKey;

#[derive(Debug, Clone)]
struct SortedAwaitedAction {
Expand Down Expand Up @@ -96,6 +94,7 @@ pub struct AwaitedActionDb {
sorted_action_info_hash_keys: SortedAwaitedActions,
}

#[allow(clippy::mutable_key_type)]
impl AwaitedActionDb {
fn get_by_client_operation_id(
&self,
Expand Down Expand Up @@ -227,9 +226,14 @@ impl AwaitedActionDb {
action_info: ActionInfo,
) -> Arc<ClientActionStateResult> {
// Check to see if the action is already known and subscribe if it is.
let action_info = match self.try_subscribe(&new_client_operation_id, action_info) {
let action_info = match self.try_subscribe(
&new_client_operation_id,
&action_info.unique_qualifier,
action_info.priority,
action_info.skip_cache_lookup,
) {
Ok(subscription) => return subscription,
Err(action_info) => Arc::new(action_info),
Err(_) => Arc::new(action_info),
};

let (awaited_action, sort_key, subscription) =
Expand All @@ -249,30 +253,41 @@ impl AwaitedActionDb {
awaited_action,
},
);
return Arc::new(ClientActionStateResult::new(subscription));
Arc::new(ClientActionStateResult::new(subscription))
}

fn try_subscribe(
&mut self,
client_operation_id: &ClientOperationId,
action_info: ActionInfo,
) -> Result<Arc<ClientActionStateResult>, ActionInfo> {
if action_info.skip_cache_lookup {
return Err(action_info);
unique_qualifier: &ActionInfoHashKey,
priority: i32,
skip_cache_lookup: bool,
) -> Result<Arc<ClientActionStateResult>, Error> {
if skip_cache_lookup {
return Err(make_err!(
Code::InvalidArgument,
"Cannot subscribe to an existing item when skip_cache_lookup is true."
));
}
let Some(awaited_action) = self

let awaited_action = self
.action_info_hash_key_to_awaited_action
.get(&action_info.unique_qualifier)
else {
return Err(action_info);
};
.get(unique_qualifier)
.ok_or(make_input_err!(
"Could not find existing action with name: {}",
unique_qualifier.action_name()
))
.err_tip(|| "In state_manager::try_subscribe")?;

// Do not subscribe if the action is already completed,
// this is the responsibility of the CacheLookupScheduler.
if awaited_action.get_current_state().stage.is_finished() {
return Err(action_info);
return Err(make_input_err!(
"Subscribing an item that is already completed should be handled by CacheLookupScheduler."
));
}
let awaited_action = awaited_action.clone();
if let Some(sort_info_lock) = awaited_action.set_priority(action_info.priority) {
if let Some(sort_info_lock) = awaited_action.set_priority(priority) {
let state = awaited_action.get_current_state();
let maybe_sorted_awaited_action =
self.get_sort_map_for_state(&state.stage)
Expand All @@ -281,13 +296,12 @@ impl AwaitedActionDb {
awaited_action: awaited_action.clone(),
});
let Some(mut sorted_awaited_action) = maybe_sorted_awaited_action else {
event!(
Level::ERROR,
?action_info,
?awaited_action,
"sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync",
);
return Err(action_info);
// TODO: Either use event on all of the above error here, but both is overkill.
let err = make_err!(
Code::Internal,
"sorted_action_info_hash_keys and action_info_hash_key_to_awaited_action are out of sync");
event!(Level::ERROR, ?unique_qualifier, ?awaited_action, "{err:?}",);
return Err(err);
};
sorted_awaited_action.sort_key = sort_info_lock.get_new_sort_key();
self.insert_sort_map_for_stage(&state.stage, sorted_awaited_action);
Expand All @@ -309,15 +323,10 @@ pub struct StateManager {

impl StateManager {
#[allow(clippy::too_many_arguments)]
pub fn new(
// metrics: Arc<Metrics>,
tasks_change_notify: Arc<Notify>,
max_job_retries: usize,
) -> Self {
pub fn new(tasks_change_notify: Arc<Notify>, max_job_retries: usize) -> Self {
Self {
inner: Arc::new(Mutex::new(StateManagerImpl {
action_db: AwaitedActionDb::default(),
// metrics,
tasks_change_notify,
max_job_retries,
})),
Expand All @@ -334,10 +343,10 @@ impl StateManager {
Arc::new(MatchingEngineActionStateResult::new(awaited_action))
}

fn get_tree_for_stage<'a>(
action_db: &'a AwaitedActionDb,
fn get_tree_for_stage(
action_db: &AwaitedActionDb,
stage: OperationStageFlags,
) -> Option<&'a BTreeSet<SortedAwaitedAction>> {
) -> Option<&BTreeSet<SortedAwaitedAction>> {
match stage {
OperationStageFlags::CacheCheck => Some(action_db.get_cache_check_actions()),
OperationStageFlags::Queued => Some(action_db.get_queued_actions()),
Expand All @@ -352,7 +361,7 @@ impl StateManager {
if let Some(operation_id) = &filter.operation_id {
return Ok(inner
.action_db
.get_by_operation_id(&operation_id)
.get_by_operation_id(operation_id)
.filter(|awaited_action| filter_check(awaited_action.as_ref(), filter))
.cloned()
.map(|awaited_action| -> ActionStateResultStream {
Expand Down Expand Up @@ -431,6 +440,7 @@ impl StateManager {

let inner = state.inner.lock().await;

#[allow(clippy::mutable_key_type)]
let btree = get_tree_for_stage(&inner.action_db, state.filter.stages)
.expect("get_tree_for_stage() should have already returned Some but in iteration it returned None");

Expand All @@ -450,16 +460,14 @@ impl StateManager {
.for_each(|item| state.buffer.push_back(item.clone()));
}
drop(inner);
let Some(sorted_awaited_action) = state.buffer.pop_front() else {
return None;
};
let sorted_awaited_action = state.buffer.pop_front()?;
if state.buffer.is_empty() {
state.start_key = Bound::Excluded(sorted_awaited_action.clone());
}
return Some((
Some((
to_action_state_result(sorted_awaited_action.awaited_action),
state,
));
))
})))
}
}
Expand All @@ -470,7 +478,6 @@ impl StateManager {
pub(crate) struct StateManagerImpl {
pub(crate) action_db: AwaitedActionDb,

// pub(crate) metrics: Arc<Metrics>,
/// Notify task<->worker matching engine that work needs to be done.
pub(crate) tasks_change_notify: Arc<Notify>,

Expand All @@ -486,10 +493,8 @@ fn filter_check(awaited_action: &AwaitedAction, filter: &OperationFilter) -> boo
}
}

if filter.worker_id.is_some() {
if filter.worker_id != awaited_action.get_worker_id() {
return false;
}
if filter.worker_id.is_some() && filter.worker_id != awaited_action.get_worker_id() {
return false;
}

{
Expand Down Expand Up @@ -621,7 +626,7 @@ impl StateManagerImpl {
// which worker sent the update.
awaited_action.set_worker_id(None);
} else {
awaited_action.set_worker_id(maybe_worker_id.map(|w| w.clone()));
awaited_action.set_worker_id(maybe_worker_id.copied());
}
let has_listeners = self.action_db.set_action_state(
awaited_action.clone(),
Expand Down Expand Up @@ -653,7 +658,7 @@ impl StateManagerImpl {
.action_db
.subscribe_or_add_action(new_client_operation_id, action_info);
self.tasks_change_notify.notify_one();
return Ok(subscription);
Ok(subscription)
}
}

Expand Down
Loading

0 comments on commit d70d31d

Please sign in to comment.