Add thread create/destroy callbacks to TaskPool (#6561)
# Objective Fix #1991. Allow users to have a bit more control over the creation and finalization of the threads in `TaskPool`. ## Solution Add new methods to `TaskPoolBuilder` that expose callbacks that are called to initialize and finalize each thread in the `TaskPool`. Unlike the proposed solution in #1991, the callback is argument-less. If an an identifier is needed, `std:🧵:current` should provide that information easily. Added a unit test to ensure that they're being called correctly.
This commit is contained in:
parent
e8b28547bf
commit
53a5bbe2d5
@ -12,8 +12,18 @@ use futures_lite::{future, pin, FutureExt};
|
||||
|
||||
use crate::Task;
|
||||
|
||||
struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
|
||||
|
||||
impl Drop for CallOnDrop {
|
||||
fn drop(&mut self) {
|
||||
if let Some(call) = self.0.as_ref() {
|
||||
call();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Used to create a [`TaskPool`]
|
||||
#[derive(Debug, Default, Clone)]
|
||||
#[derive(Default)]
|
||||
#[must_use]
|
||||
pub struct TaskPoolBuilder {
|
||||
/// If set, we'll set up the thread pool to use at most `num_threads` threads.
|
||||
@ -24,6 +34,9 @@ pub struct TaskPoolBuilder {
|
||||
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
|
||||
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
|
||||
thread_name: Option<String>,
|
||||
|
||||
on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
|
||||
on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
|
||||
}
|
||||
|
||||
impl TaskPoolBuilder {
|
||||
@ -52,13 +65,27 @@ impl TaskPoolBuilder {
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a callback that is invoked once for every created thread as it starts.
|
||||
///
|
||||
/// This is called on the thread itself and has access to all thread-local storage.
|
||||
/// This will block running async tasks on the thread until the callback completes.
|
||||
pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
|
||||
self.on_thread_spawn = Some(Arc::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets a callback that is invoked once for every created thread as it terminates.
|
||||
///
|
||||
/// This is called on the thread itself and has access to all thread-local storage.
|
||||
/// This will block thread termination until the callback completes.
|
||||
pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
|
||||
self.on_thread_destroy = Some(Arc::new(f));
|
||||
self
|
||||
}
|
||||
|
||||
/// Creates a new [`TaskPool`] based on the current options.
|
||||
pub fn build(self) -> TaskPool {
|
||||
TaskPool::new_internal(
|
||||
self.num_threads,
|
||||
self.stack_size,
|
||||
self.thread_name.as_deref(),
|
||||
)
|
||||
TaskPool::new_internal(self)
|
||||
}
|
||||
}
|
||||
|
||||
@ -88,36 +115,42 @@ impl TaskPool {
|
||||
TaskPoolBuilder::new().build()
|
||||
}
|
||||
|
||||
fn new_internal(
|
||||
num_threads: Option<usize>,
|
||||
stack_size: Option<usize>,
|
||||
thread_name: Option<&str>,
|
||||
) -> Self {
|
||||
fn new_internal(builder: TaskPoolBuilder) -> Self {
|
||||
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
|
||||
|
||||
let executor = Arc::new(async_executor::Executor::new());
|
||||
|
||||
let num_threads = num_threads.unwrap_or_else(crate::available_parallelism);
|
||||
let num_threads = builder
|
||||
.num_threads
|
||||
.unwrap_or_else(crate::available_parallelism);
|
||||
|
||||
let threads = (0..num_threads)
|
||||
.map(|i| {
|
||||
let ex = Arc::clone(&executor);
|
||||
let shutdown_rx = shutdown_rx.clone();
|
||||
|
||||
let thread_name = if let Some(thread_name) = thread_name {
|
||||
let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
|
||||
format!("{thread_name} ({i})")
|
||||
} else {
|
||||
format!("TaskPool ({i})")
|
||||
};
|
||||
let mut thread_builder = thread::Builder::new().name(thread_name);
|
||||
|
||||
if let Some(stack_size) = stack_size {
|
||||
if let Some(stack_size) = builder.stack_size {
|
||||
thread_builder = thread_builder.stack_size(stack_size);
|
||||
}
|
||||
|
||||
let on_thread_spawn = builder.on_thread_spawn.clone();
|
||||
let on_thread_destroy = builder.on_thread_destroy.clone();
|
||||
|
||||
thread_builder
|
||||
.spawn(move || {
|
||||
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
|
||||
if let Some(on_thread_spawn) = on_thread_spawn {
|
||||
on_thread_spawn();
|
||||
drop(on_thread_spawn);
|
||||
}
|
||||
let _destructor = CallOnDrop(on_thread_destroy);
|
||||
loop {
|
||||
let res = std::panic::catch_unwind(|| {
|
||||
let tick_forever = async move {
|
||||
@ -452,6 +485,57 @@ mod tests {
|
||||
assert_eq!(count.load(Ordering::Relaxed), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thread_callbacks() {
|
||||
let counter = Arc::new(AtomicI32::new(0));
|
||||
let start_counter = counter.clone();
|
||||
{
|
||||
let barrier = Arc::new(Barrier::new(11));
|
||||
let last_barrier = barrier.clone();
|
||||
// Build and immediately drop to terminate
|
||||
let _pool = TaskPoolBuilder::new()
|
||||
.num_threads(10)
|
||||
.on_thread_spawn(move || {
|
||||
start_counter.fetch_add(1, Ordering::Relaxed);
|
||||
barrier.clone().wait();
|
||||
})
|
||||
.build();
|
||||
last_barrier.wait();
|
||||
assert_eq!(10, counter.load(Ordering::Relaxed));
|
||||
}
|
||||
assert_eq!(10, counter.load(Ordering::Relaxed));
|
||||
let end_counter = counter.clone();
|
||||
{
|
||||
let _pool = TaskPoolBuilder::new()
|
||||
.num_threads(20)
|
||||
.on_thread_destroy(move || {
|
||||
end_counter.fetch_sub(1, Ordering::Relaxed);
|
||||
})
|
||||
.build();
|
||||
assert_eq!(10, counter.load(Ordering::Relaxed));
|
||||
}
|
||||
assert_eq!(-10, counter.load(Ordering::Relaxed));
|
||||
let start_counter = counter.clone();
|
||||
let end_counter = counter.clone();
|
||||
{
|
||||
let barrier = Arc::new(Barrier::new(6));
|
||||
let last_barrier = barrier.clone();
|
||||
let _pool = TaskPoolBuilder::new()
|
||||
.num_threads(5)
|
||||
.on_thread_spawn(move || {
|
||||
start_counter.fetch_add(1, Ordering::Relaxed);
|
||||
barrier.wait();
|
||||
})
|
||||
.on_thread_destroy(move || {
|
||||
end_counter.fetch_sub(1, Ordering::Relaxed);
|
||||
})
|
||||
.build();
|
||||
last_barrier.wait();
|
||||
assert_eq!(-5, counter.load(Ordering::Relaxed));
|
||||
}
|
||||
assert_eq!(-10, counter.load(Ordering::Relaxed));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_spawn_on_scope_and_spawn() {
|
||||
let pool = TaskPool::new();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user