Reworked parallel executor to not block (#437)
Reworked parallel executor to not block
This commit is contained in:
parent
8677e36681
commit
8b3553002d
@ -52,7 +52,7 @@ impl Default for DefaultTaskPoolOptions {
|
||||
fn default() -> Self {
|
||||
DefaultTaskPoolOptions {
|
||||
// By default, use however many cores are available on the system
|
||||
min_total_threads: 4, // TODO(#408): set `min_total_threads` back to `1`
|
||||
min_total_threads: 1,
|
||||
max_total_threads: std::usize::MAX,
|
||||
|
||||
// Use 25% of cores for IO, at least 1, no more than 4
|
||||
|
@ -18,7 +18,7 @@ bevy_hecs = { path = "hecs", features = ["macros", "serialize"], version = "0.1"
|
||||
bevy_tasks = { path = "../bevy_tasks", version = "0.1" }
|
||||
bevy_utils = { path = "../bevy_utils", version = "0.1" }
|
||||
rand = "0.7.2"
|
||||
crossbeam-channel = "0.4.2"
|
||||
fixedbitset = "0.3.0"
|
||||
downcast-rs = "1.1.1"
|
||||
parking_lot = "0.10"
|
||||
log = { version = "0.4", features = ["release_max_level_info"] }
|
||||
|
@ -4,7 +4,7 @@ use crate::{
|
||||
system::{ArchetypeAccess, System, ThreadLocalExecution, TypeAccess},
|
||||
};
|
||||
use bevy_hecs::{ArchetypesGeneration, World};
|
||||
use crossbeam_channel::{Receiver, Sender};
|
||||
use bevy_tasks::{ComputeTaskPool, CountdownEvent, TaskPool};
|
||||
use fixedbitset::FixedBitSet;
|
||||
use parking_lot::Mutex;
|
||||
use std::{ops::Range, sync::Arc};
|
||||
@ -52,6 +52,7 @@ impl ParallelExecutor {
|
||||
}
|
||||
for (stage_name, executor_stage) in schedule.stage_order.iter().zip(self.stages.iter_mut())
|
||||
{
|
||||
log::trace!("run stage {:?}", stage_name);
|
||||
if let Some(stage_systems) = schedule.stages.get_mut(stage_name) {
|
||||
executor_stage.run(world, resources, stage_systems, schedule_changed);
|
||||
}
|
||||
@ -69,68 +70,64 @@ impl ParallelExecutor {
|
||||
pub struct ExecutorStage {
|
||||
/// each system's set of dependencies
|
||||
system_dependencies: Vec<FixedBitSet>,
|
||||
/// count of each system's dependencies
|
||||
system_dependency_count: Vec<usize>,
|
||||
/// Countdown of finished dependencies, used to trigger the next system
|
||||
ready_events: Vec<Option<CountdownEvent>>,
|
||||
/// When a system finishes, it will decrement the countdown events of all dependents
|
||||
ready_events_of_dependents: Vec<Vec<CountdownEvent>>,
|
||||
/// each system's dependents (the systems that can't run until this system has run)
|
||||
system_dependents: Vec<Vec<usize>>,
|
||||
/// stores the indices of thread local systems in this stage, which are used during stage.prepare()
|
||||
thread_local_system_indices: Vec<usize>,
|
||||
next_thread_local_index: usize,
|
||||
/// the currently finished systems
|
||||
finished_systems: FixedBitSet,
|
||||
running_systems: FixedBitSet,
|
||||
|
||||
sender: Sender<usize>,
|
||||
receiver: Receiver<usize>,
|
||||
/// When archetypes change a counter is bumped - we cache the state of that counter when it was
|
||||
/// last read here so that we can detect when archetypes are changed
|
||||
last_archetypes_generation: ArchetypesGeneration,
|
||||
}
|
||||
|
||||
impl Default for ExecutorStage {
|
||||
fn default() -> Self {
|
||||
let (sender, receiver) = crossbeam_channel::unbounded();
|
||||
Self {
|
||||
system_dependents: Default::default(),
|
||||
system_dependency_count: Default::default(),
|
||||
ready_events: Default::default(),
|
||||
ready_events_of_dependents: Default::default(),
|
||||
system_dependencies: Default::default(),
|
||||
thread_local_system_indices: Default::default(),
|
||||
next_thread_local_index: 0,
|
||||
finished_systems: Default::default(),
|
||||
running_systems: Default::default(),
|
||||
sender,
|
||||
receiver,
|
||||
last_archetypes_generation: ArchetypesGeneration(u64::MAX), // MAX forces prepare to run the first time
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum RunReadyResult {
|
||||
Ok,
|
||||
ThreadLocalReady(usize),
|
||||
}
|
||||
|
||||
enum RunReadyType {
|
||||
Range(Range<usize>),
|
||||
Dependents(usize),
|
||||
}
|
||||
|
||||
impl ExecutorStage {
|
||||
/// Sets up state to run the next "batch" of systems. Each batch contains 0..n systems and
|
||||
/// optionally a thread local system at the end. After this function runs, a bunch of state
|
||||
/// in self will be populated for systems in this batch. Returns the range of systems
|
||||
/// that we prepared, up to but NOT including the thread local system that MIGHT be at the end
|
||||
/// of the range
|
||||
pub fn prepare_to_next_thread_local(
|
||||
&mut self,
|
||||
world: &World,
|
||||
systems: &[Arc<Mutex<Box<dyn System>>>],
|
||||
schedule_changed: bool,
|
||||
) {
|
||||
let (prepare_system_start_index, last_thread_local_index) =
|
||||
if self.next_thread_local_index == 0 {
|
||||
(0, None)
|
||||
} else {
|
||||
// start right after the last thread local system
|
||||
(
|
||||
self.thread_local_system_indices[self.next_thread_local_index - 1] + 1,
|
||||
Some(self.thread_local_system_indices[self.next_thread_local_index - 1]),
|
||||
)
|
||||
};
|
||||
next_thread_local_index: usize,
|
||||
) -> Range<usize> {
|
||||
// Find the first system in this batch and (if there is one) the thread local system that
|
||||
// ends it.
|
||||
let (prepare_system_start_index, last_thread_local_index) = if next_thread_local_index == 0
|
||||
{
|
||||
(0, None)
|
||||
} else {
|
||||
// start right after the last thread local system
|
||||
(
|
||||
self.thread_local_system_indices[next_thread_local_index - 1] + 1,
|
||||
Some(self.thread_local_system_indices[next_thread_local_index - 1]),
|
||||
)
|
||||
};
|
||||
|
||||
let prepare_system_index_range = if let Some(index) = self
|
||||
.thread_local_system_indices
|
||||
.get(self.next_thread_local_index)
|
||||
.get(next_thread_local_index)
|
||||
{
|
||||
// if there is an upcoming thread local system, prepare up to (and including) it
|
||||
prepare_system_start_index..(*index + 1)
|
||||
@ -206,73 +203,136 @@ impl ExecutorStage {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Now that system_dependents and system_dependencies is populated, update
|
||||
// system_dependency_count and ready_events
|
||||
for system_index in prepare_system_index_range.clone() {
|
||||
// Count all dependencies to update system_dependency_count
|
||||
assert!(!self.system_dependencies[system_index].contains(system_index));
|
||||
let dependency_count = self.system_dependencies[system_index].count_ones(..);
|
||||
self.system_dependency_count[system_index] = dependency_count;
|
||||
|
||||
// If dependency count > 0, allocate a ready_event
|
||||
self.ready_events[system_index] = match self.system_dependency_count[system_index] {
|
||||
0 => None,
|
||||
dependency_count => Some(CountdownEvent::new(dependency_count as isize)),
|
||||
}
|
||||
}
|
||||
|
||||
// Now that ready_events are created, we can build ready_events_of_dependents
|
||||
for system_index in prepare_system_index_range.clone() {
|
||||
for dependent_system in &self.system_dependents[system_index] {
|
||||
self.ready_events_of_dependents[system_index].push(
|
||||
self.ready_events[*dependent_system]
|
||||
.as_ref()
|
||||
.expect("A dependent task should have a non-None ready event")
|
||||
.clone(),
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Reset the countdown events for this range of systems. Resetting is required even if the
|
||||
// schedule didn't change
|
||||
self.reset_system_ready_events(prepare_system_index_range);
|
||||
}
|
||||
|
||||
self.next_thread_local_index += 1;
|
||||
if let Some(index) = self
|
||||
.thread_local_system_indices
|
||||
.get(next_thread_local_index)
|
||||
{
|
||||
// if there is an upcoming thread local system, prepare up to (and NOT including) it
|
||||
prepare_system_start_index..(*index)
|
||||
} else {
|
||||
// if there are no upcoming thread local systems, prepare everything right now
|
||||
prepare_system_start_index..systems.len()
|
||||
}
|
||||
}
|
||||
|
||||
fn run_ready_systems<'run>(
|
||||
&mut self,
|
||||
fn reset_system_ready_events(&mut self, prepare_system_index_range: Range<usize>) {
|
||||
for system_index in prepare_system_index_range {
|
||||
let dependency_count = self.system_dependency_count[system_index];
|
||||
if dependency_count > 0 {
|
||||
self.ready_events[system_index]
|
||||
.as_ref()
|
||||
.expect("A system with >0 dependency count should have a non-None ready event")
|
||||
.reset(dependency_count as isize)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runs the non-thread-local systems in the given prepared_system_range range
|
||||
pub fn run_systems(
|
||||
&self,
|
||||
world: &World,
|
||||
resources: &Resources,
|
||||
systems: &[Arc<Mutex<Box<dyn System>>>],
|
||||
run_ready_type: RunReadyType,
|
||||
scope: &mut bevy_tasks::Scope<'run, ()>,
|
||||
world: &'run World,
|
||||
resources: &'run Resources,
|
||||
) -> RunReadyResult {
|
||||
// produce a system index iterator based on the passed in RunReadyType
|
||||
let mut all;
|
||||
let mut dependents;
|
||||
let system_index_iter: &mut dyn Iterator<Item = usize> = match run_ready_type {
|
||||
RunReadyType::Range(range) => {
|
||||
all = range;
|
||||
&mut all
|
||||
}
|
||||
RunReadyType::Dependents(system_index) => {
|
||||
dependents = self.system_dependents[system_index].iter().cloned();
|
||||
&mut dependents
|
||||
}
|
||||
};
|
||||
|
||||
let mut systems_currently_running = false;
|
||||
for system_index in system_index_iter {
|
||||
// if this system has already been run, don't try to run it again
|
||||
if self.running_systems.contains(system_index) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// if all system dependencies are finished, queue up the system to run
|
||||
if self.system_dependencies[system_index].is_subset(&self.finished_systems) {
|
||||
prepared_system_range: Range<usize>,
|
||||
compute_pool: &TaskPool,
|
||||
) {
|
||||
// Generate tasks for systems in the given range and block until they are complete
|
||||
log::trace!("running systems {:?}", prepared_system_range);
|
||||
compute_pool.scope(|scope| {
|
||||
let start_system_index = prepared_system_range.start;
|
||||
for system_index in prepared_system_range {
|
||||
let system = systems[system_index].clone();
|
||||
|
||||
// handle thread local system
|
||||
{
|
||||
let system = system.lock();
|
||||
if let ThreadLocalExecution::Immediate = system.thread_local_execution() {
|
||||
if systems_currently_running {
|
||||
// if systems are currently running, we can't run this thread local system yet
|
||||
continue;
|
||||
} else {
|
||||
// if no systems are running, return this thread local system to be run exclusively
|
||||
return RunReadyResult::ThreadLocalReady(system_index);
|
||||
log::trace!(
|
||||
"prepare {} {} with {} dependents and {} dependencies",
|
||||
system_index,
|
||||
system.lock().name(),
|
||||
self.system_dependents[system_index].len(),
|
||||
self.system_dependencies[system_index].count_ones(..)
|
||||
);
|
||||
|
||||
for dependency in self.system_dependencies[system_index].ones() {
|
||||
log::trace!(" * Depends on {}", systems[dependency].lock().name());
|
||||
}
|
||||
|
||||
// This event will be awaited, preventing the task from starting until all
|
||||
// our dependencies finish running
|
||||
let ready_event = &self.ready_events[system_index];
|
||||
|
||||
// Clear any dependencies on systems before this range of systems. We know at this
|
||||
// point everything before start_system_index is finished, and our ready_event did
|
||||
// not exist to be decremented until we started processing this range
|
||||
if start_system_index != 0 {
|
||||
if let Some(ready_event) = ready_event.as_ref() {
|
||||
for dependency in self.system_dependencies[system_index].ones() {
|
||||
log::trace!(" * Depends on {}", dependency);
|
||||
if dependency < start_system_index {
|
||||
ready_event.decrement();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handle multi-threaded system
|
||||
let sender = self.sender.clone();
|
||||
self.running_systems.insert(system_index);
|
||||
let world_ref = &*world;
|
||||
let resources_ref = &*resources;
|
||||
|
||||
let trigger_events = &self.ready_events_of_dependents[system_index];
|
||||
|
||||
// Spawn the task
|
||||
scope.spawn(async move {
|
||||
let mut system = system.lock();
|
||||
system.run(world, resources);
|
||||
sender.send(system_index).unwrap();
|
||||
// Wait until our dependencies are done
|
||||
if let Some(ready_event) = ready_event {
|
||||
ready_event.listen().await;
|
||||
}
|
||||
|
||||
// Execute the system - in a scope to ensure the system lock is dropped before
|
||||
// triggering dependents
|
||||
{
|
||||
let mut system = system.lock();
|
||||
log::trace!("run {}", system.name());
|
||||
system.run(world_ref, resources_ref);
|
||||
}
|
||||
|
||||
// Notify dependents that this task is done
|
||||
for trigger_event in trigger_events {
|
||||
trigger_event.decrement();
|
||||
}
|
||||
});
|
||||
|
||||
systems_currently_running = true;
|
||||
}
|
||||
}
|
||||
|
||||
RunReadyResult::Ok
|
||||
});
|
||||
}
|
||||
|
||||
pub fn run(
|
||||
@ -283,22 +343,27 @@ impl ExecutorStage {
|
||||
schedule_changed: bool,
|
||||
) {
|
||||
let start_archetypes_generation = world.archetypes_generation();
|
||||
let compute_pool = resources
|
||||
.get_cloned::<bevy_tasks::ComputeTaskPool>()
|
||||
.unwrap();
|
||||
let compute_pool = resources.get_cloned::<ComputeTaskPool>().unwrap();
|
||||
|
||||
// if the schedule has changed, clear executor state / fill it with new defaults
|
||||
// This is mostly zeroing out a bunch of arrays parallel to the systems array. They will get
|
||||
// repopulated by prepare_to_next_thread_local() calls
|
||||
if schedule_changed {
|
||||
self.system_dependencies.clear();
|
||||
self.system_dependencies
|
||||
.resize_with(systems.len(), || FixedBitSet::with_capacity(systems.len()));
|
||||
|
||||
self.system_dependency_count.clear();
|
||||
self.system_dependency_count.resize(systems.len(), 0);
|
||||
|
||||
self.thread_local_system_indices = Vec::new();
|
||||
|
||||
self.system_dependents.clear();
|
||||
self.system_dependents.resize(systems.len(), Vec::new());
|
||||
|
||||
self.finished_systems.grow(systems.len());
|
||||
self.running_systems.grow(systems.len());
|
||||
self.ready_events.resize(systems.len(), None);
|
||||
self.ready_events_of_dependents
|
||||
.resize(systems.len(), Vec::new());
|
||||
|
||||
for (system_index, system) in systems.iter().enumerate() {
|
||||
let system = system.lock();
|
||||
@ -308,76 +373,67 @@ impl ExecutorStage {
|
||||
}
|
||||
}
|
||||
|
||||
self.next_thread_local_index = 0;
|
||||
self.prepare_to_next_thread_local(world, systems, schedule_changed);
|
||||
// index of next thread local system in thread_local_system_indices. (always incremented by one
|
||||
// when prepare_to_next_thread_local is called. (We prepared up to index 0 above)
|
||||
let mut next_thread_local_index = 0;
|
||||
|
||||
self.finished_systems.clear();
|
||||
self.running_systems.clear();
|
||||
|
||||
let mut run_ready_result = RunReadyResult::Ok;
|
||||
let run_ready_system_index_range =
|
||||
if let Some(index) = self.thread_local_system_indices.get(0) {
|
||||
// if there is an upcoming thread local system, run up to (and including) it
|
||||
0..(*index + 1)
|
||||
} else {
|
||||
// if there are no upcoming thread local systems, run everything right now
|
||||
0..systems.len()
|
||||
};
|
||||
|
||||
compute_pool.scope(|scope| {
|
||||
run_ready_result = self.run_ready_systems(
|
||||
{
|
||||
// Prepare all system up to and including the first thread local system. This will return
|
||||
// the range of systems to run, up to but NOT including the next thread local
|
||||
let prepared_system_range = self.prepare_to_next_thread_local(
|
||||
world,
|
||||
systems,
|
||||
RunReadyType::Range(run_ready_system_index_range),
|
||||
scope,
|
||||
schedule_changed,
|
||||
next_thread_local_index,
|
||||
);
|
||||
|
||||
// Run everything up to the thread local system
|
||||
self.run_systems(
|
||||
world,
|
||||
resources,
|
||||
systems,
|
||||
prepared_system_range,
|
||||
&*compute_pool,
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
loop {
|
||||
// if all systems in the stage are finished, break out of the loop
|
||||
if self.finished_systems.count_ones(..) == systems.len() {
|
||||
// Bail if we have no more thread local systems
|
||||
if next_thread_local_index >= self.thread_local_system_indices.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
if let RunReadyResult::ThreadLocalReady(thread_local_index) = run_ready_result {
|
||||
// Run the thread local system at the end of the range of systems we just processed
|
||||
let thread_local_system_index =
|
||||
self.thread_local_system_indices[next_thread_local_index];
|
||||
{
|
||||
// if a thread local system is ready to run, run it exclusively on the main thread
|
||||
let mut system = systems[thread_local_index].lock();
|
||||
self.running_systems.insert(thread_local_index);
|
||||
let mut system = systems[thread_local_system_index].lock();
|
||||
log::trace!("running thread local system {}", system.name());
|
||||
system.run(world, resources);
|
||||
system.run_thread_local(world, resources);
|
||||
self.finished_systems.insert(thread_local_index);
|
||||
self.sender.send(thread_local_index).unwrap();
|
||||
|
||||
self.prepare_to_next_thread_local(world, systems, schedule_changed);
|
||||
|
||||
run_ready_result = RunReadyResult::Ok;
|
||||
} else {
|
||||
// wait for a system to finish, then run its dependents
|
||||
compute_pool.scope(|scope| {
|
||||
loop {
|
||||
// if all systems in the stage are finished, break out of the loop
|
||||
if self.finished_systems.count_ones(..) == systems.len() {
|
||||
break;
|
||||
}
|
||||
|
||||
let finished_system = self.receiver.recv().unwrap();
|
||||
self.finished_systems.insert(finished_system);
|
||||
run_ready_result = self.run_ready_systems(
|
||||
systems,
|
||||
RunReadyType::Dependents(finished_system),
|
||||
scope,
|
||||
world,
|
||||
resources,
|
||||
);
|
||||
|
||||
// if the next ready system is thread local, break out of this loop/bevy_tasks scope so it can be run
|
||||
if let RunReadyResult::ThreadLocalReady(_) = run_ready_result {
|
||||
break;
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Now that the previous thread local system has run, time to advance to the next one
|
||||
next_thread_local_index += 1;
|
||||
|
||||
// Prepare all systems up to and including the next thread local system. This will
|
||||
// return the range of systems to run, up to but NOT including the next thread local
|
||||
let run_ready_system_index_range = self.prepare_to_next_thread_local(
|
||||
world,
|
||||
systems,
|
||||
schedule_changed,
|
||||
next_thread_local_index,
|
||||
);
|
||||
|
||||
log::trace!("running systems {:?}", run_ready_system_index_range);
|
||||
self.run_systems(
|
||||
world,
|
||||
resources,
|
||||
systems,
|
||||
run_ready_system_index_range,
|
||||
&*compute_pool,
|
||||
);
|
||||
}
|
||||
|
||||
// "flush"
|
||||
@ -410,11 +466,11 @@ mod tests {
|
||||
use bevy_tasks::{ComputeTaskPool, TaskPool};
|
||||
use fixedbitset::FixedBitSet;
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::Arc;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
#[derive(Default)]
|
||||
struct Counter {
|
||||
count: Arc<Mutex<usize>>,
|
||||
struct CompletedSystems {
|
||||
completed_systems: Arc<Mutex<HashSet<&'static str>>>,
|
||||
}
|
||||
|
||||
#[test]
|
||||
@ -483,7 +539,7 @@ mod tests {
|
||||
let mut world = World::new();
|
||||
let mut resources = Resources::default();
|
||||
resources.insert(ComputeTaskPool(TaskPool::default()));
|
||||
resources.insert(Counter::default());
|
||||
resources.insert(CompletedSystems::default());
|
||||
resources.insert(1.0f64);
|
||||
resources.insert(2isize);
|
||||
|
||||
@ -496,30 +552,51 @@ mod tests {
|
||||
schedule.add_stage("B"); // thread local
|
||||
schedule.add_stage("C"); // resources
|
||||
|
||||
// A system names
|
||||
const READ_U32_SYSTEM_NAME: &str = "read_u32";
|
||||
const WRITE_FLOAT_SYSTEM_NAME: &str = "write_float";
|
||||
const READ_U32_WRITE_U64_SYSTEM_NAME: &str = "read_u32_write_u64";
|
||||
const READ_U64_SYSTEM_NAME: &str = "read_u64";
|
||||
|
||||
// B system names
|
||||
const WRITE_U64_SYSTEM_NAME: &str = "write_u64";
|
||||
const THREAD_LOCAL_SYSTEM_SYSTEM_NAME: &str = "thread_local_system";
|
||||
const WRITE_F32_SYSTEM_NAME: &str = "write_f32";
|
||||
|
||||
// C system names
|
||||
const READ_F64_RES_SYSTEM_NAME: &str = "read_f64_res";
|
||||
const READ_ISIZE_RES_SYSTEM_NAME: &str = "read_isize_res";
|
||||
const READ_ISIZE_WRITE_F64_RES_SYSTEM_NAME: &str = "read_isize_write_f64_res";
|
||||
const WRITE_F64_RES_SYSTEM_NAME: &str = "write_f64_res";
|
||||
|
||||
// A systems
|
||||
|
||||
fn read_u32(counter: Res<Counter>, _query: Query<&u32>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert!(*count < 2, "should be one of the first two systems to run");
|
||||
*count += 1;
|
||||
fn read_u32(completed_systems: Res<CompletedSystems>, _query: Query<&u32>) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(!completed_systems.contains(READ_U32_WRITE_U64_SYSTEM_NAME));
|
||||
completed_systems.insert(READ_U32_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
fn write_float(counter: Res<Counter>, _query: Query<&f32>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert!(*count < 2, "should be one of the first two systems to run");
|
||||
*count += 1;
|
||||
fn write_float(completed_systems: Res<CompletedSystems>, _query: Query<&f32>) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
completed_systems.insert(WRITE_FLOAT_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
fn read_u32_write_u64(counter: Res<Counter>, _query: Query<(&u32, &mut u64)>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert_eq!(*count, 2, "should always be the 3rd system to run");
|
||||
*count += 1;
|
||||
fn read_u32_write_u64(
|
||||
completed_systems: Res<CompletedSystems>,
|
||||
_query: Query<(&u32, &mut u64)>,
|
||||
) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(completed_systems.contains(READ_U32_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(READ_U64_SYSTEM_NAME));
|
||||
completed_systems.insert(READ_U32_WRITE_U64_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
fn read_u64(counter: Res<Counter>, _query: Query<&u64>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert_eq!(*count, 3, "should always be the 4th system to run");
|
||||
*count += 1;
|
||||
fn read_u64(completed_systems: Res<CompletedSystems>, _query: Query<&u64>) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(completed_systems.contains(READ_U32_WRITE_U64_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(WRITE_U64_SYSTEM_NAME));
|
||||
completed_systems.insert(READ_U64_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
schedule.add_system_to_stage("A", read_u32.system());
|
||||
@ -529,23 +606,28 @@ mod tests {
|
||||
|
||||
// B systems
|
||||
|
||||
fn write_u64(counter: Res<Counter>, _query: Query<&mut u64>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert_eq!(*count, 4, "should always be the 5th system to run");
|
||||
*count += 1;
|
||||
fn write_u64(completed_systems: Res<CompletedSystems>, _query: Query<&mut u64>) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(completed_systems.contains(READ_U64_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(THREAD_LOCAL_SYSTEM_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(WRITE_F32_SYSTEM_NAME));
|
||||
completed_systems.insert(WRITE_U64_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
fn thread_local_system(_world: &mut World, resources: &mut Resources) {
|
||||
let counter = resources.get::<Counter>().unwrap();
|
||||
let mut count = counter.count.lock();
|
||||
assert_eq!(*count, 5, "should always be the 6th system to run");
|
||||
*count += 1;
|
||||
let completed_systems = resources.get::<CompletedSystems>().unwrap();
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(completed_systems.contains(WRITE_U64_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(WRITE_F32_SYSTEM_NAME));
|
||||
completed_systems.insert(THREAD_LOCAL_SYSTEM_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
fn write_f32(counter: Res<Counter>, _query: Query<&mut f32>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert_eq!(*count, 6, "should always be the 7th system to run");
|
||||
*count += 1;
|
||||
fn write_f32(completed_systems: Res<CompletedSystems>, _query: Query<&mut f32>) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(completed_systems.contains(WRITE_U64_SYSTEM_NAME));
|
||||
assert!(completed_systems.contains(THREAD_LOCAL_SYSTEM_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(READ_F64_RES_SYSTEM_NAME));
|
||||
completed_systems.insert(WRITE_F32_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
schedule.add_system_to_stage("B", write_u64.system());
|
||||
@ -554,38 +636,35 @@ mod tests {
|
||||
|
||||
// C systems
|
||||
|
||||
fn read_f64_res(counter: Res<Counter>, _f64_res: Res<f64>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert!(
|
||||
7 == *count || *count == 8,
|
||||
"should always be the 8th or 9th system to run"
|
||||
);
|
||||
*count += 1;
|
||||
fn read_f64_res(completed_systems: Res<CompletedSystems>, _f64_res: Res<f64>) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(completed_systems.contains(WRITE_F32_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(READ_ISIZE_WRITE_F64_RES_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(WRITE_F64_RES_SYSTEM_NAME));
|
||||
completed_systems.insert(READ_F64_RES_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
fn read_isize_res(counter: Res<Counter>, _isize_res: Res<isize>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert!(
|
||||
7 == *count || *count == 8,
|
||||
"should always be the 8th or 9th system to run"
|
||||
);
|
||||
*count += 1;
|
||||
fn read_isize_res(completed_systems: Res<CompletedSystems>, _isize_res: Res<isize>) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
completed_systems.insert(READ_ISIZE_RES_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
fn read_isize_write_f64_res(
|
||||
counter: Res<Counter>,
|
||||
completed_systems: Res<CompletedSystems>,
|
||||
_isize_res: Res<isize>,
|
||||
_f64_res: ResMut<f64>,
|
||||
) {
|
||||
let mut count = counter.count.lock();
|
||||
assert_eq!(*count, 9, "should always be the 10th system to run");
|
||||
*count += 1;
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(completed_systems.contains(READ_F64_RES_SYSTEM_NAME));
|
||||
assert!(!completed_systems.contains(WRITE_F64_RES_SYSTEM_NAME));
|
||||
completed_systems.insert(READ_ISIZE_WRITE_F64_RES_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
fn write_f64_res(counter: Res<Counter>, _f64_res: ResMut<f64>) {
|
||||
let mut count = counter.count.lock();
|
||||
assert_eq!(*count, 10, "should always be the 11th system to run");
|
||||
*count += 1;
|
||||
fn write_f64_res(completed_systems: Res<CompletedSystems>, _f64_res: ResMut<f64>) {
|
||||
let mut completed_systems = completed_systems.completed_systems.lock();
|
||||
assert!(completed_systems.contains(READ_F64_RES_SYSTEM_NAME));
|
||||
assert!(completed_systems.contains(READ_ISIZE_WRITE_F64_RES_SYSTEM_NAME));
|
||||
completed_systems.insert(WRITE_F64_RES_SYSTEM_NAME);
|
||||
}
|
||||
|
||||
schedule.add_system_to_stage("C", read_f64_res.system());
|
||||
@ -660,18 +739,38 @@ mod tests {
|
||||
]
|
||||
);
|
||||
|
||||
let counter = resources.get::<Counter>().unwrap();
|
||||
let completed_systems = resources.get::<CompletedSystems>().unwrap();
|
||||
assert_eq!(
|
||||
*counter.count.lock(),
|
||||
completed_systems.completed_systems.lock().len(),
|
||||
11,
|
||||
"counter should have been incremented once for each system"
|
||||
"completed_systems should have been incremented once for each system"
|
||||
);
|
||||
}
|
||||
|
||||
// Stress test the "clean start" case
|
||||
for _ in 0..1000 {
|
||||
let mut executor = ParallelExecutor::default();
|
||||
run_executor_and_validate(&mut executor, &mut schedule, &mut world, &mut resources);
|
||||
resources
|
||||
.get::<CompletedSystems>()
|
||||
.unwrap()
|
||||
.completed_systems
|
||||
.lock()
|
||||
.clear();
|
||||
}
|
||||
|
||||
// Stress test the "continue running" case
|
||||
let mut executor = ParallelExecutor::default();
|
||||
run_executor_and_validate(&mut executor, &mut schedule, &mut world, &mut resources);
|
||||
// run again (with counter reset) to ensure executor works correctly across runs
|
||||
*resources.get::<Counter>().unwrap().count.lock() = 0;
|
||||
run_executor_and_validate(&mut executor, &mut schedule, &mut world, &mut resources);
|
||||
for _ in 0..1000 {
|
||||
// run again (with completed_systems reset) to ensure executor works correctly across runs
|
||||
resources
|
||||
.get::<CompletedSystems>()
|
||||
.unwrap()
|
||||
.completed_systems
|
||||
.lock()
|
||||
.clear();
|
||||
run_executor_and_validate(&mut executor, &mut schedule, &mut world, &mut resources);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -15,3 +15,4 @@ multitask = "0.2"
|
||||
num_cpus = "1"
|
||||
parking = "1"
|
||||
pollster = "0.2"
|
||||
event-listener = "2.4.0"
|
124
crates/bevy_tasks/src/countdown_event.rs
Normal file
124
crates/bevy_tasks/src/countdown_event.rs
Normal file
@ -0,0 +1,124 @@
|
||||
use event_listener::Event;
|
||||
use std::sync::{
|
||||
atomic::{AtomicIsize, Ordering},
|
||||
Arc,
|
||||
};
|
||||
|
||||
#[derive(Debug)]
|
||||
struct CountdownEventInner {
|
||||
/// Async primitive that can be awaited and signalled. We fire it when counter hits 0.
|
||||
event: Event,
|
||||
|
||||
/// The number of decrements remaining
|
||||
counter: AtomicIsize,
|
||||
}
|
||||
|
||||
/// A counter that starts with an initial count `n`. Once it is decremented `n` times, it will be
|
||||
/// "ready". Call `listen` to get a future that can be awaited.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct CountdownEvent {
|
||||
inner: Arc<CountdownEventInner>,
|
||||
}
|
||||
|
||||
impl CountdownEvent {
|
||||
/// Creates a CountdownEvent that must be decremented `n` times for listeners to be
|
||||
/// signalled
|
||||
pub fn new(n: isize) -> Self {
|
||||
let inner = CountdownEventInner {
|
||||
event: Event::new(),
|
||||
counter: AtomicIsize::new(n),
|
||||
};
|
||||
|
||||
CountdownEvent {
|
||||
inner: Arc::new(inner),
|
||||
}
|
||||
}
|
||||
|
||||
/// Decrement the counter by one. If this is the Nth call, trigger all listeners
|
||||
pub fn decrement(&self) {
|
||||
// If we are the last decrementer, notify listeners
|
||||
let value = self.inner.counter.fetch_sub(1, Ordering::AcqRel);
|
||||
if value <= 1 {
|
||||
self.inner.event.notify(std::usize::MAX);
|
||||
|
||||
// Reset to 0 - wrapping an isize negative seems unlikely but should probably do it
|
||||
// anyways.
|
||||
self.inner.counter.store(0, Ordering::Release);
|
||||
}
|
||||
}
|
||||
|
||||
/// Resets the counter. Any listens following this point will not be notified until decrement
|
||||
/// is called N times
|
||||
pub fn reset(&self, n: isize) {
|
||||
self.inner.counter.store(n, Ordering::Release);
|
||||
}
|
||||
|
||||
/// Awaits decrement being called N times
|
||||
pub async fn listen(&self) {
|
||||
let mut listener = None;
|
||||
|
||||
// The complexity here is due to Event not necessarily signalling awaits that are placed
|
||||
// after the await is called. So we must check the counter AFTER taking a listener.
|
||||
loop {
|
||||
// We're done, break
|
||||
if self.inner.counter.load(Ordering::Acquire) <= 0 {
|
||||
break;
|
||||
}
|
||||
|
||||
match listener.take() {
|
||||
None => {
|
||||
listener = Some(self.inner.event.listen());
|
||||
}
|
||||
Some(l) => {
|
||||
l.await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn countdown_event_ready_after() {
|
||||
let countdown_event = CountdownEvent::new(2);
|
||||
countdown_event.decrement();
|
||||
countdown_event.decrement();
|
||||
pollster::block_on(countdown_event.listen());
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn countdown_event_ready() {
|
||||
let countdown_event = CountdownEvent::new(2);
|
||||
countdown_event.decrement();
|
||||
let countdown_event_clone = countdown_event.clone();
|
||||
let handle = std::thread::spawn(move || pollster::block_on(countdown_event_clone.listen()));
|
||||
|
||||
// Pause to give the new thread time to start blocking (ugly hack)
|
||||
std::thread::sleep(std::time::Duration::from_millis(100));
|
||||
|
||||
countdown_event.decrement();
|
||||
handle.join().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn event_resets_if_listeners_are_cleared() {
|
||||
let event = Event::new();
|
||||
|
||||
// notify all listeners
|
||||
let listener1 = event.listen();
|
||||
event.notify(std::usize::MAX);
|
||||
pollster::block_on(listener1);
|
||||
|
||||
// If all listeners are notified, the structure should now be cleared. We're free to listen again
|
||||
let listener2 = event.listen();
|
||||
let listener3 = event.listen();
|
||||
|
||||
// Verify that we are still blocked
|
||||
assert_eq!(
|
||||
false,
|
||||
listener2.wait_timeout(std::time::Duration::from_millis(10))
|
||||
);
|
||||
|
||||
// Notify all and verify the remaining listener is notified
|
||||
event.notify(std::usize::MAX);
|
||||
pollster::block_on(listener3);
|
||||
}
|
@ -10,6 +10,9 @@ pub use task_pool::{Scope, TaskPool, TaskPoolBuilder};
|
||||
mod usages;
|
||||
pub use usages::{AsyncComputeTaskPool, ComputeTaskPool, IOTaskPool};
|
||||
|
||||
mod countdown_event;
|
||||
pub use countdown_event::CountdownEvent;
|
||||
|
||||
pub mod prelude {
|
||||
pub use crate::{
|
||||
slice::{ParallelSlice, ParallelSliceMut},
|
||||
|
Loading…
Reference in New Issue
Block a user