 28ba87e6c8
			
		
	
	
		28ba87e6c8
		
	
	
	
	
		
			
			# Objective
Fixes #1529
Run bevy_ecs in miri
## Solution
- Don't set thread names when running in miri rust-lang/miri/issues/1717
- Update `event-listener` to `2.5.2` as previous versions have UB that is detected by miri: [event-listener commit](1fa31c553e)
- Ignore memory leaks when running in miri as they are impossible to track down rust-lang/miri/issues/1481
- Make `table_add_remove_many` test less "many" because miri is really quite slow :)
- Make CI run `RUSTFLAGS="-Zrandomize-layout" MIRIFLAGS="-Zmiri-ignore-leaks -Zmiri-tag-raw-pointers -Zmiri-disable-isolation" cargo +nightly miri test -p bevy_ecs`
		
	
			
		
			
				
	
	
		
			429 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
			
		
		
	
	
			429 lines
		
	
	
		
			15 KiB
		
	
	
	
		
			Rust
		
	
	
	
	
	
| use std::{
 | |
|     future::Future,
 | |
|     mem,
 | |
|     pin::Pin,
 | |
|     sync::Arc,
 | |
|     thread::{self, JoinHandle},
 | |
| };
 | |
| 
 | |
| use futures_lite::{future, pin};
 | |
| 
 | |
| use crate::Task;
 | |
| 
 | |
| /// Used to create a [`TaskPool`]
 | |
| #[derive(Debug, Default, Clone)]
 | |
| #[must_use]
 | |
| pub struct TaskPoolBuilder {
 | |
|     /// If set, we'll set up the thread pool to use at most n threads. Otherwise use
 | |
|     /// the logical core count of the system
 | |
|     num_threads: Option<usize>,
 | |
|     /// If set, we'll use the given stack size rather than the system default
 | |
|     stack_size: Option<usize>,
 | |
|     /// Allows customizing the name of the threads - helpful for debugging. If set, threads will
 | |
|     /// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
 | |
|     thread_name: Option<String>,
 | |
| }
 | |
| 
 | |
