parent
							
								
									e7dab0c359
								
							
						
					
					
						commit
						3c5f1f8a80
					
				@ -102,6 +102,13 @@ impl TaskPool {
 | 
				
			|||||||
        });
 | 
					        });
 | 
				
			||||||
        FakeTask
 | 
					        FakeTask
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> FakeTask
 | 
				
			||||||
 | 
					    where
 | 
				
			||||||
 | 
					        T: 'static,
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        self.spawn(future)
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[derive(Debug)]
 | 
					#[derive(Debug)]
 | 
				
			||||||
 | 
				
			|||||||
@ -95,6 +95,10 @@ pub struct TaskPool {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl TaskPool {
 | 
					impl TaskPool {
 | 
				
			||||||
 | 
					    thread_local! {
 | 
				
			||||||
 | 
					        static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new();
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// Create a `TaskPool` with the default configuration.
 | 
					    /// Create a `TaskPool` with the default configuration.
 | 
				
			||||||
    pub fn new() -> Self {
 | 
					    pub fn new() -> Self {
 | 
				
			||||||
        TaskPoolBuilder::new().build()
 | 
					        TaskPoolBuilder::new().build()
 | 
				
			||||||
@ -162,15 +166,18 @@ impl TaskPool {
 | 
				
			|||||||
        F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send,
 | 
					        F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send,
 | 
				
			||||||
        T: Send + 'static,
 | 
					        T: Send + 'static,
 | 
				
			||||||
    {
 | 
					    {
 | 
				
			||||||
 | 
					        TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
 | 
				
			||||||
            // SAFETY: This function blocks until all futures complete, so this future must return
 | 
					            // SAFETY: This function blocks until all futures complete, so this future must return
 | 
				
			||||||
            // before this function returns. However, rust has no way of knowing
 | 
					            // before this function returns. However, rust has no way of knowing
 | 
				
			||||||
            // this so we must convert to 'static here to appease the compiler as it is unable to
 | 
					            // this so we must convert to 'static here to appease the compiler as it is unable to
 | 
				
			||||||
            // validate safety.
 | 
					            // validate safety.
 | 
				
			||||||
            let executor: &async_executor::Executor = &*self.executor;
 | 
					            let executor: &async_executor::Executor = &*self.executor;
 | 
				
			||||||
            let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) };
 | 
					            let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) };
 | 
				
			||||||
 | 
					            let local_executor: &'scope async_executor::LocalExecutor =
 | 
				
			||||||
 | 
					                unsafe { mem::transmute(local_executor) };
 | 
				
			||||||
            let mut scope = Scope {
 | 
					            let mut scope = Scope {
 | 
				
			||||||
                executor,
 | 
					                executor,
 | 
				
			||||||
 | 
					                local_executor,
 | 
				
			||||||
                spawned: Vec::new(),
 | 
					                spawned: Vec::new(),
 | 
				
			||||||
            };
 | 
					            };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -190,30 +197,32 @@ impl TaskPool {
 | 
				
			|||||||
                    results
 | 
					                    results
 | 
				
			||||||
                };
 | 
					                };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            // Pin the future on the stack.
 | 
					                // Pin the futures on the stack.
 | 
				
			||||||
                pin!(fut);
 | 
					                pin!(fut);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                // SAFETY: This function blocks until all futures complete, so we do not read/write the
 | 
					                // SAFETY: This function blocks until all futures complete, so we do not read/write the
 | 
				
			||||||
                // data from futures outside of the 'scope lifetime. However, rust has no way of knowing
 | 
					                // data from futures outside of the 'scope lifetime. However, rust has no way of knowing
 | 
				
			||||||
                // this so we must convert to 'static here to appease the compiler as it is unable to
 | 
					                // this so we must convert to 'static here to appease the compiler as it is unable to
 | 
				
			||||||
                // validate safety.
 | 
					                // validate safety.
 | 
				
			||||||
            let fut: Pin<&mut (dyn Future<Output = Vec<T>> + Send)> = fut;
 | 
					                let fut: Pin<&mut (dyn Future<Output = Vec<T>>)> = fut;
 | 
				
			||||||
            let fut: Pin<&'static mut (dyn Future<Output = Vec<T>> + Send + 'static)> =
 | 
					                let fut: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static)> =
 | 
				
			||||||
                    unsafe { mem::transmute(fut) };
 | 
					                    unsafe { mem::transmute(fut) };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                // The thread that calls scope() will participate in driving tasks in the pool forward
 | 
					                // The thread that calls scope() will participate in driving tasks in the pool forward
 | 
				
			||||||
                // until the tasks that are spawned by this scope() call complete. (If the caller of scope()
 | 
					                // until the tasks that are spawned by this scope() call complete. (If the caller of scope()
 | 
				
			||||||
                // happens to be a thread in this thread pool, and we only have one thread in the pool, then
 | 
					                // happens to be a thread in this thread pool, and we only have one thread in the pool, then
 | 
				
			||||||
                // simply calling future::block_on(spawned) would deadlock.)
 | 
					                // simply calling future::block_on(spawned) would deadlock.)
 | 
				
			||||||
            let mut spawned = self.executor.spawn(fut);
 | 
					                let mut spawned = local_executor.spawn(fut);
 | 
				
			||||||
                loop {
 | 
					                loop {
 | 
				
			||||||
                    if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
 | 
					                    if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
 | 
				
			||||||
                        break result;
 | 
					                        break result;
 | 
				
			||||||
                }
 | 
					                    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    self.executor.try_tick();
 | 
					                    self.executor.try_tick();
 | 
				
			||||||
 | 
					                    local_executor.try_tick();
 | 
				
			||||||
                }
 | 
					                }
 | 
				
			||||||
            }
 | 
					            }
 | 
				
			||||||
 | 
					        })
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be
 | 
					    /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be
 | 
				
			||||||
