Skip to content

Commit

Permalink
add TaskPool::spawn_pollable so that the consumer does not need a blo…
Browse files Browse the repository at this point in the history
…ck_on method
  • Loading branch information
pubrrr committed Mar 4, 2022
1 parent b6a647c commit 2d868eb
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 16 deletions.
2 changes: 2 additions & 0 deletions crates/bevy_tasks/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ license = "MIT OR Apache-2.0"
keywords = ["bevy"]

[dependencies]
bevy_utils = { path = "../bevy_utils", version = "0.6.0" }

futures-lite = "1.4.0"
event-listener = "2.4.0"
async-executor = "1.3.0"
Expand Down
3 changes: 3 additions & 0 deletions crates/bevy_tasks/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ pub use slice::{ParallelSlice, ParallelSliceMut};
mod task;
pub use task::Task;

mod pollable_task;
pub use pollable_task::PollableTask;

#[cfg(not(target_arch = "wasm32"))]
mod task_pool;
#[cfg(not(target_arch = "wasm32"))]
Expand Down
31 changes: 31 additions & 0 deletions crates/bevy_tasks/src/pollable_task.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
use crate::Task;
use async_channel::{Receiver, TryRecvError};

/// A pollable task whose result readiness can be checked in system functions
/// on every frame update without blocking on a future
#[derive(Debug)]
pub struct PollableTask<T> {
receiver: Receiver<T>,
// this is to keep the task alive
_task: Task<()>,
}