| impl TaskPoolBuilder {
 | |
|     /// Creates a new [`TaskPoolBuilder`] instance
 | |
|     pub fn new() -> Self {
 | |
|         Self::default()
 | |
|     }
 | |
| 
 | |
|     /// Override the number of threads created for the pool. If unset, we default to the number
 | |
|     /// of logical cores of the system
 | |
|     pub fn num_threads(mut self, num_threads: usize) -> Self {
 | |
|         self.num_threads = Some(num_threads);
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     /// Override the stack size of the threads created for the pool
 | |
|     pub fn stack_size(mut self, stack_size: usize) -> Self {
 | |
|         self.stack_size = Some(stack_size);
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     /// Override the name of the threads created for the pool. If set, threads will
 | |
|     /// be named `<thread_name> (<thread_index>)`, i.e. `MyThreadPool (2)`
 | |
|     pub fn thread_name(mut self, thread_name: String) -> Self {
 | |
|         self.thread_name = Some(thread_name);
 | |
|         self
 | |
|     }
 | |
| 
 | |
|     /// Creates a new [`TaskPool`] based on the current options.
 | |
|     pub fn build(self) -> TaskPool {
 | |
|         TaskPool::new_internal(
 | |
|             self.num_threads,
 | |
|             self.stack_size,
 | |
|             self.thread_name.as_deref(),
 | |
|         )
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[derive(Debug)]
 | |
| struct TaskPoolInner {
 | |
|     threads: Vec<JoinHandle<()>>,
 | |
|     shutdown_tx: async_channel::Sender<()>,
 | |
| }
 | |
| 
 | |
| impl Drop for TaskPoolInner {
 | |
|     fn drop(&mut self) {
 | |
|         self.shutdown_tx.close();
 | |
| 
 | |
|         let panicking = thread::panicking();
 | |
|         for join_handle in self.threads.drain(..) {
 | |
|             let res = join_handle.join();
 | |
|             if !panicking {
 | |
|                 res.expect("Task thread panicked while executing.");
 | |
|             }
 | |
|         }
 | |
|     }
 | |
| }
 | |
| 
 | |
| /// A thread pool for executing tasks. Tasks are futures that are being automatically driven by
 | |
| /// the pool on threads owned by the pool.
 | |
| #[derive(Debug, Clone)]
 | |
| pub struct TaskPool {
 | |
|     /// The executor for the pool
 | |
|     ///
 | |
|     /// This has to be separate from TaskPoolInner because we have to create an Arc<Executor> to
 | |
|     /// pass into the worker threads, and we must create the worker threads before we can create
 | |
|     /// the Vec<Task<T>> contained within TaskPoolInner
 | |
|     executor: Arc<async_executor::Executor<'static>>,
 | |
| 
 | |
|     /// Inner state of the pool
 | |
|     inner: Arc<TaskPoolInner>,
 | |
| }
 | |
| 
 | |
| impl TaskPool {
 | |
|     thread_local! {
 | |
|         static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new();
 | |
|     }
 | |
| 
 | |
|     /// Create a `TaskPool` with the default configuration.
 | |
|     pub fn new() -> Self {
 | |
|         TaskPoolBuilder::new().build()
 | |
|     }
 | |
| 
 | |
|     fn new_internal(
 | |
|         num_threads: Option<usize>,
 | |
|         stack_size: Option<usize>,
 | |
|         thread_name: Option<&str>,
 | |
|     ) -> Self {
 | |
|         let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
 | |
| 
 | |
|         let executor = Arc::new(async_executor::Executor::new());
 | |
| 
 | |
|         let num_threads = num_threads.unwrap_or_else(num_cpus::get);
 | |
| 
 | |
|         let threads = (0..num_threads)
 | |
|             .map(|i| {
 | |
|                 let ex = Arc::clone(&executor);
 | |
|                 let shutdown_rx = shutdown_rx.clone();
 | |
| 
 | |
|                 // miri does not support setting thread names
 | |
|                 // TODO: change back when https://github.com/rust-lang/miri/issues/1717 is fixed
 | |
|                 #[cfg(not(miri))]
 | |
|                 let mut thread_builder = {
 | |
|                     let thread_name = if let Some(thread_name) = thread_name {
 | |
|                         format!("{} ({})", thread_name, i)
 | |
|                     } else {
 | |
|                         format!("TaskPool ({})", i)
 | |
|                     };
 | |
|                     thread::Builder::new().name(thread_name)
 | |
|                 };
 | |
|                 #[cfg(miri)]
 | |
|                 let mut thread_builder = {
 | |
|                     let _ = i;
 | |
|                     let _ = thread_name;
 | |
|                     thread::Builder::new()
 | |
|                 };
 | |
| 
 | |
|                 if let Some(stack_size) = stack_size {
 | |
|                     thread_builder = thread_builder.stack_size(stack_size);
 | |
|                 }
 | |
| 
 | |
|                 thread_builder
 | |
|                     .spawn(move || {
 | |
|                         let shutdown_future = ex.run(shutdown_rx.recv());
 | |
|                         // Use unwrap_err because we expect a Closed error
 | |
|                         future::block_on(shutdown_future).unwrap_err();
 | |
|                     })
 | |
|                     .expect("Failed to spawn thread.")
 | |
|             })
 | |
|             .collect();
 | |
| 
 | |
|         Self {
 | |
|             executor,
 | |
|             inner: Arc::new(TaskPoolInner {
 | |
|                 threads,
 | |
|                 shutdown_tx,
 | |
|             }),
 | |
|         }
 | |
|     }
 | |
| 
 | |
|     /// Return the number of threads owned by the task pool
 | |
|     pub fn thread_num(&self) -> usize {
 | |
|         self.inner.threads.len()
 | |
|     }
 | |
| 
 | |
|     /// Allows spawning non-`'static` futures on the thread pool. The function takes a callback,
 | |
|     /// passing a scope object into it. The scope object provided to the callback can be used
 | |
|     /// to spawn tasks. This function will await the completion of all tasks before returning.
 | |
|     ///
 | |
|     /// This is similar to `rayon::scope` and `crossbeam::scope`
 | |
|     pub fn scope<'scope, F, T>(&self, f: F) -> Vec<T>
 | |
|     where
 | |
|         F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send,
 | |
|         T: Send + 'static,
 | |
|     {
 | |
|         TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
 | |
|             // SAFETY: This function blocks until all futures complete, so this future must return
 | |
|             // before this function returns. However, rust has no way of knowing
 | |
|             // this so we must convert to 'static here to appease the compiler as it is unable to
 | |
|             // validate safety.
 | |
|             let executor: &async_executor::Executor = &*self.executor;
 | |
|             let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) };
 | |
|             let local_executor: &'scope async_executor::LocalExecutor =
 | |
|                 unsafe { mem::transmute(local_executor) };
 | |
|             let mut scope = Scope {
 | |
|                 executor,
 | |
|                 local_executor,
 | |
|                 spawned: Vec::new(),
 | |
|             };
 | |
| 
 | |
|             f(&mut scope);
 | |
| 
 | |
|             if scope.spawned.is_empty() {
 | |
|                 Vec::default()
 | |
|             } else if scope.spawned.len() == 1 {
 | |
|                 vec![future::block_on(&mut scope.spawned[0])]
 | |
|             } else {
 | |
|                 let fut = async move {
 | |
|                     let mut results = Vec::with_capacity(scope.spawned.len());
 | |
|                     for task in scope.spawned {
 | |
|                         results.push(task.await);
 | |
|                     }
 | |
| 
 | |
|                     results
 | |
|                 };
 | |
| 
 | |
|                 // Pin the futures on the stack.
 | |
|                 pin!(fut);
 | |
| 
 | |
|                 // SAFETY: This function blocks until all futures complete, so we do not read/write
 | |
|                 // the data from futures outside of the 'scope lifetime. However,
 | |
|                 // rust has no way of knowing this so we must convert to 'static
 | |
|                 // here to appease the compiler as it is unable to validate safety.
 | |
|                 let fut: Pin<&mut (dyn Future<Output = Vec<T>>)> = fut;
 | |
|                 let fut: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static)> =
 | |
|                     unsafe { mem::transmute(fut) };
 | |
| 
 | |
|                 // The thread that calls scope() will participate in driving tasks in the pool
 | |
|                 // forward until the tasks that are spawned by this scope() call
 | |
|                 // complete. (If the caller of scope() happens to be a thread in
 | |
|                 // this thread pool, and we only have one thread in the pool, then
 | |
|                 // simply calling future::block_on(spawned) would deadlock.)
 | |
|                 let mut spawned = local_executor.spawn(fut);
 | |
|                 loop {
 | |
|                     if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
 | |
|                         break result;
 | |
|                     };
 | |
| 
 | |
|                     self.executor.try_tick();
 | |
|                     local_executor.try_tick();
 | |
|                 }
 | |
|             }
 | |
|         })
 | |
