Add thread create/destroy callbacks to TaskPool (#6561)
# Objective Fix #1991. Allow users to have a bit more control over the creation and finalization of the threads in `TaskPool`. ## Solution Add new methods to `TaskPoolBuilder` that expose callbacks that are called to initialize and finalize each thread in the `TaskPool`. Unlike the proposed solution in #1991, the callback is argument-less. If an an identifier is needed, `std:🧵:current` should provide that information easily. Added a unit test to ensure that they're being called correctly.
This commit is contained in:
		
							parent
							
								
									e8b28547bf
								
							
						
					
					
						commit
						53a5bbe2d5
					
				@ -12,8 +12,18 @@ use futures_lite::{future, pin, FutureExt};
 | 
			
		||||
 | 
			
		||||
use crate::Task;
 | 
			
		||||
 | 
			
		||||
struct CallOnDrop(Option<Arc<dyn Fn() + Send + Sync + 'static>>);
 | 
			
		||||
 | 
			
		||||
impl Drop for CallOnDrop {
 | 
			
		||||
    fn drop(&mut self) {
 | 
			
		||||
        if let Some(call) = self.0.as_ref() {
 | 
			
		||||
            call();
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/// Used to create a [`TaskPool`]
 | 
			
		||||
#[derive(Debug, Default, Clone)]
 | 
			
		||||
#[derive(Default)]
 | 
			
		||||
#[must_use]
 | 
			
		||||
pub struct TaskPoolBuilder {
 | 
			
		||||
    /// If set, we'll set up the thread pool to use at most `num_threads` threads.
 | 
			
		||||
@ -24,6 +34,9 @@ pub struct TaskPoolBuilder {
 | 
			
		||||
    /// Allows customizing the name of the threads - helpful for debugging. If set, threads will
 | 
			
		||||
    /// be named <thread_name> (<thread_index>), i.e. "MyThreadPool (2)"
 | 
			
		||||
    thread_name: Option<String>,
 | 
			
		||||
 | 
			
		||||
    on_thread_spawn: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
 | 
			
		||||
    on_thread_destroy: Option<Arc<dyn Fn() + Send + Sync + 'static>>,
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
impl TaskPoolBuilder {
 | 
			
		||||
@ -52,13 +65,27 @@ impl TaskPoolBuilder {
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Sets a callback that is invoked once for every created thread as it starts.
 | 
			
		||||
    ///
 | 
			
		||||
    /// This is called on the thread itself and has access to all thread-local storage.
 | 
			
		||||
    /// This will block running async tasks on the thread until the callback completes.
 | 
			
		||||
    pub fn on_thread_spawn(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
 | 
			
		||||
        self.on_thread_spawn = Some(Arc::new(f));
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Sets a callback that is invoked once for every created thread as it terminates.
 | 
			
		||||
    ///
 | 
			
		||||
    /// This is called on the thread itself and has access to all thread-local storage.
 | 
			
		||||
    /// This will block thread termination until the callback completes.
 | 
			
		||||
    pub fn on_thread_destroy(mut self, f: impl Fn() + Send + Sync + 'static) -> Self {
 | 
			
		||||
        self.on_thread_destroy = Some(Arc::new(f));
 | 
			
		||||
        self
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /// Creates a new [`TaskPool`] based on the current options.
 | 
			
		||||
    pub fn build(self) -> TaskPool {
 | 
			
		||||
        TaskPool::new_internal(
 | 
			
		||||
            self.num_threads,
 | 
			
		||||
            self.stack_size,
 | 
			
		||||
            self.thread_name.as_deref(),
 | 
			
		||||
        )
 | 
			
		||||
        TaskPool::new_internal(self)
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -88,36 +115,42 @@ impl TaskPool {
 | 
			
		||||
        TaskPoolBuilder::new().build()
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    fn new_internal(
 | 
			
		||||
        num_threads: Option<usize>,
 | 
			
		||||
        stack_size: Option<usize>,
 | 
			
		||||
        thread_name: Option<&str>,
 | 
			
		||||
    ) -> Self {
 | 
			
		||||
    fn new_internal(builder: TaskPoolBuilder) -> Self {
 | 
			
		||||
        let (shutdown_tx, shutdown_rx) = async_channel::unbounded::<()>();
 | 
			
		||||
 | 
			
		||||
        let executor = Arc::new(async_executor::Executor::new());
 | 
			
		||||
 | 
			
		||||
        let num_threads = num_threads.unwrap_or_else(crate::available_parallelism);
 | 
			
		||||
        let num_threads = builder
 | 
			
		||||
            .num_threads
 | 
			
		||||
            .unwrap_or_else(crate::available_parallelism);
 | 
			
		||||
 | 
			
		||||
        let threads = (0..num_threads)
 | 
			
		||||
            .map(|i| {
 | 
			
		||||
                let ex = Arc::clone(&executor);
 | 
			
		||||
                let shutdown_rx = shutdown_rx.clone();
 | 
			
		||||
 | 
			
		||||
                let thread_name = if let Some(thread_name) = thread_name {
 | 
			
		||||
                let thread_name = if let Some(thread_name) = builder.thread_name.as_deref() {
 | 
			
		||||
                    format!("{thread_name} ({i})")
 | 
			
		||||
                } else {
 | 
			
		||||
                    format!("TaskPool ({i})")
 | 
			
		||||
                };
 | 
			
		||||
                let mut thread_builder = thread::Builder::new().name(thread_name);
 | 
			
		||||
 | 
			
		||||
                if let Some(stack_size) = stack_size {
 | 
			
		||||
                if let Some(stack_size) = builder.stack_size {
 | 
			
		||||
                    thread_builder = thread_builder.stack_size(stack_size);
 | 
			
		||||
                }
 | 
			
		||||
 | 
			
		||||
                let on_thread_spawn = builder.on_thread_spawn.clone();
 | 
			
		||||
                let on_thread_destroy = builder.on_thread_destroy.clone();
 | 
			
		||||
 | 
			
		||||
                thread_builder
 | 
			
		||||
                    .spawn(move || {
 | 
			
		||||
                        TaskPool::LOCAL_EXECUTOR.with(|local_executor| {
 | 
			
		||||
                            if let Some(on_thread_spawn) = on_thread_spawn {
 | 
			
		||||
                                on_thread_spawn();
 | 
			
		||||
                                drop(on_thread_spawn);
 | 
			
		||||
                            }
 | 
			
		||||
                            let _destructor = CallOnDrop(on_thread_destroy);
 | 
			
		||||
                            loop {
 | 
			
		||||
                                let res = std::panic::catch_unwind(|| {
 | 
			
		||||
                                    let tick_forever = async move {
 | 
			
		||||
@ -452,6 +485,57 @@ mod tests {
 | 
			
		||||
        assert_eq!(count.load(Ordering::Relaxed), 100);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_thread_callbacks() {
 | 
			
		||||
        let counter = Arc::new(AtomicI32::new(0));
 | 
			
		||||
        let start_counter = counter.clone();
 | 
			
		||||
        {
 | 
			
		||||
            let barrier = Arc::new(Barrier::new(11));
 | 
			
		||||
            let last_barrier = barrier.clone();
 | 
			
		||||
            // Build and immediately drop to terminate
 | 
			
		||||
            let _pool = TaskPoolBuilder::new()
 | 
			
		||||
                .num_threads(10)
 | 
			
		||||
                .on_thread_spawn(move || {
 | 
			
		||||
                    start_counter.fetch_add(1, Ordering::Relaxed);
 | 
			
		||||
                    barrier.clone().wait();
 | 
			
		||||
                })
 | 
			
		||||
                .build();
 | 
			
		||||
            last_barrier.wait();
 | 
			
		||||
            assert_eq!(10, counter.load(Ordering::Relaxed));
 | 
			
		||||
        }
 | 
			
		||||
        assert_eq!(10, counter.load(Ordering::Relaxed));
 | 
			
		||||
        let end_counter = counter.clone();
 | 
			
		||||
        {
 | 
			
		||||
            let _pool = TaskPoolBuilder::new()
 | 
			
		||||
                .num_threads(20)
 | 
			
		||||
                .on_thread_destroy(move || {
 | 
			
		||||
                    end_counter.fetch_sub(1, Ordering::Relaxed);
 | 
			
		||||
                })
 | 
			
		||||
                .build();
 | 
			
		||||
            assert_eq!(10, counter.load(Ordering::Relaxed));
 | 
			
		||||
        }
 | 
			
		||||
        assert_eq!(-10, counter.load(Ordering::Relaxed));
 | 
			
		||||
        let start_counter = counter.clone();
 | 
			
		||||
        let end_counter = counter.clone();
 | 
			
		||||
        {
 | 
			
		||||
            let barrier = Arc::new(Barrier::new(6));
 | 
			
		||||
            let last_barrier = barrier.clone();
 | 
			
		||||
            let _pool = TaskPoolBuilder::new()
 | 
			
		||||
                .num_threads(5)
 | 
			
		||||
                .on_thread_spawn(move || {
 | 
			
		||||
                    start_counter.fetch_add(1, Ordering::Relaxed);
 | 
			
		||||
                    barrier.wait();
 | 
			
		||||
                })
 | 
			
		||||
                .on_thread_destroy(move || {
 | 
			
		||||
                    end_counter.fetch_sub(1, Ordering::Relaxed);
 | 
			
		||||
                })
 | 
			
		||||
                .build();
 | 
			
		||||
            last_barrier.wait();
 | 
			
		||||
            assert_eq!(-5, counter.load(Ordering::Relaxed));
 | 
			
		||||
        }
 | 
			
		||||
        assert_eq!(-10, counter.load(Ordering::Relaxed));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    #[test]
 | 
			
		||||
    fn test_mixed_spawn_on_scope_and_spawn() {
 | 
			
		||||
        let pool = TaskPool::new();
 | 
			
		||||
 | 
			
		||||
		Loading…
	
		Reference in New Issue
	
	Block a user