impl<T> PollableTask<T> {
pub(crate) fn new(receiver: Receiver<T>, task: Task<()>) -> Self {
Self {
receiver,
_task: task,
}
}

/// poll to see whether the task finished
pub fn poll(&self) -> Option<T> {
match self.receiver.try_recv() {
Ok(value) => Some(value),
Err(try_error) => match try_error {
TryRecvError::Empty => None,
TryRecvError::Closed => panic!("todo"),
},
}
}
}
59 changes: 54 additions & 5 deletions crates/bevy_tasks/src/task_pool.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use async_channel::bounded;
use std::{
future::Future,
mem,
Expand All @@ -6,9 +7,10 @@ use std::{
thread::{self, JoinHandle},
};

use bevy_utils::tracing::warn;
use futures_lite::{future, pin};

use crate::Task;
use crate::{PollableTask, Task};

/// Used to create a [`TaskPool`]
#[derive(Debug, Default, Clone)]
Expand Down Expand Up @@ -239,6 +241,26 @@ impl TaskPool {
Task::new(self.executor.spawn(future))
}

/// Spawns a static future onto the thread pool. The returned PollableTask is not a future,
/// but can be polled in system functions on every frame update without being blocked on
pub fn spawn_pollable<T>(
&self,
future: impl Future<Output = T> + Send + 'static,
) -> PollableTask<T>
where
T: Send + Sync + 'static,
{
let (sender, receiver) = bounded(1);
let task = self.spawn(async move {
let result = future.await;
match sender.send(result).await {
Ok(_) => {}
Err(_) => warn!("Could not send result of task to receiver"),
}
});
PollableTask::new(receiver, task)
}

/// Spawns a static future on the thread-local async executor for the current thread. The task
/// will run entirely on the thread the task was spawned on. The returned Task is a future.
/// It can also be cancelled and "detached" allowing it to continue running without having
Expand Down Expand Up @@ -298,9 +320,13 @@ impl<'scope, T: Send + 'scope> Scope<'scope, T> {
#[allow(clippy::blacklisted_name)]
mod tests {
use super::*;
use std::sync::{
atomic::{AtomicBool, AtomicI32, Ordering},
Barrier,
use std::{
ops::Range,
sync::{
atomic::{AtomicBool, AtomicI32, Ordering},
Barrier,
},
time::Duration,
};

#[test]
Expand Down Expand Up @@ -402,7 +428,7 @@ mod tests {
scope.spawn_local(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
if std::thread::current().id() != spawner {
// NOTE: This check is using an atomic rather than simply panicing the
// NOTE: This check is using an atomic rather than simply panicking the
// thread to avoid deadlocking the barrier on failure
inner_thread_check_failed.store(true, Ordering::Release);
}
Expand All @@ -415,4 +441,27 @@ mod tests {
assert!(!thread_check_failed.load(Ordering::Acquire));
assert_eq!(count.load(Ordering::Acquire), 200);
}

#[test]
fn test_spawn_pollable() {
let transform_fn = |i| i + 1;

let pool = TaskPool::new();
let nums: Range<u8> = 0..10;

let pollable_tasks = nums
.clone()
.into_iter()
.map(|i| pool.spawn_pollable(async move { transform_fn(i) }))
.collect::<Vec<_>>();

std::thread::sleep(Duration::from_secs_f32(1.0 / 30.0));

for (pollable_task, number) in pollable_tasks.iter().zip(nums) {
let poll_result = pollable_task.poll();

let expected = transform_fn(number);
assert_eq!(Some(expected), poll_result);
}
}
}
24 changes: 13 additions & 11 deletions examples/async_tasks/async_compute.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
use bevy::{
prelude::*,
tasks::{AsyncComputeTaskPool, Task},
};
use futures_lite::future;
use bevy::{prelude::*, tasks::AsyncComputeTaskPool};
use bevy_internal::tasks::PollableTask;
use rand::Rng;
use std::time::{Duration, Instant};

Expand All @@ -25,6 +22,11 @@ const NUM_CUBES: u32 = 6;
struct BoxMeshHandle(Handle<Mesh>);
struct BoxMaterialHandle(Handle<StandardMaterial>);

/// PollableTask is not a `Component` itself to prevent multiple systems accessing it accidentally.
/// Instead wrap it in a customer component.
#[derive(Component)]
struct ComputeTransformTask(PollableTask<Transform>);

/// Startup system which runs only once and generates our Box Mesh
/// and Box Material assets, adds them to their respective Asset
/// Resources, and stores their handles as resources so we can access
Expand All @@ -50,7 +52,7 @@ fn spawn_tasks(mut commands: Commands, thread_pool: Res<AsyncComputeTaskPool>) {
for y in 0..NUM_CUBES {
for z in 0..NUM_CUBES {
// Spawn new task on the AsyncComputeTaskPool
let task = thread_pool.spawn(async move {
let task = thread_pool.spawn_pollable(async move {
let mut rng = rand::thread_rng();
let start_time = Instant::now();
let duration = Duration::from_secs_f32(rng.gen_range(0.05..0.2));
Expand All @@ -64,7 +66,7 @@ fn spawn_tasks(mut commands: Commands, thread_pool: Res<AsyncComputeTaskPool>) {
});

// Spawn new entity and add our new task as a component
commands.spawn().insert(task);
commands.spawn().insert(ComputeTransformTask(task));
}
}
}
Expand All @@ -76,12 +78,12 @@ fn spawn_tasks(mut commands: Commands, thread_pool: Res<AsyncComputeTaskPool>) {
/// removes the task component from the entity.
fn handle_tasks(
mut commands: Commands,
mut transform_tasks: Query<(Entity, &mut Task<Transform>)>,
mut transform_tasks: Query<(Entity, &ComputeTransformTask)>,
box_mesh_handle: Res<BoxMeshHandle>,
box_material_handle: Res<BoxMaterialHandle>,
) {
for (entity, mut task) in transform_tasks.iter_mut() {
if let Some(transform) = future::block_on(future::poll_once(&mut *task)) {
for (entity, compute_transform_task) in transform_tasks.iter_mut() {
if let Some(transform) = compute_transform_task.0.poll() {
// Add our new PbrBundle of components to our tagged entity
commands.entity(entity).insert_bundle(PbrBundle {
mesh: box_mesh_handle.0.clone(),
Expand All @@ -91,7 +93,7 @@ fn handle_tasks(
});

// Task is complete, so remove task component from entity
commands.entity(entity).remove::<Task<Transform>>();
commands.entity(entity).remove::<ComputeTransformTask>();
}
}
}
Expand Down

0 comments on commit 2d868eb

Please sign in to comment.