|     }
 | |
| 
 | |
|     /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be
 | |
|     /// cancelled and "detached" allowing it to continue running without having to be polled by the
 | |
|     /// end-user.
 | |
|     ///
 | |
|     /// If the provided future is non-`Send`, [`TaskPool::spawn_local`] should be used instead.
 | |
|     pub fn spawn<T>(&self, future: impl Future<Output = T> + Send + 'static) -> Task<T>
 | |
|     where
 | |
|         T: Send + 'static,
 | |
|     {
 | |
|         Task::new(self.executor.spawn(future))
 | |
|     }
 | |
| 
 | |
|     /// Spawns a static future on the thread-local async executor for the current thread. The task
 | |
|     /// will run entirely on the thread the task was spawned on.  The returned Task is a future.
 | |
|     /// It can also be cancelled and "detached" allowing it to continue running without having
 | |
|     /// to be polled by the end-user. Users should generally prefer to use [`TaskPool::spawn`]
 | |
|     /// instead, unless the provided future is not `Send`.
 | |
|     pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
 | |
|     where
 | |
|         T: 'static,
 | |
|     {
 | |
|         Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
 | |
|     }
 | |
| }
 | |
| 
 | |
| impl Default for TaskPool {
 | |
|     fn default() -> Self {
 | |
|         Self::new()
 | |
|     }
 | |
| }
 | |
