Add thread create/destroy callbacks to TaskPool (#6561)
# Objective Fix #1991. Allow users to have a bit more control over the creation and finalization of the threads in `TaskPool`. ## Solution Add new methods to `TaskPoolBuilder` that expose callbacks that are called to initialize and finalize each thread in the `TaskPool`. Unlike the proposed solution in #1991, the callback is argument-less. If an an identifier is needed, `std:🧵:current` should provide that information easily. Added a unit test to ensure that they're being called correctly.
This commit is contained in:
parent
e8b28547bf
commit
53a5bbe2d5
@ -12,8 +12,18 @@ use futures_lite::{future, pin, FutureExt};
|
|||||||
|
|
||||||
use crate::Task;
|
use crate::Task;
|
||||||
|
|
||||||
|
struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
|
||||||
|
|
||||||
|
impl Drop for CallOnDrop {
|
||||||
|
fn drop(&mut self) {
|
||||||
|
if let Some(call) = self.0.as_ref() {
|
||||||
|
call();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Used to create a [`TaskPool`]
|
/// Used to create a [`TaskPool`]
|
||||||
#[derive(Debug, Default, Clone)]
|
#[derive(Default)]
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub struct TaskPoolBuilder {
|
pub struct TaskPoolBuilder {
|
||||||
/// If set, we'll set up the thread pool to use at most `num_threads` threads.
|
/// If set, we'll set up the thread pool to use at most `num_threads` threads.
|
||||||
@ -24,6 +34,9 @@ pub struct TaskPoolBuilder {
|
|||||||
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
|
/// Allows customizing the name of the threads - helpful for debugging. If set, threads will
|
||||||
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
|
/// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
|
||||||
thread_name: Option<String>,
|
thread_name: Option<String>,
|
||||||
|
|
||||||
|
on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
|
||||||
|
on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TaskPoolBuilder {
|
impl TaskPoolBuilder {
|
||||||
@ -52,13 +65,27 @@ impl TaskPoolBuilder {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sets a callback that is invoked once for every created thread as it starts.
|
||||||
|
///
|
||||||
|
/// This is called on the thread itself and has access to all thread-local storage.
|
||||||
|
/// This will block running async tasks on the thread until the callback completes.
|
||||||
|
pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
|
||||||
|
self.on_thread_spawn = Some(Arc::new(f));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets a callback that is invoked once for every created thread as it terminates.
|
||||||
|
///
|
||||||
|
/// This is called on the thread itself and has access to all thread-local storage.
|
||||||
|
/// This will block thread termination until the callback completes.
|
||||||
|
pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
|
||||||
|
self.on_thread_destroy = Some(Arc::new(f));
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Creates a new [`TaskPool`] based on the current options.
|
/// Creates a new [`TaskPool`] based on the current options.
|
||||||
pub fn build(self) -> TaskPool {
|
pub fn build(self) -> TaskPool {
|
||||||
TaskPool::new_internal(
|
TaskPool::new_internal(self)
|
||||||
self.num_threads,
|
|
||||||
self.stack_size,
|
|
||||||
self.thread_name.as_deref(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -88,36 +115,42 @@ impl TaskPool {
|
|||||||
TaskPoolBuilder::new().build()
|
TaskPoolBuilder::new().build()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_internal(
|
fn new_internal(builder: TaskPoolBuilder) -> Self {
|
||||||
num_threads: Option<usize>,
|
|
||||||
stack_size: Option<usize>,
|
|
||||||
thread_name: Option<&str>,
|
|
||||||
) -> Self {
|
|
||||||
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
|
let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
|
||||||
|
|
||||||
let executor = Arc::new(async_executor::Executor::new());
|
let executor = Arc::new(async_executor::Executor::new());
|
||||||
|
|
||||||
let num_threads = num_threads.unwrap_or_else(crate::available_parallelism);
|
let num_threads = builder
|
||||||
|
.num_threads
|
||||||
|
.unwrap_or_else(crate::available_parallelism);
|
||||||
|
|
||||||
let threads = (0..num_threads)
|
let threads = (0..num_threads)
|
||||||
.map(|i| {
|
.map(|i| {
|
||||||
let ex = Arc::clone(&executor);
|
let ex = Arc::clone(&executor);
|
||||||
let shutdown_rx = shutdown_rx.clone();
|
let shutdown_rx = shutdown_rx.clone();
|
||||||
|
|
||||||
let thread_name = if let Some(thread_name) = thread_name {
|
let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
|
||||||
format!("{thread_name} ({i})")
|
format!("{thread_name} ({i})")
|
||||||
} else {
|
} else {
|
||||||
format!("TaskPool ({i})")
|
format!("TaskPool ({i})")
|
||||||
};
|
};
|
||||||
let mut thread_builder = thread::Builder::new().name(thread_name);
|
let mut thread_builder = thread::Builder::new().name(thread_name);
|
||||||
|
|
||||||
if let Some(stack_size) = stack_size {
|
if let Some(stack_size) = builder.stack_size {
|
||||||
thread_builder = thread_builder.stack_size(stack_size);
|
thread_builder = thread_builder.stack_size(stack_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let on_thread_spawn = builder.on_thread_spawn.clone();
|
||||||
|
let on_thread_destroy = builder.on_thread_destroy.clone();
|
||||||
|
|
||||||
thread_builder
|
thread_builder
|
||||||
.spawn(move || {
|
.spawn(move || {
|
||||||
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
|
TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
|
||||||
|
if let Some(on_thread_spawn) = on_thread_spawn {
|
||||||
|
on_thread_spawn();
|
||||||
|
drop(on_thread_spawn);
|
||||||
|
}
|
||||||
|
let _destructor = CallOnDrop(on_thread_destroy);
|
||||||
loop {
|
loop {
|
||||||
let res = std::panic::catch_unwind(|| {
|
let res = std::panic::catch_unwind(|| {
|
||||||
let tick_forever = async move {
|
let tick_forever = async move {
|
||||||
@ -452,6 +485,57 @@ mod tests {
|
|||||||
assert_eq!(count.load(Ordering::Relaxed), 100);
|
assert_eq!(count.load(Ordering::Relaxed), 100);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_thread_callbacks() {
|
||||||
|
let counter = Arc::new(AtomicI32::new(0));
|
||||||
|
let start_counter = counter.clone();
|
||||||
|
{
|
||||||
|
let barrier = Arc::new(Barrier::new(11));
|
||||||
|
let last_barrier = barrier.clone();
|
||||||
|
// Build and immediately drop to terminate
|
||||||
|
let _pool = TaskPoolBuilder::new()
|
||||||
|
.num_threads(10)
|
||||||
|
.on_thread_spawn(move || {
|
||||||
|
start_counter.fetch_add(1, Ordering::Relaxed);
|
||||||
|
barrier.clone().wait();
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
last_barrier.wait();
|
||||||
|
assert_eq!(10, counter.load(Ordering::Relaxed));
|
||||||
|
}
|
||||||
|
assert_eq!(10, counter.load(Ordering::Relaxed));
|
||||||
|
let end_counter = counter.clone();
|
||||||
|
{
|
||||||
|
let _pool = TaskPoolBuilder::new()
|
||||||
|
.num_threads(20)
|
||||||
|
.on_thread_destroy(move || {
|
||||||
|
end_counter.fetch_sub(1, Ordering::Relaxed);
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
assert_eq!(10, counter.load(Ordering::Relaxed));
|
||||||
|
}
|
||||||
|
assert_eq!(-10, counter.load(Ordering::Relaxed));
|
||||||
|
let start_counter = counter.clone();
|
||||||
|
let end_counter = counter.clone();
|
||||||
|
{
|
||||||
|
let barrier = Arc::new(Barrier::new(6));
|
||||||
|
let last_barrier = barrier.clone();
|
||||||
|
let _pool = TaskPoolBuilder::new()
|
||||||
|
.num_threads(5)
|
||||||
|
.on_thread_spawn(move || {
|
||||||
|
start_counter.fetch_add(1, Ordering::Relaxed);
|
||||||
|
barrier.wait();
|
||||||
|
})
|
||||||
|
.on_thread_destroy(move || {
|
||||||
|
end_counter.fetch_sub(1, Ordering::Relaxed);
|
||||||
|
})
|
||||||
|
.build();
|
||||||
|
last_barrier.wait();
|
||||||
|
assert_eq!(-5, counter.load(Ordering::Relaxed));
|
||||||
|
}
|
||||||
|
assert_eq!(-10, counter.load(Ordering::Relaxed));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_mixed_spawn_on_scope_and_spawn() {
|
fn test_mixed_spawn_on_scope_and_spawn() {
|
||||||
let pool = TaskPool::new();
|
let pool = TaskPool::new();
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user