Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(turbo-tasks): Add TaskPersistence enum for task creation functions #68866

Merged
merged 3 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 51 additions & 34 deletions turbopack/crates/turbo-tasks-macros/src/func.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use std::collections::HashSet;

use proc_macro2::{Ident, Span, TokenStream};
use quote::{quote, quote_spanned};
use quote::{quote, quote_spanned, ToTokens};
use syn::{
parenthesized,
parse::{Parse, ParseStream},
Expand All @@ -24,6 +24,8 @@ pub struct TurboFn {
inputs: Vec<Input>,
/// Should we check that the return type contains a `ResolvedValue`?
resolved: Option<Span>,
/// Should this function use `TaskPersistence::LocalCells`?
local_cells: bool,
}

#[derive(Debug)]
Expand Down Expand Up @@ -257,6 +259,7 @@ impl TurboFn {
this,
inputs,
resolved: args.resolved,
local_cells: args.local_cells.is_some(),
})
}

Expand Down Expand Up @@ -301,17 +304,26 @@ impl TurboFn {
}
}

fn inputs(&self) -> Vec<&Ident> {
self.inputs
.iter()
.map(|Input { ident, .. }| ident)
.collect()
fn input_idents(&self) -> impl Iterator<Item = &Ident> {
self.inputs.iter().map(|Input { ident, .. }| ident)
}

pub fn input_types(&self) -> Vec<&Type> {
self.inputs.iter().map(|Input { ty, .. }| ty).collect()
}

pub fn persistence(&self) -> impl ToTokens {
if self.local_cells {
quote! {
turbo_tasks::TaskPersistence::LocalCells
}
} else {
quote! {
turbo_tasks::macro_helpers::get_non_local_persistence_from_inputs(&*inputs)
}
}
}

fn converted_this(&self) -> Option<Expr> {
self.this.as_ref().map(|Input { ty: _, ident }| {
parse_quote! {
Expand Down Expand Up @@ -344,31 +356,33 @@ impl TurboFn {
/// The block of the exposed function for a dynamic dispatch call to the
/// given trait.
pub fn dynamic_block(&self, trait_type_id_ident: &Ident) -> Block {
let Some(converted_this) = self.converted_this() else {
return parse_quote! {
{
unimplemented!("trait methods without self are not yet supported")
}
};
};

let ident = &self.ident;
let output = &self.output;
let assertions = self.get_assertions();
if let Some(converted_this) = self.converted_this() {
let inputs = self.inputs();
parse_quote! {
{
#assertions
let turbo_tasks_transient = #( turbo_tasks::TaskInput::is_transient(&#inputs) ||)* false;
<#output as turbo_tasks::task::TaskOutput>::try_from_raw_vc(
turbo_tasks::trait_call(
*#trait_type_id_ident,
std::borrow::Cow::Borrowed(stringify!(#ident)),
#converted_this,
Box::new((#(#inputs,)*)) as Box<dyn turbo_tasks::MagicAny>,
turbo_tasks_transient,
)
let inputs = self.input_idents();
let persistence = self.persistence();
parse_quote! {
{
#assertions
let inputs = std::boxed::Box::new((#(#inputs,)*));
let persistence = #persistence;
<#output as turbo_tasks::task::TaskOutput>::try_from_raw_vc(
turbo_tasks::trait_call(
*#trait_type_id_ident,
std::borrow::Cow::Borrowed(stringify!(#ident)),
#converted_this,
inputs as std::boxed::Box<dyn turbo_tasks::MagicAny>,
persistence,
)
}
}
} else {
parse_quote! {
{
unimplemented!("trait methods without self are not yet supported")
}
)
}
}
}
Expand All @@ -377,19 +391,21 @@ impl TurboFn {
/// given native function.
pub fn static_block(&self, native_function_id_ident: &Ident) -> Block {
let output = &self.output;
let inputs = self.inputs();
let inputs = self.input_idents();
let persistence = self.persistence();
let assertions = self.get_assertions();
if let Some(converted_this) = self.converted_this() {
parse_quote! {
{
#assertions
let turbo_tasks_transient = #( turbo_tasks::TaskInput::is_transient(&#inputs) ||)* false;
let inputs = std::boxed::Box::new((#(#inputs,)*));
let persistence = #persistence;
<#output as turbo_tasks::task::TaskOutput>::try_from_raw_vc(
turbo_tasks::dynamic_this_call(
*#native_function_id_ident,
#converted_this,
Box::new((#(#inputs,)*)) as Box<dyn turbo_tasks::MagicAny>,
turbo_tasks_transient
inputs as std::boxed::Box<dyn turbo_tasks::MagicAny>,
persistence,
)
)
}
Expand All @@ -398,12 +414,13 @@ impl TurboFn {
parse_quote! {
{
#assertions
let turbo_tasks_transient = #( turbo_tasks::TaskInput::is_transient(&#inputs) ||)* false;
let inputs = std::boxed::Box::new((#(#inputs,)*));
let persistence = #persistence;
<#output as turbo_tasks::task::TaskOutput>::try_from_raw_vc(
turbo_tasks::dynamic_call(
*#native_function_id_ident,
Box::new((#(#inputs,)*)) as Box<dyn turbo_tasks::MagicAny>,
turbo_tasks_transient,
inputs as std::boxed::Box<dyn turbo_tasks::MagicAny>,
persistence,
)
)
}
Expand Down
4 changes: 2 additions & 2 deletions turbopack/crates/turbo-tasks-memory/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,7 @@ impl Task {
native_fn_id,
*this,
&**arg,
self.id.is_transient(),
self.id.persistence(),
turbo_tasks,
));
drop(entered);
Expand All @@ -833,7 +833,7 @@ impl Task {
name,
*this,
&**arg,
self.id.is_transient(),
self.id.persistence(),
turbo_tasks,
));
drop(entered);
Expand Down
1 change: 1 addition & 0 deletions turbopack/crates/turbo-tasks-memory/tests/local_cell.rs
14 changes: 7 additions & 7 deletions turbopack/crates/turbo-tasks-testing/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ use turbo_tasks::{
registry,
test_helpers::with_turbo_tasks_for_testing,
util::{SharedError, StaticOrArc},
CellId, ExecutionId, InvalidationReason, MagicAny, RawVc, TaskId, TraitTypeId, TurboTasksApi,
TurboTasksCallApi,
CellId, ExecutionId, InvalidationReason, MagicAny, RawVc, TaskId, TaskPersistence, TraitTypeId,
TurboTasksApi, TurboTasksCallApi,
};

pub use crate::run::{run, run_without_cache_check, Registration};
Expand Down Expand Up @@ -92,7 +92,7 @@ impl TurboTasksCallApi for VcStorage {
&self,
func: turbo_tasks::FunctionId,
arg: Box<dyn MagicAny>,
_is_transient: bool,
_persistence: TaskPersistence,
) -> RawVc {
self.dynamic_call(func, None, arg)
}
Expand All @@ -102,7 +102,7 @@ impl TurboTasksCallApi for VcStorage {
func: turbo_tasks::FunctionId,
this_arg: RawVc,
arg: Box<dyn MagicAny>,
_is_transient: bool,
_persistence: TaskPersistence,
) -> RawVc {
self.dynamic_call(func, Some(this_arg), arg)
}
Expand All @@ -111,7 +111,7 @@ impl TurboTasksCallApi for VcStorage {
&self,
_func: turbo_tasks::FunctionId,
_arg: Box<dyn MagicAny>,
_is_transient: bool,
_persistence: TaskPersistence,
) -> RawVc {
unreachable!()
}
Expand All @@ -121,7 +121,7 @@ impl TurboTasksCallApi for VcStorage {
_func: turbo_tasks::FunctionId,
_this: RawVc,
_arg: Box<dyn MagicAny>,
_is_transient: bool,
_persistence: TaskPersistence,
) -> RawVc {
unreachable!()
}
Expand All @@ -132,7 +132,7 @@ impl TurboTasksCallApi for VcStorage {
_trait_fn_name: Cow<'static, str>,
_this: RawVc,
_arg: Box<dyn MagicAny>,
_is_transient: bool,
_persistence: TaskPersistence,
) -> RawVc {
unreachable!()
}
Expand Down
46 changes: 32 additions & 14 deletions turbopack/crates/turbo-tasks-testing/tests/local_cell.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#![feature(arbitrary_self_types)]

use anyhow::Result;
use turbo_tasks::{
debug::ValueDebug, test_helpers::current_task_for_testing, ResolvedValue, ValueDefault, Vc,
};
Expand All @@ -14,8 +15,8 @@ struct Wrapper(u32);
struct TransparentWrapper(u32);

#[tokio::test]
async fn test_store_and_read() {
run(&REGISTRATION, async {
async fn test_store_and_read() -> Result<()> {
run(&REGISTRATION, || async {
let a: Vc<u32> = Vc::local_cell(42);
assert_eq!(*a.await.unwrap(), 42);

Expand All @@ -24,13 +25,15 @@ async fn test_store_and_read() {

let c = TransparentWrapper(42).local_cell();
assert_eq!(*c.await.unwrap(), 42);

Ok(())
})
.await
}

#[tokio::test]
async fn test_store_and_read_generic() {
run(&REGISTRATION, async {
async fn test_store_and_read_generic() -> Result<()> {
run(&REGISTRATION, || async {
// `Vc<Vec<Vc<T>>>` is stored as `Vc<Vec<Vc<()>>>` and requires special
// transmute handling
let cells: Vc<Vec<Vc<u32>>> =
Expand All @@ -42,6 +45,8 @@ async fn test_store_and_read_generic() {
}

assert_eq!(output, vec![1, 2, 3]);

Ok(())
})
.await
}
Expand All @@ -53,10 +58,12 @@ async fn returns_resolved_local_vc() -> Vc<u32> {
cell.resolve().await.unwrap()
}

#[ignore]
#[tokio::test]
async fn test_return_resolved() {
run(&REGISTRATION, async {
async fn test_return_resolved() -> Result<()> {
run(&REGISTRATION, || async {
assert_eq!(*returns_resolved_local_vc().await.unwrap(), 42);
Ok(())
})
.await
}
Expand All @@ -65,8 +72,8 @@ async fn test_return_resolved() {
trait UnimplementedTrait {}

#[tokio::test]
async fn test_try_resolve_sidecast() {
run(&REGISTRATION, async {
async fn test_try_resolve_sidecast() -> Result<()> {
run(&REGISTRATION, || async {
let trait_vc: Vc<Box<dyn ValueDebug>> = Vc::upcast(Vc::<u32>::local_cell(42));

// `u32` is both a `ValueDebug` and a `ValueDefault`, so this sidecast is valid
Expand All @@ -80,13 +87,15 @@ async fn test_try_resolve_sidecast() {
.await
.unwrap();
assert!(wrongly_sidecast_vc.is_none());

Ok(())
})
.await
}

#[tokio::test]
async fn test_try_resolve_downcast_type() {
run(&REGISTRATION, async {
async fn test_try_resolve_downcast_type() -> Result<()> {
run(&REGISTRATION, || async {
let trait_vc: Vc<Box<dyn ValueDebug>> = Vc::upcast(Vc::<u32>::local_cell(42));

let downcast_vc: Vc<u32> = Vc::try_resolve_downcast_type(trait_vc)
Expand All @@ -98,16 +107,19 @@ async fn test_try_resolve_downcast_type() {
let wrongly_downcast_vc: Option<Vc<i64>> =
Vc::try_resolve_downcast_type(trait_vc).await.unwrap();
assert!(wrongly_downcast_vc.is_none());

Ok(())
})
.await
}

#[tokio::test]
async fn test_get_task_id() {
run(&REGISTRATION, async {
async fn test_get_task_id() -> Result<()> {
run(&REGISTRATION, || async {
// the task id as reported by the RawVc
let vc_task_id = Vc::into_raw(Vc::<()>::local_cell(())).get_task_id();
assert_eq!(vc_task_id, current_task_for_testing());
Ok(())
})
.await
}
Expand Down Expand Up @@ -139,25 +151,31 @@ async fn get_untracked_local_cell() -> Vc<Untracked> {
.unwrap()
}

#[ignore]
#[tokio::test]
#[should_panic(expected = "Local Vcs must only be accessed within their own task")]
async fn test_panics_on_local_cell_escape_read() {
run(&REGISTRATION, async {
run(&REGISTRATION, || async {
get_untracked_local_cell()
.await
.unwrap()
.cell
.await
.unwrap();
Ok(())
})
.await
.unwrap()
}

#[ignore]
#[tokio::test]
#[should_panic(expected = "Local Vcs must only be accessed within their own task")]
async fn test_panics_on_local_cell_escape_get_task_id() {
run(&REGISTRATION, async {
run(&REGISTRATION, || async {
Vc::into_raw(get_untracked_local_cell().await.unwrap().cell).get_task_id();
Ok(())
})
.await
.unwrap()
}
Loading
Loading