| 
 | |
| /// A `TaskPool` scope for running one or more non-`'static` futures.
 | |
| ///
 | |
| /// For more information, see [`TaskPool::scope`].
 | |
| #[derive(Debug)]
 | |
| pub struct Scope<'scope, T> {
 | |
|     executor: &'scope async_executor::Executor<'scope>,
 | |
|     local_executor: &'scope async_executor::LocalExecutor<'scope>,
 | |
|     spawned: Vec<async_executor::Task<T>>,
 | |
| }
 | |
| 
 | |
| impl<'scope, T: Send + 'scope> Scope<'scope, T> {
 | |
|     /// Spawns a scoped future onto the thread pool. The scope *must* outlive
 | |
|     /// the provided future. The results of the future will be returned as a part of
 | |
|     /// [`TaskPool::scope`]'s return value.
 | |
|     ///
 | |
|     /// If the provided future is non-`Send`, [`Scope::spawn_local`] should be used
 | |
|     /// instead.
 | |
|     ///
 | |
|     /// For more information, see [`TaskPool::scope`].
 | |
|     pub fn spawn<Fut: Future<Output = T> + 'scope + Send>(&mut self, f: Fut) {
 | |
|         let task = self.executor.spawn(f);
 | |
|         self.spawned.push(task);
 | |
|     }
 | |
| 
 | |
|     /// Spawns a scoped future onto the thread-local executor. The scope *must* outlive
 | |
|     /// the provided future. The results of the future will be returned as a part of
 | |
|     /// [`TaskPool::scope`]'s return value.  Users should generally prefer to use
 | |
|     /// [`Scope::spawn`] instead, unless the provided future is not `Send`.
 | |
|     ///
 | |
|     /// For more information, see [`TaskPool::scope`].
 | |
|     pub fn spawn_local<Fut: Future<Output = T> + 'scope>(&mut self, f: Fut) {
 | |
|         let task = self.local_executor.spawn(f);
 | |
|         self.spawned.push(task);
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[cfg(test)]
 | |
| #[allow(clippy::blacklisted_name)]
 | |
