Support for !Send tasks (#1216)

Support for !Send tasks
This commit is contained in:
Alec Deason 2021-01-18 13:48:28 -08:00 committed by GitHub
parent e7dab0c359
commit 3c5f1f8a80
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 161 additions and 48 deletions

View File

@ -102,6 +102,13 @@ impl TaskPool {
}); });
FakeTask FakeTask
} }
pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> FakeTask
where
T: 'static,
{
self.spawn(future)
}
} }
#[derive(Debug)] #[derive(Debug)]

View File

@ -95,6 +95,10 @@ pub struct TaskPool {
} }
impl TaskPool { impl TaskPool {
thread_local! {
static LOCAL_EXECUTOR: async_executor::LocalExecutor<'static> = async_executor::LocalExecutor::new();
}
/// Create a `TaskPool` with the default configuration. /// Create a `TaskPool` with the default configuration.
pub fn new() -> Self { pub fn new() -> Self {
TaskPoolBuilder::new().build() TaskPoolBuilder::new().build()
@ -162,58 +166,63 @@ impl TaskPool {
F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send, F: FnOnce(&mut Scope<'scope, T>) + 'scope + Send,
T: Send + 'static, T: Send + 'static,
{ {
// SAFETY: This function blocks until all futures complete, so this future must return TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
// before this function returns. However, rust has no way of knowing // SAFETY: This function blocks until all futures complete, so this future must return
// this so we must convert to 'static here to appease the compiler as it is unable to // before this function returns. However, rust has no way of knowing
// validate safety.
let executor: &async_executor::Executor = &*self.executor;
let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) };
let mut scope = Scope {
executor,
spawned: Vec::new(),
};
f(&mut scope);
if scope.spawned.is_empty() {
Vec::default()
} else if scope.spawned.len() == 1 {
vec![future::block_on(&mut scope.spawned[0])]
} else {
let fut = async move {
let mut results = Vec::with_capacity(scope.spawned.len());
for task in scope.spawned {
results.push(task.await);
}
results
};
// Pin the future on the stack.
pin!(fut);
// SAFETY: This function blocks until all futures complete, so we do not read/write the
// data from futures outside of the 'scope lifetime. However, rust has no way of knowing
// this so we must convert to 'static here to appease the compiler as it is unable to // this so we must convert to 'static here to appease the compiler as it is unable to
// validate safety. // validate safety.
let fut: Pin<&mut (dyn Future<Output = Vec<T>> + Send)> = fut; let executor: &async_executor::Executor = &*self.executor;
let fut: Pin<&'static mut (dyn Future<Output = Vec<T>> + Send + 'static)> = let executor: &'scope async_executor::Executor = unsafe { mem::transmute(executor) };
unsafe { mem::transmute(fut) }; let local_executor: &'scope async_executor::LocalExecutor =
unsafe { mem::transmute(local_executor) };
let mut scope = Scope {
executor,
local_executor,
spawned: Vec::new(),
};
// The thread that calls scope() will participate in driving tasks in the pool forward f(&mut scope);
// until the tasks that are spawned by this scope() call complete. (If the caller of scope()
// happens to be a thread in this thread pool, and we only have one thread in the pool, then if scope.spawned.is_empty() {
// simply calling future::block_on(spawned) would deadlock.) Vec::default()
let mut spawned = self.executor.spawn(fut); } else if scope.spawned.len() == 1 {
loop { vec![future::block_on(&mut scope.spawned[0])]
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) { } else {
break result; let fut = async move {
let mut results = Vec::with_capacity(scope.spawned.len());
for task in scope.spawned {
results.push(task.await);
}
results
};
// Pin the futures on the stack.
pin!(fut);
// SAFETY: This function blocks until all futures complete, so we do not read/write the
// data from futures outside of the 'scope lifetime. However, rust has no way of knowing
// this so we must convert to 'static here to appease the compiler as it is unable to
// validate safety.
let fut: Pin<&mut (dyn Future<Output = Vec<T>>)> = fut;
let fut: Pin<&'static mut (dyn Future<Output = Vec<T>> + 'static)> =
unsafe { mem::transmute(fut) };
// The thread that calls scope() will participate in driving tasks in the pool forward
// until the tasks that are spawned by this scope() call complete. (If the caller of scope()
// happens to be a thread in this thread pool, and we only have one thread in the pool, then
// simply calling future::block_on(spawned) would deadlock.)
let mut spawned = local_executor.spawn(fut);
loop {
if let Some(result) = future::block_on(future::poll_once(&mut spawned)) {
break result;
};
self.executor.try_tick();
local_executor.try_tick();
} }
self.executor.try_tick();
} }
} })
} }
/// Spawns a static future onto the thread pool. The returned Task is a future. It can also be /// Spawns a static future onto the thread pool. The returned Task is a future. It can also be
@ -225,6 +234,13 @@ impl TaskPool {
{ {
Task::new(self.executor.spawn(future)) Task::new(self.executor.spawn(future))
} }
pub fn spawn_local<T>(&self, future: impl Future<Output = T> + 'static) -> Task<T>
where
T: 'static,
{
Task::new(TaskPool::LOCAL_EXECUTOR.with(|executor| executor.spawn(future)))
}
} }
impl Default for TaskPool { impl Default for TaskPool {
@ -236,6 +252,7 @@ impl Default for TaskPool {
#[derive(Debug)] #[derive(Debug)]
pub struct Scope<'scope, T> { pub struct Scope<'scope, T> {
executor: &'scope async_executor::Executor<'scope>, executor: &'scope async_executor::Executor<'scope>,
local_executor: &'scope async_executor::LocalExecutor<'scope>,
spawned: Vec<async_executor::Task<T>>, spawned: Vec<async_executor::Task<T>>,
} }
@ -244,12 +261,20 @@ impl<'scope, T: Send + 'scope> Scope<'scope, T> {
let task = self.executor.spawn(f); let task = self.executor.spawn(f);
self.spawned.push(task); self.spawned.push(task);
} }
pub fn spawn_local<Fut: Future<Output = T> + 'scope>(&mut self, f: Fut) {
let task = self.local_executor.spawn(f);
self.spawned.push(task);
}
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::sync::atomic::{AtomicI32, Ordering}; use std::sync::{
atomic::{AtomicBool, AtomicI32, Ordering},
Barrier,
};
#[test] #[test]
pub fn test_spawn() { pub fn test_spawn() {
@ -281,4 +306,85 @@ mod tests {
assert_eq!(outputs.len(), 100); assert_eq!(outputs.len(), 100);
assert_eq!(count.load(Ordering::Relaxed), 100); assert_eq!(count.load(Ordering::Relaxed), 100);
} }
#[test]
pub fn test_mixed_spawn_local_and_spawn() {
let pool = TaskPool::new();
let foo = Box::new(42);
let foo = &*foo;
let local_count = Arc::new(AtomicI32::new(0));
let non_local_count = Arc::new(AtomicI32::new(0));
let outputs = pool.scope(|scope| {
for i in 0..100 {
if i % 2 == 0 {
let count_clone = non_local_count.clone();
scope.spawn(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
count_clone.fetch_add(1, Ordering::Relaxed);
*foo
}
});
} else {
let count_clone = local_count.clone();
scope.spawn_local(async move {
if *foo != 42 {
panic!("not 42!?!?")
} else {
count_clone.fetch_add(1, Ordering::Relaxed);
*foo
}
});
}
}
});
for output in &outputs {
assert_eq!(*output, 42);
}
assert_eq!(outputs.len(), 100);
assert_eq!(local_count.load(Ordering::Relaxed), 50);
assert_eq!(non_local_count.load(Ordering::Relaxed), 50);
}
#[test]
pub fn test_thread_locality() {
let pool = Arc::new(TaskPool::new());
let count = Arc::new(AtomicI32::new(0));
let barrier = Arc::new(Barrier::new(101));
let thread_check_failed = Arc::new(AtomicBool::new(false));
for _ in 0..100 {
let inner_barrier = barrier.clone();
let count_clone = count.clone();
let inner_pool = pool.clone();
let inner_thread_check_failed = thread_check_failed.clone();
std::thread::spawn(move || {
inner_pool.scope(|scope| {
let inner_count_clone = count_clone.clone();
scope.spawn(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
});
let spawner = std::thread::current().id();
let inner_count_clone = count_clone.clone();
scope.spawn_local(async move {
inner_count_clone.fetch_add(1, Ordering::Release);
if std::thread::current().id() != spawner {
// NOTE: This check is using an atomic rather than simply panicing the thread to avoid deadlocking the barrier on failure
inner_thread_check_failed.store(true, Ordering::Release);
}
});
});
inner_barrier.wait();
});
}
barrier.wait();
assert!(!thread_check_failed.load(Ordering::Acquire));
assert_eq!(count.load(Ordering::Acquire), 200);
}
} }