use std::{ collections::HashSet, sync::mpsc::{self, Receiver, SendError, Sender}, thread::{JoinHandle, Thread, ThreadId}, }; use thiserror::Error; use crate::thread::monitor::{Error as ThreadMonitorError, ThreadMonitor}; /// Error type for [`ThreadManager`]. #[derive(Error, Debug)] pub enum Error { /// Unable to receive from the monitor channel. #[error("Unable to receive from the monitor channel.")] RecvNotify, #[error("Unable to initialize internal monitor: {source}")] Monitor { #[from] source: ThreadMonitorError, }, } /// Reason why a call to [`ThreadManager::join_all`] returned. #[derive(Debug)] pub enum FinishReason { /// The monitor thread encountered an error. MonitorError { error: ThreadMonitorError }, /// A thread marked as `triggers_shutdown` finished. ThreadFinished { thread: Thread }, /// One of more threads panicked. ThreadPanic { threads: Vec<Thread> }, /// Unable to join with monnitor thread MonitorJoinError, /// Monitor was unable to send a notification MonitorSendError { error: SendError<Thread> }, } /// An abstraction for building simple multi-threaded applications. /// /// It's goal is to simplify the process of setting up an application composed /// of a fixed number of threads, adding in support for common tasks such as /// detecting panics and waiting for some threads to finish. /// /// It achieves this by focusing on a narrow use-case: /// /// - The application has a fixed number of threads. /// - The threads are all started at the same time during initialization. /// - A panic in any thread is expected to trigger a shut down of the /// application. /// - Some threads finishing should trigger a shut down of the application. /// /// For example, a web application has a main server thread and a metrics/stats /// server thread. /// /// [`ThreadManager`] is a higher level abstraction built on top of /// [`ThreadMonitor`], which is used to detect and handle panics in other /// threads. [`ThreadMonitor`] can be used directly if [`ThreadManager`] does /// not meet a use-case or low level control is needed. /// /// Limitations: /// /// - Only threads spawned directly through the manager interface can be /// monitored. The manager doesn't monitor other threads spawned separately, /// even if it's done inside one of the original spawned threads. /// - The return type of the [`JoinHandle`] is expected to be `()`. pub struct ThreadManager { join_handles: Vec<(JoinHandle<()>, bool)>, monitor: &'static ThreadMonitor, monitor_tx: Sender<Thread>, monitor_rx: Receiver<Thread>, } impl ThreadManager { /// Instantiates a new manager. /// /// A static reference to a [`ThreadMonitor`] that has already been /// initialized is expected. /// /// See the documentation of [`ThreadMonitor::init`] for more details. pub fn new(monitor: &'static ThreadMonitor) -> ThreadManager { let (monitor_tx, monitor_rx) = mpsc::channel(); ThreadManager { join_handles: vec![], monitor, monitor_rx, monitor_tx, } } /// Spawns a new thread to be managed. /// /// Unlike [`std::thread::spawn`], a copy of the [`Thread`] will be /// returned instead of a [`std::thread::JoinHandle`]. The handle will be /// used internally by the manager. pub fn spawn<F>(&mut self, f: F, triggers_shutdown: bool) -> Thread where F: FnOnce(), F: Send + 'static, { let monitor_tx_for_spawn = self.monitor_tx.clone(); let join_handle = std::thread::spawn(move || { f(); monitor_tx_for_spawn.send(std::thread::current()).unwrap(); }); let thread = join_handle.thread().clone(); self.join_handles.push((join_handle, triggers_shutdown)); thread } /// Waits for any thread marked as `triggers_shutdown` to finish or for one /// or more threads to panic. /// /// A [`FinishReason`] will be returned providing context into why the /// manager returned. pub fn join_all(self) -> Result<FinishReason, Error> { let monitor_tx_for_spawn = self.monitor_tx.clone(); let watched_thread_ids: HashSet<ThreadId> = self .join_handles .iter() .map(|(join_handle, _)| join_handle.thread().id()) .collect(); let watched_trigger_thread_ids: HashSet<ThreadId> = self .join_handles .iter() .filter_map(|(join_handle, triggers_shutdown)| { if *triggers_shutdown { Some(join_handle.thread().id()) } else { None } }) .collect(); let monitor_for_monitor_thread = self.monitor; let monitor_join_handle = std::thread::spawn(move || { let watch_result = monitor_for_monitor_thread.watch(Some(&watched_thread_ids)); let notification_result = monitor_tx_for_spawn.send(std::thread::current()); match notification_result { Ok(()) => Ok(watch_result), Err(send_error) => Err(send_error), } }); // Wait for a thread to finish. loop { let finished_thread = self.monitor_rx.recv().map_err(|_| Error::RecvNotify)?; let finished_thread_id = finished_thread.id(); if finished_thread_id == monitor_join_handle.thread().id() { return match monitor_join_handle.join() { Ok(join_result_result) => match join_result_result { Ok(watch_result) => Ok(match watch_result { Ok(threads) => FinishReason::ThreadPanic { threads }, Err(error) => FinishReason::MonitorError { error }, }), Err(error) => Ok(FinishReason::MonitorSendError { error }), }, Err(_err) => Ok(FinishReason::MonitorJoinError), }; } else if watched_trigger_thread_ids.contains(&finished_thread_id) { return Ok(FinishReason::ThreadFinished { thread: finished_thread, }); } } } } #[cfg(test)] mod tests { use crate::thread::{manager::FinishReason, monitor::ThreadMonitor}; use lazy_static::lazy_static; use std::{sync::mpsc, thread, time::Duration}; use super::ThreadManager; lazy_static! { static ref MONITOR_FOR_FINISH: ThreadMonitor = ThreadMonitor::new(); static ref MONITOR_FOR_PANIC: ThreadMonitor = ThreadMonitor::new(); } #[test] pub fn test_join_with_panic() { MONITOR_FOR_PANIC.init().unwrap(); let (tx, rx) = mpsc::channel(); let mut manager = ThreadManager::new(&MONITOR_FOR_PANIC); manager.spawn( move || { rx.recv().unwrap(); panic!("Oh no"); }, true, ); // Create a separate thread to trigger the panic. let trigger_handle = thread::spawn(move || { thread::sleep(Duration::from_millis(1000)); tx.send(true) }); let join_result = manager.join_all(); let trigger_result = trigger_handle.join().unwrap(); assert!(trigger_result.is_ok()); assert!(matches!( join_result, Ok(FinishReason::ThreadPanic { threads: _ }) )); } #[test] pub fn test_join_with_finish() { MONITOR_FOR_FINISH.init().unwrap(); let (tx, rx) = mpsc::channel(); let mut manager = ThreadManager::new(&MONITOR_FOR_FINISH); manager.spawn( move || { rx.recv().unwrap(); }, true, ); // Create a separate thread to trigger the finish. let trigger_handle = thread::spawn(move || { thread::sleep(Duration::from_millis(10)); tx.send(true) }); let join_result = manager.join_all(); let trigger_result = trigger_handle.join().unwrap(); assert!(trigger_result.is_ok()); assert!(matches!( join_result, Ok(FinishReason::ThreadFinished { thread: _ }) )) } }