Add thread create/destroy callbacks to TaskPool (#6561)

# Objective
Fix #1991. Allow users to have a bit more control over the creation and finalization of the threads in `TaskPool`.

## Solution
Add new methods to `TaskPoolBuilder` that expose callbacks that are called to initialize and finalize each thread in the `TaskPool`.

Unlike the proposed solution in #1991, the callback is argument-less. If an an identifier is needed, `std:🧵:current` should provide that information easily.

Added a unit test to ensure that they're being called correctly.
This commit is contained in:
James Liu 2022-12-20 16:17:02 +00:00
parent e8b28547bf
commit 53a5bbe2d5

View File

@ -12,8 +12,18 @@ use futures_lite::{future, pin, FutureExt};
use crate::Task; use crate::Task;
struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
impl Drop for CallOnDrop {
fn drop(&mut self) {
if let Some(call) = self.0.as_ref() {
call();
}
}
}
/// Used to create a [`TaskPool`] /// Used to create a [`TaskPool`]
#[derive(Debug, Default, Clone)] #[derive(Default)]
#[must_use] #[must_use]
pub struct TaskPoolBuilder { pub struct TaskPoolBuilder {
/// If set, we'll set up the thread pool to use at most `num_threads` threads. /// If set, we'll set up the thread pool to use at most `num_threads` threads.
@ -24,6 +34,9 @@ pub struct TaskPoolBuilder {
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will /// Allows customizing the name of the threads - helpful for debugging. If set, threads will
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)" /// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
thread_name: Option<String>, thread_name: Option<String>,
on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
} }
impl TaskPoolBuilder { impl TaskPoolBuilder {
@ -52,13 +65,27 @@ impl TaskPoolBuilder {
self self
} }
/// Sets a callback that is invoked once for every created thread as it starts.
///
/// This is called on the thread itself and has access to all thread-local storage.
/// This will block running async tasks on the thread until the callback completes.
pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
self.on_thread_spawn = Some(Arc::new(f));
self
}
/// Sets a callback that is invoked once for every created thread as it terminates.
///
/// This is called on the thread itself and has access to all thread-local storage.
/// This will block thread termination until the callback completes.
pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
self.on_thread_destroy = Some(Arc::new(f));
self
}
/// Creates a new [`TaskPool`] based on the current options. /// Creates a new [`TaskPool`] based on the current options.
pub fn build(self) -> TaskPool { pub fn build(self) -> TaskPool {
TaskPool::new_internal( TaskPool::new_internal(self)
self.num_threads,
self.stack_size,
self.thread_name.as_deref(),
)
} }
} }
@ -88,36 +115,42 @@ impl TaskPool {
TaskPoolBuilder::new().build() TaskPoolBuilder::new().build()
} }
fn new_internal( fn new_internal(builder: TaskPoolBuilder) -> Self {
num_threads: Option<usize>,
stack_size: Option<usize>,
thread_name: Option<&str>,
) -> Self {
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>(); let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
let executor = Arc::new(async_executor::Executor::new()); let executor = Arc::new(async_executor::Executor::new());
let num_threads = num_threads.unwrap_or_else(crate::available_parallelism); let num_threads = builder
.num_threads
.unwrap_or_else(crate::available_parallelism);
let threads = (0..num_threads) let threads = (0..num_threads)
.map(|i| { .map(|i| {
let ex = Arc::clone(&executor); let ex = Arc::clone(&executor);
let shutdown_rx = shutdown_rx.clone(); let shutdown_rx = shutdown_rx.clone();
let thread_name = if let Some(thread_name) = thread_name { let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
format!("{thread_name} ({i})") format!("{thread_name} ({i})")
} else { } else {
format!("TaskPool ({i})") format!("TaskPool ({i})")
}; };
let mut thread_builder = thread::Builder::new().name(thread_name); let mut thread_builder = thread::Builder::new().name(thread_name);
if let Some(stack_size) = stack_size { if let Some(stack_size) = builder.stack_size {
thread_builder = thread_builder.stack_size(stack_size); thread_builder = thread_builder.stack_size(stack_size);
} }
let on_thread_spawn = builder.on_thread_spawn.clone();
let on_thread_destroy = builder.on_thread_destroy.clone();
thread_builder thread_builder
.spawn(move || { .spawn(move || {
TaskPool::LOCAL_EXECUTOR.with(|local_executor| { TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
if let Some(on_thread_spawn) = on_thread_spawn {
on_thread_spawn();
drop(on_thread_spawn);
}
let _destructor = CallOnDrop(on_thread_destroy);
loop { loop {
let res = std::panic::catch_unwind(|| { let res = std::panic::catch_unwind(|| {
let tick_forever = async move { let tick_forever = async move {
@ -452,6 +485,57 @@ mod tests {
assert_eq!(count.load(Ordering::Relaxed), 100); assert_eq!(count.load(Ordering::Relaxed), 100);
} }
#[test]
fn test_thread_callbacks() {
let counter = Arc::new(AtomicI32::new(0));
let start_counter = counter.clone();
{
let barrier = Arc::new(Barrier::new(11));
let last_barrier = barrier.clone();
// Build and immediately drop to terminate
let _pool = TaskPoolBuilder::new()
.num_threads(10)
.on_thread_spawn(move || {
start_counter.fetch_add(1, Ordering::Relaxed);
barrier.clone().wait();
})
.build();
last_barrier.wait();
assert_eq!(10, counter.load(Ordering::Relaxed));
}
assert_eq!(10, counter.load(Ordering::Relaxed));
let end_counter = counter.clone();
{
let _pool = TaskPoolBuilder::new()
.num_threads(20)
.on_thread_destroy(move || {
end_counter.fetch_sub(1, Ordering::Relaxed);
})
.build();
assert_eq!(10, counter.load(Ordering::Relaxed));
}
assert_eq!(-10, counter.load(Ordering::Relaxed));
let start_counter = counter.clone();
let end_counter = counter.clone();
{
let barrier = Arc::new(Barrier::new(6));
let last_barrier = barrier.clone();
let _pool = TaskPoolBuilder::new()
.num_threads(5)
.on_thread_spawn(move || {
start_counter.fetch_add(1, Ordering::Relaxed);
barrier.wait();
})
.on_thread_destroy(move || {
end_counter.fetch_sub(1, Ordering::Relaxed);
})
.build();
last_barrier.wait();
assert_eq!(-5, counter.load(Ordering::Relaxed));
}
assert_eq!(-10, counter.load(Ordering::Relaxed));
}
#[test] #[test]
fn test_mixed_spawn_on_scope_and_spawn() { fn test_mixed_spawn_on_scope_and_spawn() {
let pool = TaskPool::new(); let pool = TaskPool::new();