diff --git a/turbopack/crates/turbo-tasks/src/invalidation.rs b/turbopack/crates/turbo-tasks/src/invalidation.rs index 9d62cd50cd386..0e88fcc9b6176 100644 --- a/turbopack/crates/turbo-tasks/src/invalidation.rs +++ b/turbopack/crates/turbo-tasks/src/invalidation.rs @@ -3,11 +3,138 @@ use std::{ fmt::Display, hash::{Hash, Hasher}, mem::replace, + sync::{Arc, Weak}, }; +use anyhow::Result; use indexmap::{map::Entry, IndexMap, IndexSet}; +use serde::{de::Visitor, Deserialize, Serialize}; +use tokio::runtime::Handle; -use crate::{magic_any::HasherMut, util::StaticOrArc}; +use crate::{ + magic_any::HasherMut, + manager::{current_task, with_turbo_tasks}, + trace::TraceRawVcs, + util::StaticOrArc, + TaskId, TurboTasksApi, +}; + +/// Get an [`Invalidator`] that can be used to invalidate the current task +/// based on external events. +pub fn get_invalidator() -> Invalidator { + let handle = Handle::current(); + Invalidator { + task: current_task("turbo_tasks::get_invalidator()"), + turbo_tasks: with_turbo_tasks(Arc::downgrade), + handle, + } +} + +pub struct Invalidator { + task: TaskId, + turbo_tasks: Weak, + handle: Handle, +} + +impl Invalidator { + pub fn invalidate(self) { + let Invalidator { + task, + turbo_tasks, + handle, + } = self; + let _ = handle.enter(); + if let Some(turbo_tasks) = turbo_tasks.upgrade() { + turbo_tasks.invalidate(task); + } + } + + pub fn invalidate_with_reason(self, reason: T) { + let Invalidator { + task, + turbo_tasks, + handle, + } = self; + let _ = handle.enter(); + if let Some(turbo_tasks) = turbo_tasks.upgrade() { + turbo_tasks.invalidate_with_reason( + task, + (Arc::new(reason) as Arc).into(), + ); + } + } + + pub fn invalidate_with_static_reason(self, reason: &'static T) { + let Invalidator { + task, + turbo_tasks, + handle, + } = self; + let _ = handle.enter(); + if let Some(turbo_tasks) = turbo_tasks.upgrade() { + turbo_tasks + .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into()); + } + } +} + +impl Hash for Invalidator { + fn hash(&self, state: &mut H) { + self.task.hash(state); + } +} + +impl PartialEq for Invalidator { + fn eq(&self, other: &Self) -> bool { + self.task == other.task + } +} + +impl Eq for Invalidator {} + +impl TraceRawVcs for Invalidator { + fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) { + // nothing here + } +} + +impl Serialize for Invalidator { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_newtype_struct("Invalidator", &self.task) + } +} + +impl<'de> Deserialize<'de> for Invalidator { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct V; + + impl<'de> Visitor<'de> for V { + type Value = Invalidator; + + fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "an Invalidator") + } + + fn visit_newtype_struct(self, deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + Ok(Invalidator { + task: TaskId::deserialize(deserializer)?, + turbo_tasks: with_turbo_tasks(Arc::downgrade), + handle: tokio::runtime::Handle::current(), + }) + } + } + deserializer.deserialize_newtype_struct("Invalidator", V) + } +} pub trait DynamicEqHash { fn as_any(&self) -> &dyn Any; diff --git a/turbopack/crates/turbo-tasks/src/lib.rs b/turbopack/crates/turbo-tasks/src/lib.rs index eb0c4497535d9..540c74fd4327f 100644 --- a/turbopack/crates/turbo-tasks/src/lib.rs +++ b/turbopack/crates/turbo-tasks/src/lib.rs @@ -83,16 +83,16 @@ pub use completion::{Completion, Completions}; pub use display::ValueToString; pub use id::{ExecutionId, FunctionId, TaskId, TraitTypeId, ValueTypeId, TRANSIENT_TASK_BIT}; pub use invalidation::{ - DynamicEqHash, InvalidationReason, InvalidationReasonKind, InvalidationReasonSet, + get_invalidator, DynamicEqHash, InvalidationReason, InvalidationReasonKind, + InvalidationReasonSet, Invalidator, }; pub use join_iter_ext::{JoinIterExt, TryFlatJoinIterExt, TryJoinIterExt}; pub use magic_any::MagicAny; pub use manager::{ - dynamic_call, dynamic_this_call, emit, get_invalidator, mark_finished, mark_stateful, - prevent_gc, run_once, run_once_with_reason, spawn_blocking, spawn_thread, trait_call, - turbo_tasks, CurrentCellRef, Invalidator, ReadConsistency, TaskPersistence, TurboTasks, - TurboTasksApi, TurboTasksBackendApi, TurboTasksBackendApiExt, TurboTasksCallApi, Unused, - UpdateInfo, + dynamic_call, dynamic_this_call, emit, mark_finished, mark_stateful, prevent_gc, run_once, + run_once_with_reason, spawn_blocking, spawn_thread, trait_call, turbo_tasks, CurrentCellRef, + ReadConsistency, TaskPersistence, TurboTasks, TurboTasksApi, TurboTasksBackendApi, + TurboTasksBackendApiExt, TurboTasksCallApi, Unused, UpdateInfo, }; pub use native_function::{FunctionMeta, NativeFunction}; pub use raw_vc::{CellId, RawVc, ReadRawVcFuture, ResolveTypeError}; diff --git a/turbopack/crates/turbo-tasks/src/manager.rs b/turbopack/crates/turbo-tasks/src/manager.rs index 6367a4139a414..6ae251fe4fe4d 100644 --- a/turbopack/crates/turbo-tasks/src/manager.rs +++ b/turbopack/crates/turbo-tasks/src/manager.rs @@ -2,7 +2,7 @@ use std::{ any::Any, borrow::Cow, future::Future, - hash::{BuildHasherDefault, Hash}, + hash::BuildHasherDefault, mem::take, panic::AssertUnwindSafe, pin::Pin, @@ -18,7 +18,7 @@ use anyhow::{anyhow, Result}; use auto_hash_map::AutoMap; use futures::FutureExt; use rustc_hash::FxHasher; -use serde::{de::Visitor, Deserialize, Serialize}; +use serde::{Deserialize, Serialize}; use tokio::{runtime::Handle, select, task_local}; use tokio_util::task::TaskTracker; use tracing::{info_span, instrument, trace_span, Instrument, Level}; @@ -1513,112 +1513,6 @@ pub(crate) fn current_task(from: &str) -> TaskId { } } -pub struct Invalidator { - task: TaskId, - turbo_tasks: Weak, - handle: Handle, -} - -impl Hash for Invalidator { - fn hash(&self, state: &mut H) { - self.task.hash(state); - } -} - -impl PartialEq for Invalidator { - fn eq(&self, other: &Self) -> bool { - self.task == other.task - } -} - -impl Eq for Invalidator {} - -impl Invalidator { - pub fn invalidate(self) { - let Invalidator { - task, - turbo_tasks, - handle, - } = self; - let _ = handle.enter(); - if let Some(turbo_tasks) = turbo_tasks.upgrade() { - turbo_tasks.invalidate(task); - } - } - - pub fn invalidate_with_reason(self, reason: T) { - let Invalidator { - task, - turbo_tasks, - handle, - } = self; - let _ = handle.enter(); - if let Some(turbo_tasks) = turbo_tasks.upgrade() { - turbo_tasks.invalidate_with_reason( - task, - (Arc::new(reason) as Arc).into(), - ); - } - } - - pub fn invalidate_with_static_reason(self, reason: &'static T) { - let Invalidator { - task, - turbo_tasks, - handle, - } = self; - let _ = handle.enter(); - if let Some(turbo_tasks) = turbo_tasks.upgrade() { - turbo_tasks - .invalidate_with_reason(task, (reason as &'static dyn InvalidationReason).into()); - } - } -} - -impl TraceRawVcs for Invalidator { - fn trace_raw_vcs(&self, _context: &mut crate::trace::TraceRawVcsContext) { - // nothing here - } -} - -impl Serialize for Invalidator { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_newtype_struct("Invalidator", &self.task) - } -} - -impl<'de> Deserialize<'de> for Invalidator { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - struct V; - - impl<'de> Visitor<'de> for V { - type Value = Invalidator; - - fn expecting(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - write!(f, "an Invalidator") - } - - fn visit_newtype_struct(self, deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - Ok(Invalidator { - task: TaskId::deserialize(deserializer)?, - turbo_tasks: weak_turbo_tasks(), - handle: tokio::runtime::Handle::current(), - }) - } - } - deserializer.deserialize_newtype_struct("Invalidator", V) - } -} - pub async fn run_once( tt: Arc, future: impl Future> + Send + 'static, @@ -1704,10 +1598,6 @@ pub fn with_turbo_tasks(func: impl FnOnce(&Arc) -> T) -> T TURBO_TASKS.with(|arc| func(arc)) } -pub fn weak_turbo_tasks() -> Weak { - TURBO_TASKS.with(Arc::downgrade) -} - pub fn with_turbo_tasks_for_testing( tt: Arc, current_task: TaskId, @@ -1738,17 +1628,6 @@ pub fn current_task_for_testing() -> TaskId { CURRENT_GLOBAL_TASK_STATE.with(|ts| ts.read().unwrap().task_id) } -/// Get an [`Invalidator`] that can be used to invalidate the current task -/// based on external events. -pub fn get_invalidator() -> Invalidator { - let handle = Handle::current(); - Invalidator { - task: current_task("turbo_tasks::get_invalidator()"), - turbo_tasks: weak_turbo_tasks(), - handle, - } -} - /// Marks the current task as finished. This excludes it from waiting for /// strongly consistency. pub fn mark_finished() {