Newer
Older
use std::{
collections::HashSet,
sync::mpsc::{self, Receiver, SendError, Sender},
thread::{JoinHandle, Thread, ThreadId},
};
use thiserror::Error;
use crate::thread::monitor::{Error as ThreadMonitorError, ThreadMonitor};
/// Error type for [`ThreadManager`].
#[derive(Error, Debug)]
pub enum Error {
/// Unable to receive from the monitor channel.
#[error("Unable to receive from the monitor channel.")]
RecvNotify,
#[error("Unable to initialize internal monitor: {source}")]
Monitor {
#[from]
source: ThreadMonitorError,
},
}
/// Reason why a call to [`ThreadManager::join_all`] returned.
#[derive(Debug)]
pub enum FinishReason {
/// The monitor thread encountered an error.
MonitorError { error: ThreadMonitorError },
/// A thread marked as `triggers_shutdown` finished.
ThreadFinished { thread: Thread },
/// One of more threads panicked.
ThreadPanic { threads: Vec<Thread> },
/// Unable to join with monnitor thread
MonitorJoinError,
/// Monitor was unable to send a notification
MonitorSendError { error: SendError<Thread> },
}
/// An abstraction for building simple multi-threaded applications.
/// It's goal is to simplify the process of setting up an application composed
/// of a fixed number of threads, adding in support for common tasks such as
/// detecting panics and waiting for some threads to finish.
///
/// It achieves this by focusing on a narrow use-case:
///
/// - The application has a fixed number of threads.
/// - The threads are all started at the same time during initialization.
/// - A panic in any thread is expected to trigger a shut down of the
/// application.
/// - Some threads finishing should trigger a shut down of the application.
/// For example, a web application has a main server thread and a metrics/stats
/// server thread.
/// [`ThreadManager`] is a higher level abstraction built on top of
/// [`ThreadMonitor`], which is used to detect and handle panics in other
/// threads. [`ThreadMonitor`] can be used directly if [`ThreadManager`] does
/// not meet a use-case or low level control is needed.
/// - Only threads spawned directly through the manager interface can be
/// monitored. The manager doesn't monitor other threads spawned separately,
/// even if it's done inside one of the original spawned threads.
/// - The return type of the [`JoinHandle`] is expected to be `()`.
pub struct ThreadManager {
join_handles: Vec<(JoinHandle<()>, bool)>,
monitor: &'static ThreadMonitor,
monitor_tx: Sender<Thread>,
monitor_rx: Receiver<Thread>,
}
impl ThreadManager {
/// Instantiates a new manager.
///
/// A static reference to a [`ThreadMonitor`] that has already been
/// initialized is expected.
/// See the documentation of [`ThreadMonitor::init`] for more details.
pub fn new(monitor: &'static ThreadMonitor) -> ThreadManager {
let (monitor_tx, monitor_rx) = mpsc::channel();
ThreadManager {
join_handles: vec![],
monitor,
monitor_rx,
monitor_tx,
}
}
/// Spawns a new thread to be managed.
///
/// Unlike [`std::thread::spawn`], a copy of the [`Thread`] will be
/// returned instead of a [`std::thread::JoinHandle`]. The handle will be
/// used internally by the manager.
pub fn spawn<F>(&mut self, f: F, triggers_shutdown: bool) -> Thread
where
F: FnOnce(),
F: Send + 'static,
{
let monitor_tx_for_spawn = self.monitor_tx.clone();
let join_handle = std::thread::spawn(move || {
monitor_tx_for_spawn.send(std::thread::current()).unwrap();
});
let thread = join_handle.thread().clone();
self.join_handles.push((join_handle, triggers_shutdown));
thread
}
/// Waits for any thread marked as `triggers_shutdown` to finish or for one
/// or more threads to panic.
/// A [`FinishReason`] will be returned providing context into why the
/// manager returned.
pub fn join_all(self) -> Result<FinishReason, Error> {
let monitor_tx_for_spawn = self.monitor_tx.clone();
let watched_thread_ids: HashSet<ThreadId> = self
.join_handles
.iter()
.map(|(join_handle, _)| join_handle.thread().id())
.collect();
let watched_trigger_thread_ids: HashSet<ThreadId> = self
.join_handles
.iter()
.filter_map(|(join_handle, triggers_shutdown)| {
if *triggers_shutdown {
Some(join_handle.thread().id())
} else {
None
}
})
.collect();
let monitor_for_monitor_thread = self.monitor;
let monitor_join_handle = std::thread::spawn(move || {
let watch_result = monitor_for_monitor_thread.watch(Some(&watched_thread_ids));
let notification_result = monitor_tx_for_spawn.send(std::thread::current());
match notification_result {
Ok(()) => Ok(watch_result),
Err(send_error) => Err(send_error),
}
});
// Wait for a thread to finish.
loop {
let finished_thread = self.monitor_rx.recv().map_err(|_| Error::RecvNotify)?;
let finished_thread_id = finished_thread.id();
if finished_thread_id == monitor_join_handle.thread().id() {
return match monitor_join_handle.join() {
Ok(join_result_result) => match join_result_result {
Ok(watch_result) => Ok(match watch_result {
Ok(threads) => FinishReason::ThreadPanic { threads },
Err(error) => FinishReason::MonitorError { error },
}),
Err(error) => Ok(FinishReason::MonitorSendError { error }),
},
Err(_err) => Ok(FinishReason::MonitorJoinError),
};
} else if watched_trigger_thread_ids.contains(&finished_thread_id) {
return Ok(FinishReason::ThreadFinished {
thread: finished_thread,
});
}
}
}
}
#[cfg(test)]
mod tests {
use crate::thread::{manager::FinishReason, monitor::ThreadMonitor};
use lazy_static::lazy_static;
use std::{
sync::mpsc,
thread,
time::Duration,
};
use super::ThreadManager;
lazy_static! {
static ref MONITOR_FOR_FINISH: ThreadMonitor = ThreadMonitor::new();
static ref MONITOR_FOR_PANIC: ThreadMonitor = ThreadMonitor::new();
}
#[test]
pub fn test_join_with_panic() {
MONITOR_FOR_PANIC.init().unwrap();
let (tx, rx) = mpsc::channel();
let mut manager = ThreadManager::new(&MONITOR_FOR_PANIC);
manager.spawn(
move || {
rx.recv().unwrap();
panic!("Oh no");
},
true,
);
// Create a separate thread to trigger the panic.
let trigger_handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(1000));
tx.send(true)
});
let join_result = manager.join_all();
let trigger_result = trigger_handle.join().unwrap();
assert!(trigger_result.is_ok());
join_result,
Ok(FinishReason::ThreadPanic { threads: _ })
));
}
#[test]
pub fn test_join_with_finish() {
MONITOR_FOR_FINISH.init().unwrap();
let (tx, rx) = mpsc::channel();
let mut manager = ThreadManager::new(&MONITOR_FOR_FINISH);
manager.spawn(
move || {
rx.recv().unwrap();
},
true,
);
// Create a separate thread to trigger the finish.
let trigger_handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(10));
tx.send(true)
});
let join_result = manager.join_all();
let trigger_result = trigger_handle.join().unwrap();
assert!(trigger_result.is_ok());
join_result,
Ok(FinishReason::ThreadFinished { thread: _ })