| mod tests {
 | |
|     use super::*;
 | |
|     use std::sync::{
 | |
|         atomic::{AtomicBool, AtomicI32, Ordering},
 | |
|         Barrier,
 | |
|     };
 | |
| 
 | |
|     #[test]
 | |
|     fn test_spawn() {
 | |
|         let pool = TaskPool::new();
 | |
| 
 | |
|         let foo = Box::new(42);
 | |
|         let foo = &*foo;
 | |
| 
 | |
|         let count = Arc::new(AtomicI32::new(0));
 | |
| 
 | |
|         let outputs = pool.scope(|scope| {
 | |
|             for _ in 0..100 {
 | |
|                 let count_clone = count.clone();
 | |
|                 scope.spawn(async move {
 | |
|                     if *foo != 42 {
 | |
|                         panic!("not 42!?!?")
 | |
|                     } else {
 | |
|                         count_clone.fetch_add(1, Ordering::Relaxed);
 | |
|                         *foo
 | |
|                     }
 | |
|                 });
 | |
|             }
 | |
|         });
 | |
| 
 | |
|         for output in &outputs {
 | |
|             assert_eq!(*output, 42);
 | |
|         }
 | |
| 
 | |
|         assert_eq!(outputs.len(), 100);
 | |
|         assert_eq!(count.load(Ordering::Relaxed), 100);
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn test_mixed_spawn_local_and_spawn() {
 | |
|         let pool = TaskPool::new();
 | |
| 
 | |
|         let foo = Box::new(42);
 | |
|         let foo = &*foo;
 | |
| 
 | |
|         let local_count = Arc::new(AtomicI32::new(0));
 | |
|         let non_local_count = Arc::new(AtomicI32::new(0));
 | |
| 
 | |
|         let outputs = pool.scope(|scope| {
 | |
|             for i in 0..100 {
 | |
|                 if i % 2 == 0 {
 | |
|                     let count_clone = non_local_count.clone();
 | |
|                     scope.spawn(async move {
 | |
|                         if *foo != 42 {
 | |
|                             panic!("not 42!?!?")
 | |
|                         } else {
 | |
|                             count_clone.fetch_add(1, Ordering::Relaxed);
 | |
|                             *foo
 | |
|                         }
 | |
|                     });
 | |
|                 } else {
 | |
|                     let count_clone = local_count.clone();
 | |
|                     scope.spawn_local(async move {
 | |
|                         if *foo != 42 {
 | |
|                             panic!("not 42!?!?")
 | |
|                         } else {
 | |
|                             count_clone.fetch_add(1, Ordering::Relaxed);
 | |
|                             *foo
 | |
|                         }
 | |
|                     });
 | |
|                 }
 | |
|             }
 | |
|         });
 | |
| 
 | |
|         for output in &outputs {
 | |
|             assert_eq!(*output, 42);
 | |
|         }
 | |
| 
 | |
|         assert_eq!(outputs.len(), 100);
 | |
|         assert_eq!(local_count.load(Ordering::Relaxed), 50);
 | |
|         assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
 | |
|     }
 | |
| 
 | |
|     #[test]
 | |
|     fn test_thread_locality() {
 | |
|         let pool = Arc::new(TaskPool::new());
 | |
|         let count = Arc::new(AtomicI32::new(0));
 | |
|         let barrier = Arc::new(Barrier::new(101));
 | |
|         let thread_check_failed = Arc::new(AtomicBool::new(false));
 | |
| 
 | |
|         for _ in 0..100 {
 | |
|             let inner_barrier = barrier.clone();
 | |
|             let count_clone = count.clone();
 | |
|             let inner_pool = pool.clone();
 | |
|             let inner_thread_check_failed = thread_check_failed.clone();
 | |
|             std::thread::spawn(move || {
 | |
|                 inner_pool.scope(|scope| {
 | |
|                     let inner_count_clone = count_clone.clone();
 | |
|                     scope.spawn(async move {
 | |
|                         inner_count_clone.fetch_add(1, Ordering::Release);
 | |
|                     });
 | |
|                     let spawner = std::thread::current().id();
 | |
|                     let inner_count_clone = count_clone.clone();
 | |
|                     scope.spawn_local(async move {
 | |
|                         inner_count_clone.fetch_add(1, Ordering::Release);
 | |
|                         if std::thread::current().id() != spawner {
 | |
|                             // NOTE: This check is using an atomic rather than simply panicing the
 | |
|                             // thread to avoid deadlocking the barrier on failure
 | |
|                             inner_thread_check_failed.store(true, Ordering::Release);
 | |
|                         }
 | |
|                     });
 | |
|                 });
 | |
|                 inner_barrier.wait();
 | |
|             });
 | |
|         }
 | |
|         barrier.wait();
 | |
|         assert!(!thread_check_failed.load(Ordering::Acquire));
 | |
|         assert_eq!(count.load(Ordering::Acquire), 200);
 | |
|     }
 | |
| }
 |