@ -225,6 +234,13 @@ impl TaskPool {
 | 
				
			|||||||
    {
 | 
					    {
 | 
				
			||||||
        Task::new(self.executor.spawn(future))
 | 
					        Task::new(self.executor.spawn(future))
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
 | 
				
			||||||
 | 
					    where
 | 
				
			||||||
 | 
					        T: 'static,
 | 
				
			||||||
 | 
					    {
 | 
				
			||||||
 | 
					        Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
impl Default for TaskPool {
 | 
					impl Default for TaskPool {
 | 
				
			||||||
@ -236,6 +252,7 @@ impl Default for TaskPool {
 | 
				
			|||||||
#[derive(Debug)]
 | 
					#[derive(Debug)]
 | 
				
			||||||
pub struct Scope<'scope, T> {
 | 
					pub struct Scope<'scope, T> {
 | 
				
			||||||
    executor: &'scope async_executor::Executor<'scope>,
 | 
					    executor: &'scope async_executor::Executor<'scope>,
 | 
				
			||||||
 | 
					    local_executor: &'scope async_executor::LocalExecutor<'scope>,
 | 
				
			||||||
    spawned: Vec<async_executor::Task<T>>,
 | 
					    spawned: Vec<async_executor::Task<T>>,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -244,12 +261,20 @@ impl<'scope, T: Send + 'scope> Scope<'scope, T> {
 | 
				
			|||||||
        let task = self.executor.spawn(f);
 | 
					        let task = self.executor.spawn(f);
 | 
				
			||||||
        self.spawned.push(task);
 | 
					        self.spawned.push(task);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    pub fn spawn_local<Fut: Future<Output = T> + 'scope>(&mut self, f: Fut) {
 | 
				
			||||||
 | 
					        let task = self.local_executor.spawn(f);
 | 
				
			||||||
 | 
					        self.spawned.push(task);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
#[cfg(test)]
 | 
					#[cfg(test)]
 | 
				
			||||||
mod tests {
 | 
					mod tests {
 | 
				
			||||||
    use super::*;
 | 
					    use super::*;
 | 
				
			||||||
    use std::sync::atomic::{AtomicI32, Ordering};
 | 
					    use std::sync::{
 | 
				
			||||||
 | 
					        atomic::{AtomicBool, AtomicI32, Ordering},
 | 
				
			||||||
 | 
					        Barrier,
 | 
				
			||||||
 | 
					    };
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    #[test]
 | 
					    #[test]
 | 
				
			||||||
    pub fn test_spawn() {
 | 
					    pub fn test_spawn() {
 | 
				
			||||||
@ -281,4 +306,85 @@ mod tests {
 | 
				
			|||||||
        assert_eq!(outputs.len(), 100);
 | 
					        assert_eq!(outputs.len(), 100);
 | 
				
			||||||
        assert_eq!(count.load(Ordering::Relaxed), 100);
 | 
					        assert_eq!(count.load(Ordering::Relaxed), 100);
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[test]
 | 
				
			||||||
 | 
					    pub fn test_mixed_spawn_local_and_spawn() {
 | 
				
			||||||
 | 
					        let pool = TaskPool::new();
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        let foo = Box::new(42);
 | 
				
			||||||
 | 
					        let foo = &*foo;
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        let local_count = Arc::new(AtomicI32::new(0));
 | 
				
			||||||
 | 
					        let non_local_count = Arc::new(AtomicI32::new(0));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        let outputs = pool.scope(|scope| {
 | 
				
			||||||
 | 
					            for i in 0..100 {
 | 
				
			||||||
 | 
					                if i % 2 == 0 {
 | 
				
			||||||
 | 
					                    let count_clone = non_local_count.clone();
 | 
				
			||||||
 | 
					                    scope.spawn(async move {
 | 
				
			||||||
 | 
					                        if *foo != 42 {
 | 
				
			||||||
 | 
					                            panic!("not 42!?!?")
 | 
				
			||||||
 | 
					                        } else {
 | 
				
			||||||
 | 
					                            count_clone.fetch_add(1, Ordering::Relaxed);
 | 
				
			||||||
 | 
					                            *foo
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    });
 | 
				
			||||||
 | 
					                } else {
 | 
				
			||||||
 | 
					                    let count_clone = local_count.clone();
 | 
				
			||||||
 | 
					                    scope.spawn_local(async move {
 | 
				
			||||||
 | 
					                        if *foo != 42 {
 | 
				
			||||||
 | 
					                            panic!("not 42!?!?")
 | 
				
			||||||
 | 
					                        } else {
 | 
				
			||||||
 | 
					                            count_clone.fetch_add(1, Ordering::Relaxed);
 | 
				
			||||||
 | 
					                            *foo
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    });
 | 
				
			||||||
 | 
					                }
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					        });
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for output in &outputs {
 | 
				
			||||||
 | 
					            assert_eq!(*output, 42);
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert_eq!(outputs.len(), 100);
 | 
				
			||||||
 | 
					        assert_eq!(local_count.load(Ordering::Relaxed), 50);
 | 
				
			||||||
 | 
					        assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    #[test]
 | 
				
			||||||
 | 
					    pub fn test_thread_locality() {
 | 
				
			||||||
 | 
					        let pool = Arc::new(TaskPool::new());
 | 
				
			||||||
 | 
					        let count = Arc::new(AtomicI32::new(0));
 | 
				
			||||||
 | 
					        let barrier = Arc::new(Barrier::new(101));
 | 
				
			||||||
 | 
					        let thread_check_failed = Arc::new(AtomicBool::new(false));
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for _ in 0..100 {
 | 
				
			||||||
 | 
					            let inner_barrier = barrier.clone();
 | 
				
			||||||
 | 
					            let count_clone = count.clone();
 | 
				
			||||||
 | 
					            let inner_pool = pool.clone();
 | 
				
			||||||
 | 
					            let inner_thread_check_failed = thread_check_failed.clone();
 | 
				
			||||||
 | 
					            std::thread::spawn(move || {
 | 
				
			||||||
 | 
					                inner_pool.scope(|scope| {
 | 
				
			||||||
 | 
					                    let inner_count_clone = count_clone.clone();
 | 
				
			||||||
 | 
					                    scope.spawn(async move {
 | 
				
			||||||
 | 
					                        inner_count_clone.fetch_add(1, Ordering::Release);
 | 
				
			||||||
 | 
					                    });
 | 
				
			||||||
 | 
					                    let spawner = std::thread::current().id();
 | 
				
			||||||
 | 
					                    let inner_count_clone = count_clone.clone();
 | 
				
			||||||
 | 
					                    scope.spawn_local(async move {
 | 
				
			||||||
 | 
					                        inner_count_clone.fetch_add(1, Ordering::Release);
 | 
				
			||||||
 | 
					                        if std::thread::current().id() != spawner {
 | 
				
			||||||
 | 
					                            // NOTE: This check is using an atomic rather than simply panicing the thread to avoid deadlocking the barrier on failure
 | 
				
			||||||
 | 
					                            inner_thread_check_failed.store(true, Ordering::Release);
 | 
				
			||||||
 | 
					                        }
 | 
				
			||||||
 | 
					                    });
 | 
				
			||||||
 | 
					                });
 | 
				
			||||||
 | 
					                inner_barrier.wait();
 | 
				
			||||||
 | 
					            });
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        barrier.wait();
 | 
				
			||||||
 | 
					        assert!(!thread_check_failed.load(Ordering::Acquire));
 | 
				
			||||||
 | 
					        assert_eq!(count.load(Ordering::Acquire), 200);
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
		Reference in New Issue
	
	Block a user