From 012ae07dc87d43a7437bd4200b4a7f04bf99dbf3 Mon Sep 17 00:00:00 2001 From: James Liu Date: Thu, 9 Jun 2022 02:43:24 +0000 Subject: [PATCH] Add global init and get accessors for all newtyped TaskPools (#2250) Right now, a direct reference to the target TaskPool is required to launch tasks on the pools, despite the three newtyped pools (AsyncComputeTaskPool, ComputeTaskPool, and IoTaskPool) effectively acting as global instances. The need to pass a TaskPool reference adds notable friction to spawning subtasks within existing tasks. Possible use cases for this may include chaining tasks within the same pool like spawning separate send/receive I/O tasks after waiting on a network connection to be established, or allowing cross-pool dependent tasks like starting dependent multi-frame computations following a long I/O load. Other task execution runtimes provide static access to spawning tasks (i.e. `tokio::spawn`), which is notably easier to use than the reference passing required by `bevy_tasks` right now. This PR makes does the following: * Adds `*TaskPool::init` which initializes a `OnceCell`'ed with a provided TaskPool. Failing if the pool has already been initialized. * Adds `*TaskPool::get` which fetches the initialized global pool of the respective type or panics. This generally should not be an issue in normal Bevy use, as the pools are initialized before they are accessed. * Updated default task pool initialization to either pull the global handles and save them as resources, or if they are already initialized, pull the a cloned global handle as the resource. This should make it notably easier to build more complex task hierarchies for dependent tasks. It should also make writing bevy-adjacent, but not strictly bevy-only plugin crates easier, as the global pools ensure it's all running on the same threads. One alternative considered is keeping a thread-local reference to the pool for all threads in each pool to enable the same `tokio::spawn` interface. This would spawn tasks on the same pool that a task is currently running in. However this potentially leads to potential footgun situations where long running blocking tasks run on `ComputeTaskPool`. --- .../bevy_ecs/ecs_bench_suite/heavy_compute.rs | 3 +- crates/bevy_app/src/app.rs | 12 +- crates/bevy_asset/src/asset_server.rs | 18 +- crates/bevy_asset/src/debug_asset_server.rs | 12 +- crates/bevy_asset/src/lib.rs | 7 +- crates/bevy_asset/src/loader.rs | 8 - crates/bevy_core/src/lib.rs | 2 +- crates/bevy_core/src/task_pool_options.rs | 28 +-- crates/bevy_ecs/src/lib.rs | 4 +- crates/bevy_ecs/src/query/state.rs | 205 ++++++++---------- .../src/schedule/executor_parallel.rs | 5 +- crates/bevy_ecs/src/system/query.rs | 8 +- crates/bevy_gltf/Cargo.toml | 1 + crates/bevy_gltf/src/loader.rs | 4 +- crates/bevy_tasks/Cargo.toml | 3 +- crates/bevy_tasks/src/task_pool.rs | 47 ++-- crates/bevy_tasks/src/usages.rs | 71 +++++- examples/asset/custom_asset_io.rs | 6 +- examples/async_tasks/async_compute.rs | 3 +- 19 files changed, 227 insertions(+), 220 deletions(-) diff --git a/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs b/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs index 71a8e6cc6a..169f32f065 100644 --- a/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs +++ b/benches/benches/bevy_ecs/ecs_bench_suite/heavy_compute.rs @@ -18,6 +18,8 @@ pub struct Benchmark(World, Box>); impl Benchmark { pub fn new() -> Self { + ComputeTaskPool::init(TaskPool::default); + let mut world = World::default(); world.spawn_batch((0..1000).map(|_| { @@ -39,7 +41,6 @@ impl Benchmark { }); } - world.insert_resource(ComputeTaskPool(TaskPool::default())); let mut system = IntoSystem::into_system(sys); system.initialize(&mut world); system.update_archetype_component_access(&world); diff --git a/crates/bevy_app/src/app.rs b/crates/bevy_app/src/app.rs index d587344d65..cdf1984c8b 100644 --- a/crates/bevy_app/src/app.rs +++ b/crates/bevy_app/src/app.rs @@ -10,7 +10,6 @@ use bevy_ecs::{ system::Resource, world::World, }; -use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool}; use bevy_utils::{tracing::debug, HashMap}; use std::fmt::Debug; @@ -863,18 +862,9 @@ impl App { pub fn add_sub_app( &mut self, label: impl AppLabel, - mut app: App, + app: App, sub_app_runner: impl Fn(&mut World, &mut App) + 'static, ) -> &mut Self { - if let Some(pool) = self.world.get_resource::() { - app.world.insert_resource(pool.clone()); - } - if let Some(pool) = self.world.get_resource::() { - app.world.insert_resource(pool.clone()); - } - if let Some(pool) = self.world.get_resource::() { - app.world.insert_resource(pool.clone()); - } self.sub_apps.insert( Box::new(label), SubApp { diff --git a/crates/bevy_asset/src/asset_server.rs b/crates/bevy_asset/src/asset_server.rs index 1612b406b4..bc7defc361 100644 --- a/crates/bevy_asset/src/asset_server.rs +++ b/crates/bevy_asset/src/asset_server.rs @@ -7,7 +7,7 @@ use crate::{ use anyhow::Result; use bevy_ecs::system::{Res, ResMut}; use bevy_log::warn; -use bevy_tasks::TaskPool; +use bevy_tasks::IoTaskPool; use bevy_utils::{Entry, HashMap, Uuid}; use crossbeam_channel::TryRecvError; use parking_lot::{Mutex, RwLock}; @@ -56,7 +56,6 @@ pub struct AssetServerInternal { loaders: RwLock>>, extension_to_loader_index: RwLock>, handle_to_path: Arc>>>, - task_pool: TaskPool, } /// Loads assets from the filesystem on background threads @@ -66,11 +65,11 @@ pub struct AssetServer { } impl AssetServer { - pub fn new(source_io: T, task_pool: TaskPool) -> Self { - Self::with_boxed_io(Box::new(source_io), task_pool) + pub fn new(source_io: T) -> Self { + Self::with_boxed_io(Box::new(source_io)) } - pub fn with_boxed_io(asset_io: Box, task_pool: TaskPool) -> Self { + pub fn with_boxed_io(asset_io: Box) -> Self { AssetServer { server: Arc::new(AssetServerInternal { loaders: Default::default(), @@ -79,7 +78,6 @@ impl AssetServer { asset_ref_counter: Default::default(), handle_to_path: Default::default(), asset_lifecycles: Default::default(), - task_pool, asset_io, }), } @@ -315,7 +313,6 @@ impl AssetServer { &self.server.asset_ref_counter.channel, self.asset_io(), version, - &self.server.task_pool, ); if let Err(err) = asset_loader @@ -377,8 +374,7 @@ impl AssetServer { pub(crate) fn load_untracked(&self, asset_path: AssetPath<'_>, force: bool) -> HandleId { let server = self.clone(); let owned_path = asset_path.to_owned(); - self.server - .task_pool + IoTaskPool::get() .spawn(async move { if let Err(err) = server.load_async(owned_path, force).await { warn!("{}", err); @@ -620,8 +616,8 @@ mod test { fn setup(asset_path: impl AsRef) -> AssetServer { use crate::FileAssetIo; - - AssetServer::new(FileAssetIo::new(asset_path, false), Default::default()) + IoTaskPool::init(Default::default); + AssetServer::new(FileAssetIo::new(asset_path, false)) } #[test] diff --git a/crates/bevy_asset/src/debug_asset_server.rs b/crates/bevy_asset/src/debug_asset_server.rs index e53646de95..b7e3c71c85 100644 --- a/crates/bevy_asset/src/debug_asset_server.rs +++ b/crates/bevy_asset/src/debug_asset_server.rs @@ -58,14 +58,14 @@ impl Default for HandleMap { impl Plugin for DebugAssetServerPlugin { fn build(&self, app: &mut bevy_app::App) { + IoTaskPool::init(|| { + TaskPoolBuilder::default() + .num_threads(2) + .thread_name("Debug Asset Server IO Task Pool".to_string()) + .build() + }); let mut debug_asset_app = App::new(); debug_asset_app - .insert_resource(IoTaskPool( - TaskPoolBuilder::default() - .num_threads(2) - .thread_name("Debug Asset Server IO Task Pool".to_string()) - .build(), - )) .insert_resource(AssetServerSettings { asset_folder: "crates".to_string(), watch_for_changes: true, diff --git a/crates/bevy_asset/src/lib.rs b/crates/bevy_asset/src/lib.rs index 870f100d10..b5ba1a1854 100644 --- a/crates/bevy_asset/src/lib.rs +++ b/crates/bevy_asset/src/lib.rs @@ -30,7 +30,6 @@ pub use path::*; use bevy_app::{prelude::Plugin, App}; use bevy_ecs::schedule::{StageLabel, SystemStage}; -use bevy_tasks::IoTaskPool; /// The names of asset stages in an App Schedule #[derive(Debug, Hash, PartialEq, Eq, Clone, StageLabel)] @@ -82,12 +81,8 @@ pub fn create_platform_default_asset_io(app: &mut App) -> Box { impl Plugin for AssetPlugin { fn build(&self, app: &mut App) { if !app.world.contains_resource::() { - let task_pool = app.world.resource::().0.clone(); - let source = create_platform_default_asset_io(app); - - let asset_server = AssetServer::with_boxed_io(source, task_pool); - + let asset_server = AssetServer::with_boxed_io(source); app.insert_resource(asset_server); } diff --git a/crates/bevy_asset/src/loader.rs b/crates/bevy_asset/src/loader.rs index 5a5de9b8c1..5d6b87d838 100644 --- a/crates/bevy_asset/src/loader.rs +++ b/crates/bevy_asset/src/loader.rs @@ -5,7 +5,6 @@ use crate::{ use anyhow::Result; use bevy_ecs::system::{Res, ResMut}; use bevy_reflect::{TypeUuid, TypeUuidDynamic}; -use bevy_tasks::TaskPool; use bevy_utils::{BoxedFuture, HashMap}; use crossbeam_channel::{Receiver, Sender}; use downcast_rs::{impl_downcast, Downcast}; @@ -84,7 +83,6 @@ pub struct LoadContext<'a> { pub(crate) labeled_assets: HashMap, BoxedLoadedAsset>, pub(crate) path: &'a Path, pub(crate) version: usize, - pub(crate) task_pool: &'a TaskPool, } impl<'a> LoadContext<'a> { @@ -93,7 +91,6 @@ impl<'a> LoadContext<'a> { ref_change_channel: &'a RefChangeChannel, asset_io: &'a dyn AssetIo, version: usize, - task_pool: &'a TaskPool, ) -> Self { Self { ref_change_channel, @@ -101,7 +98,6 @@ impl<'a> LoadContext<'a> { labeled_assets: Default::default(), version, path, - task_pool, } } @@ -144,10 +140,6 @@ impl<'a> LoadContext<'a> { asset_metas } - pub fn task_pool(&self) -> &TaskPool { - self.task_pool - } - pub fn asset_io(&self) -> &dyn AssetIo { self.asset_io } diff --git a/crates/bevy_core/src/lib.rs b/crates/bevy_core/src/lib.rs index f4630bc668..7c7d193564 100644 --- a/crates/bevy_core/src/lib.rs +++ b/crates/bevy_core/src/lib.rs @@ -30,7 +30,7 @@ impl Plugin for CorePlugin { .get_resource::() .cloned() .unwrap_or_default() - .create_default_pools(&mut app.world); + .create_default_pools(); app.register_type::().register_type::(); diff --git a/crates/bevy_core/src/task_pool_options.rs b/crates/bevy_core/src/task_pool_options.rs index 19c9dad5bf..152489b7cf 100644 --- a/crates/bevy_core/src/task_pool_options.rs +++ b/crates/bevy_core/src/task_pool_options.rs @@ -1,4 +1,3 @@ -use bevy_ecs::world::World; use bevy_tasks::{AsyncComputeTaskPool, ComputeTaskPool, IoTaskPool, TaskPoolBuilder}; use bevy_utils::tracing::trace; @@ -93,14 +92,14 @@ impl DefaultTaskPoolOptions { } /// Inserts the default thread pools into the given resource map based on the configured values - pub fn create_default_pools(&self, world: &mut World) { + pub fn create_default_pools(&self) { let total_threads = bevy_tasks::logical_core_count().clamp(self.min_total_threads, self.max_total_threads); trace!("Assigning {} cores to default task pools", total_threads); let mut remaining_threads = total_threads; - if !world.contains_resource::() { + { // Determine the number of IO threads we will use let io_threads = self .io @@ -109,15 +108,15 @@ impl DefaultTaskPoolOptions { trace!("IO Threads: {}", io_threads); remaining_threads = remaining_threads.saturating_sub(io_threads); - world.insert_resource(IoTaskPool( + IoTaskPool::init(|| { TaskPoolBuilder::default() .num_threads(io_threads) .thread_name("IO Task Pool".to_string()) - .build(), - )); + .build() + }); } - if !world.contains_resource::() { + { // Determine the number of async compute threads we will use let async_compute_threads = self .async_compute @@ -126,15 +125,15 @@ impl DefaultTaskPoolOptions { trace!("Async Compute Threads: {}", async_compute_threads); remaining_threads = remaining_threads.saturating_sub(async_compute_threads); - world.insert_resource(AsyncComputeTaskPool( + AsyncComputeTaskPool::init(|| { TaskPoolBuilder::default() .num_threads(async_compute_threads) .thread_name("Async Compute Task Pool".to_string()) - .build(), - )); + .build() + }); } - if !world.contains_resource::() { + { // Determine the number of compute threads we will use // This is intentionally last so that an end user can specify 1.0 as the percent let compute_threads = self @@ -142,12 +141,13 @@ impl DefaultTaskPoolOptions { .get_number_of_threads(remaining_threads, total_threads); trace!("Compute Threads: {}", compute_threads); - world.insert_resource(ComputeTaskPool( + + ComputeTaskPool::init(|| { TaskPoolBuilder::default() .num_threads(compute_threads) .thread_name("Compute Task Pool".to_string()) - .build(), - )); + .build() + }); } } } diff --git a/crates/bevy_ecs/src/lib.rs b/crates/bevy_ecs/src/lib.rs index 56ac5db9ac..30a373a9d5 100644 --- a/crates/bevy_ecs/src/lib.rs +++ b/crates/bevy_ecs/src/lib.rs @@ -375,8 +375,8 @@ mod tests { #[test] fn par_for_each_dense() { + ComputeTaskPool::init(TaskPool::default); let mut world = World::new(); - world.insert_resource(ComputeTaskPool(TaskPool::default())); let e1 = world.spawn().insert(A(1)).id(); let e2 = world.spawn().insert(A(2)).id(); let e3 = world.spawn().insert(A(3)).id(); @@ -397,8 +397,8 @@ mod tests { #[test] fn par_for_each_sparse() { + ComputeTaskPool::init(TaskPool::default); let mut world = World::new(); - world.insert_resource(ComputeTaskPool(TaskPool::default())); let e1 = world.spawn().insert(SparseStored(1)).id(); let e2 = world.spawn().insert(SparseStored(2)).id(); let e3 = world.spawn().insert(SparseStored(3)).id(); diff --git a/crates/bevy_ecs/src/query/state.rs b/crates/bevy_ecs/src/query/state.rs index 64a80313c6..a01fc4019d 100644 --- a/crates/bevy_ecs/src/query/state.rs +++ b/crates/bevy_ecs/src/query/state.rs @@ -10,18 +10,17 @@ use crate::{ storage::TableId, world::{World, WorldId}, }; -use bevy_tasks::{ComputeTaskPool, TaskPool}; +use bevy_tasks::ComputeTaskPool; #[cfg(feature = "trace")] use bevy_utils::tracing::Instrument; use fixedbitset::FixedBitSet; -use std::{borrow::Borrow, fmt, ops::Deref}; +use std::{borrow::Borrow, fmt}; use super::{QueryFetch, QueryItem, QueryManyIter, ROQueryFetch, ROQueryItem}; /// Provides scoped access to a [`World`] state according to a given [`WorldQuery`] and query filter. pub struct QueryState { world_id: WorldId, - task_pool: Option, pub(crate) archetype_generation: ArchetypeGeneration, pub(crate) matched_tables: FixedBitSet, pub(crate) matched_archetypes: FixedBitSet, @@ -62,9 +61,6 @@ impl QueryState { let mut state = Self { world_id: world.id(), - task_pool: world - .get_resource::() - .map(|task_pool| task_pool.deref().clone()), archetype_generation: ArchetypeGeneration::initial(), matched_table_ids: Vec::new(), matched_archetype_ids: Vec::new(), @@ -754,8 +750,8 @@ impl QueryState { /// write-queries. /// /// # Panics - /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. #[inline] pub fn par_for_each<'w, FN: Fn(ROQueryItem<'w, Q>) + Send + Sync + Clone>( &mut self, @@ -779,8 +775,8 @@ impl QueryState { /// Runs `func` on each query result in parallel. /// /// # Panics - /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. #[inline] pub fn par_for_each_mut<'w, FN: Fn(QueryItem<'w, Q>) + Send + Sync + Clone>( &mut self, @@ -806,8 +802,8 @@ impl QueryState { /// This can only be called for read-only queries. /// /// # Panics - /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. /// /// # Safety /// @@ -922,8 +918,8 @@ impl QueryState { /// iter() method, but cannot be chained like a normal [`Iterator`]. /// /// # Panics - /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. /// /// # Safety /// @@ -945,106 +941,95 @@ impl QueryState { ) { // NOTE: If you are changing query iteration code, remember to update the following places, where relevant: // QueryIter, QueryIterationCursor, QueryState::for_each_unchecked_manual, QueryState::many_for_each_unchecked_manual, QueryState::par_for_each_unchecked_manual - self.task_pool - .as_ref() - .expect("Cannot iterate query in parallel. No ComputeTaskPool initialized.") - .scope(|scope| { - if QF::IS_DENSE && >::IS_DENSE { - let tables = &world.storages().tables; - for table_id in &self.matched_table_ids { - let table = &tables[*table_id]; - let mut offset = 0; - while offset < table.len() { - let func = func.clone(); - let len = batch_size.min(table.len() - offset); - let task = async move { - let mut fetch = QF::init( - world, - &self.fetch_state, - last_change_tick, - change_tick, - ); - let mut filter = as Fetch>::init( - world, - &self.filter_state, - last_change_tick, - change_tick, - ); - let tables = &world.storages().tables; - let table = &tables[*table_id]; - fetch.set_table(&self.fetch_state, table); - filter.set_table(&self.filter_state, table); - for table_index in offset..offset + len { - if !filter.table_filter_fetch(table_index) { - continue; - } - let item = fetch.table_fetch(table_index); - func(item); - } - }; - #[cfg(feature = "trace")] - let span = bevy_utils::tracing::info_span!( - "par_for_each", - query = std::any::type_name::(), - filter = std::any::type_name::(), - count = len, + ComputeTaskPool::get().scope(|scope| { + if QF::IS_DENSE && >::IS_DENSE { + let tables = &world.storages().tables; + for table_id in &self.matched_table_ids { + let table = &tables[*table_id]; + let mut offset = 0; + while offset < table.len() { + let func = func.clone(); + let len = batch_size.min(table.len() - offset); + let task = async move { + let mut fetch = + QF::init(world, &self.fetch_state, last_change_tick, change_tick); + let mut filter = as Fetch>::init( + world, + &self.filter_state, + last_change_tick, + change_tick, ); - #[cfg(feature = "trace")] - let task = task.instrument(span); - scope.spawn(task); - offset += batch_size; - } - } - } else { - let archetypes = &world.archetypes; - for archetype_id in &self.matched_archetype_ids { - let mut offset = 0; - let archetype = &archetypes[*archetype_id]; - while offset < archetype.len() { - let func = func.clone(); - let len = batch_size.min(archetype.len() - offset); - let task = async move { - let mut fetch = QF::init( - world, - &self.fetch_state, - last_change_tick, - change_tick, - ); - let mut filter = as Fetch>::init( - world, - &self.filter_state, - last_change_tick, - change_tick, - ); - let tables = &world.storages().tables; - let archetype = &world.archetypes[*archetype_id]; - fetch.set_archetype(&self.fetch_state, archetype, tables); - filter.set_archetype(&self.filter_state, archetype, tables); - - for archetype_index in offset..offset + len { - if !filter.archetype_filter_fetch(archetype_index) { - continue; - } - func(fetch.archetype_fetch(archetype_index)); + let tables = &world.storages().tables; + let table = &tables[*table_id]; + fetch.set_table(&self.fetch_state, table); + filter.set_table(&self.filter_state, table); + for table_index in offset..offset + len { + if !filter.table_filter_fetch(table_index) { + continue; } - }; - - #[cfg(feature = "trace")] - let span = bevy_utils::tracing::info_span!( - "par_for_each", - query = std::any::type_name::(), - filter = std::any::type_name::(), - count = len, - ); - #[cfg(feature = "trace")] - let task = task.instrument(span); - - scope.spawn(task); - offset += batch_size; - } + let item = fetch.table_fetch(table_index); + func(item); + } + }; + #[cfg(feature = "trace")] + let span = bevy_utils::tracing::info_span!( + "par_for_each", + query = std::any::type_name::(), + filter = std::any::type_name::(), + count = len, + ); + #[cfg(feature = "trace")] + let task = task.instrument(span); + scope.spawn(task); + offset += batch_size; } } - }); + } else { + let archetypes = &world.archetypes; + for archetype_id in &self.matched_archetype_ids { + let mut offset = 0; + let archetype = &archetypes[*archetype_id]; + while offset < archetype.len() { + let func = func.clone(); + let len = batch_size.min(archetype.len() - offset); + let task = async move { + let mut fetch = + QF::init(world, &self.fetch_state, last_change_tick, change_tick); + let mut filter = as Fetch>::init( + world, + &self.filter_state, + last_change_tick, + change_tick, + ); + let tables = &world.storages().tables; + let archetype = &world.archetypes[*archetype_id]; + fetch.set_archetype(&self.fetch_state, archetype, tables); + filter.set_archetype(&self.filter_state, archetype, tables); + + for archetype_index in offset..offset + len { + if !filter.archetype_filter_fetch(archetype_index) { + continue; + } + func(fetch.archetype_fetch(archetype_index)); + } + }; + + #[cfg(feature = "trace")] + let span = bevy_utils::tracing::info_span!( + "par_for_each", + query = std::any::type_name::(), + filter = std::any::type_name::(), + count = len, + ); + #[cfg(feature = "trace")] + let task = task.instrument(span); + + scope.spawn(task); + offset += batch_size; + } + } + } + }); } /// Runs `func` on each query result for the given [`World`] and list of [`Entity`]'s, where the last change and diff --git a/crates/bevy_ecs/src/schedule/executor_parallel.rs b/crates/bevy_ecs/src/schedule/executor_parallel.rs index 149d8d02bc..c82924b0e2 100644 --- a/crates/bevy_ecs/src/schedule/executor_parallel.rs +++ b/crates/bevy_ecs/src/schedule/executor_parallel.rs @@ -123,10 +123,7 @@ impl ParallelSystemExecutor for ParallelExecutor { } } - let compute_pool = world - .get_resource_or_insert_with(|| ComputeTaskPool(TaskPool::default())) - .clone(); - compute_pool.scope(|scope| { + ComputeTaskPool::init(TaskPool::default).scope(|scope| { self.prepare_systems(scope, systems, world); let parallel_executor = async { // All systems have been ran if there are no queued or running systems. diff --git a/crates/bevy_ecs/src/system/query.rs b/crates/bevy_ecs/src/system/query.rs index 484d119e41..0bfe2711dc 100644 --- a/crates/bevy_ecs/src/system/query.rs +++ b/crates/bevy_ecs/src/system/query.rs @@ -587,8 +587,8 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { ///* `f` - The function to run on each item in the query /// /// # Panics - /// The [`ComputeTaskPool`] resource must be added to the `World` before using this method. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. /// /// [`ComputeTaskPool`]: bevy_tasks::prelude::ComputeTaskPool #[inline] @@ -615,8 +615,8 @@ impl<'w, 's, Q: WorldQuery, F: WorldQuery> Query<'w, 's, Q, F> { /// See [`Self::par_for_each`] for more details. /// /// # Panics - /// [`ComputeTaskPool`] was not stored in the world at initialzation. If using this from a query - /// that is being initialized and run from the ECS scheduler, this should never panic. + /// The [`ComputeTaskPool`] is not initialized. If using this from a query that is being + /// initialized and run from the ECS scheduler, this should never panic. /// /// [`ComputeTaskPool`]: bevy_tasks::prelude::ComputeTaskPool #[inline] diff --git a/crates/bevy_gltf/Cargo.toml b/crates/bevy_gltf/Cargo.toml index c35ac483c0..5d9ab53ad2 100644 --- a/crates/bevy_gltf/Cargo.toml +++ b/crates/bevy_gltf/Cargo.toml @@ -24,6 +24,7 @@ bevy_reflect = { path = "../bevy_reflect", version = "0.8.0-dev", features = ["b bevy_render = { path = "../bevy_render", version = "0.8.0-dev" } bevy_scene = { path = "../bevy_scene", version = "0.8.0-dev" } bevy_transform = { path = "../bevy_transform", version = "0.8.0-dev" } +bevy_tasks = { path = "../bevy_tasks", version = "0.8.0-dev" } bevy_utils = { path = "../bevy_utils", version = "0.8.0-dev" } # other diff --git a/crates/bevy_gltf/src/loader.rs b/crates/bevy_gltf/src/loader.rs index c699a58945..a5976530e7 100644 --- a/crates/bevy_gltf/src/loader.rs +++ b/crates/bevy_gltf/src/loader.rs @@ -29,6 +29,7 @@ use bevy_render::{ view::VisibleEntities, }; use bevy_scene::Scene; +use bevy_tasks::IoTaskPool; use bevy_transform::{components::Transform, TransformBundle}; use bevy_utils::{HashMap, HashSet}; @@ -410,8 +411,7 @@ async fn load_gltf<'a, 'b>( } } else { #[cfg(not(target_arch = "wasm32"))] - load_context - .task_pool() + IoTaskPool::get() .scope(|scope| { gltf.textures().for_each(|gltf_texture| { let linear_textures = &linear_textures; diff --git a/crates/bevy_tasks/Cargo.toml b/crates/bevy_tasks/Cargo.toml index 7b83b9bc34..06a0da4569 100644 --- a/crates/bevy_tasks/Cargo.toml +++ b/crates/bevy_tasks/Cargo.toml @@ -13,7 +13,8 @@ futures-lite = "1.4.0" event-listener = "2.5.2" async-executor = "1.3.0" async-channel = "1.4.2" -num_cpus = "1.0.1" +num_cpus = "1" +once_cell = "1.7" [target.'cfg(target_arch = "wasm32")'.dependencies] wasm-bindgen-futures = "0.4" diff --git a/crates/bevy_tasks/src/task_pool.rs b/crates/bevy_tasks/src/task_pool.rs index ebd6ba6b41..1d0f86e7cb 100644 --- a/crates/bevy_tasks/src/task_pool.rs +++ b/crates/bevy_tasks/src/task_pool.rs @@ -60,29 +60,9 @@ impl TaskPoolBuilder { } } -#[derive(Debug)] -struct TaskPoolInner { - threads: Vec>, - shutdown_tx: async_channel::Sender<()>, -} - -impl Drop for TaskPoolInner { - fn drop(&mut self) { - self.shutdown_tx.close(); - - let panicking = thread::panicking(); - for join_handle in self.threads.drain(..) { - let res = join_handle.join(); - if !panicking { - res.expect("Task thread panicked while executing."); - } - } - } -} - /// A thread pool for executing tasks. Tasks are futures that are being automatically driven by /// the pool on threads owned by the pool. -#[derive(Debug, Clone)] +#[derive(Debug)] pub struct TaskPool { /// The executor for the pool /// @@ -92,7 +72,8 @@ pub struct TaskPool { executor: Arc>, /// Inner state of the pool - inner: Arc, + threads: Vec>, + shutdown_tx: async_channel::Sender<()>, } impl TaskPool { @@ -155,16 +136,14 @@ impl TaskPool { Self { executor, - inner: Arc::new(TaskPoolInner { - threads, - shutdown_tx, - }), + threads, + shutdown_tx, } } /// Return the number of threads owned by the task pool pub fn thread_num(&self) -> usize { - self.inner.threads.len() + self.threads.len() } /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback, @@ -268,6 +247,20 @@ impl Default for TaskPool { } } +impl Drop for TaskPool { + fn drop(&mut self) { + self.shutdown_tx.close(); + + let panicking = thread::panicking(); + for join_handle in self.threads.drain(..) { + let res = join_handle.join(); + if !panicking { + res.expect("Task thread panicked while executing."); + } + } + } +} + /// A `TaskPool` scope for running one or more non-`'static` futures. /// /// For more information, see [`TaskPool::scope`]. diff --git a/crates/bevy_tasks/src/usages.rs b/crates/bevy_tasks/src/usages.rs index 923c1a7eb4..419d842f47 100644 --- a/crates/bevy_tasks/src/usages.rs +++ b/crates/bevy_tasks/src/usages.rs @@ -11,12 +11,35 @@ //! for consumption. (likely via channels) use super::TaskPool; +use once_cell::sync::OnceCell; use std::ops::Deref; +static COMPUTE_TASK_POOL: OnceCell = OnceCell::new(); +static ASYNC_COMPUTE_TASK_POOL: OnceCell = OnceCell::new(); +static IO_TASK_POOL: OnceCell = OnceCell::new(); + /// A newtype for a task pool for CPU-intensive work that must be completed to deliver the next /// frame -#[derive(Clone, Debug)] -pub struct ComputeTaskPool(pub TaskPool); +#[derive(Debug)] +pub struct ComputeTaskPool(TaskPool); + +impl ComputeTaskPool { + /// Initializes the global [`ComputeTaskPool`] instance. + pub fn init(f: impl FnOnce() -> TaskPool) -> &'static Self { + COMPUTE_TASK_POOL.get_or_init(|| Self(f())) + } + + /// Gets the global [`ComputeTaskPool`] instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + COMPUTE_TASK_POOL.get().expect( + "A ComputeTaskPool has not been initialized yet. Please call \ + ComputeTaskPool::init beforehand.", + ) + } +} impl Deref for ComputeTaskPool { type Target = TaskPool; @@ -27,8 +50,26 @@ impl Deref for ComputeTaskPool { } /// A newtype for a task pool for CPU-intensive work that may span across multiple frames -#[derive(Clone, Debug)] -pub struct AsyncComputeTaskPool(pub TaskPool); +#[derive(Debug)] +pub struct AsyncComputeTaskPool(TaskPool); + +impl AsyncComputeTaskPool { + /// Initializes the global [`AsyncComputeTaskPool`] instance. + pub fn init(f: impl FnOnce() -> TaskPool) -> &'static Self { + ASYNC_COMPUTE_TASK_POOL.get_or_init(|| Self(f())) + } + + /// Gets the global [`AsyncComputeTaskPool`] instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + ASYNC_COMPUTE_TASK_POOL.get().expect( + "A AsyncComputeTaskPool has not been initialized yet. Please call \ + AsyncComputeTaskPool::init beforehand.", + ) + } +} impl Deref for AsyncComputeTaskPool { type Target = TaskPool; @@ -40,8 +81,26 @@ impl Deref for AsyncComputeTaskPool { /// A newtype for a task pool for IO-intensive work (i.e. tasks that spend very little time in a /// "woken" state) -#[derive(Clone, Debug)] -pub struct IoTaskPool(pub TaskPool); +#[derive(Debug)] +pub struct IoTaskPool(TaskPool); + +impl IoTaskPool { + /// Initializes the global [`IoTaskPool`] instance. + pub fn init(f: impl FnOnce() -> TaskPool) -> &'static Self { + IO_TASK_POOL.get_or_init(|| Self(f())) + } + + /// Gets the global [`IoTaskPool`] instance. + /// + /// # Panics + /// Panics if no pool has been initialized yet. + pub fn get() -> &'static Self { + IO_TASK_POOL.get().expect( + "A IoTaskPool has not been initialized yet. Please call \ + IoTaskPool::init beforehand.", + ) + } +} impl Deref for IoTaskPool { type Target = TaskPool; diff --git a/examples/asset/custom_asset_io.rs b/examples/asset/custom_asset_io.rs index f601ebad24..7a58750a1c 100644 --- a/examples/asset/custom_asset_io.rs +++ b/examples/asset/custom_asset_io.rs @@ -51,10 +51,6 @@ struct CustomAssetIoPlugin; impl Plugin for CustomAssetIoPlugin { fn build(&self, app: &mut App) { - // must get a hold of the task pool in order to create the asset server - - let task_pool = app.world.resource::().0.clone(); - let asset_io = { // the platform default asset io requires a reference to the app // builder to find its configuration @@ -68,7 +64,7 @@ impl Plugin for CustomAssetIoPlugin { // the asset server is constructed and added the resource manager - app.insert_resource(AssetServer::new(asset_io, task_pool)); + app.insert_resource(AssetServer::new(asset_io)); } } diff --git a/examples/async_tasks/async_compute.rs b/examples/async_tasks/async_compute.rs index e50e3ad982..e01ebd4c9b 100644 --- a/examples/async_tasks/async_compute.rs +++ b/examples/async_tasks/async_compute.rs @@ -50,7 +50,8 @@ struct ComputeTransform(Task); /// work that potentially spans multiple frames/ticks. A separate /// system, `handle_tasks`, will poll the spawned tasks on subsequent /// frames/ticks, and use the results to spawn cubes -fn spawn_tasks(mut commands: Commands, thread_pool: Res) { +fn spawn_tasks(mut commands: Commands) { + let thread_pool = AsyncComputeTaskPool::get(); for x in 0..NUM_CUBES { for y in 0..NUM_CUBES { for z in 0..NUM_CUBES {