//! Utilities for monitoring when a thread has panicked. //! //! Multi-threaded applications may have to react to panics in individual //! threads. This module provides an implementation of an abstraction to do //! exactly this, [`ThreadMonitor`]. use std::{ collections::{HashMap, HashSet}, panic, sync::{Condvar, Mutex}, thread::{self, Thread, ThreadId}, }; use thiserror::Error; /// Error type for [`ThreadMonitor`] #[derive(Debug, Error)] pub enum Error { /// Unable to adquire a lock of the monitor state. #[error("Unable to adquire a lock of the monitor state.")] WatchStateLock, /// Got more than one active call to [`ThreadMonitor::watch`]. #[error("There should only be one active call to watch().")] MultipleWatches, /// Got no threads to watch. #[error("Got no threads to watch.")] NoWatches, /// The monitor was invoked before it was initialized. /// /// Make sure to call [`ThreadMonitor::init`] before using it. #[error("The monitor was invoked before it was initialized.")] Uninitialized, #[error("There was an attempt to reinitialize the monitor.")] Reinitialized, } /// Result type for [`ThreadMonitor`] methods. pub type Result<T, E = Error> = std::result::Result<T, E>; struct State { panicked: HashMap<ThreadId, Thread>, watched: Option<HashSet<ThreadId>>, initialized: bool, } /// A thread panic monitor. pub struct ThreadMonitor { condvar: Condvar, state: Mutex<State>, } impl ThreadMonitor { /// Instantiates a new thread monitor. pub fn new() -> Self { ThreadMonitor { condvar: Condvar::new(), state: Mutex::new(State { panicked: HashMap::new(), watched: None, initialized: false, }), } } /// Initializes the thread monitor. /// /// This should be done before watching any threads. /// /// Internally, the monitor will set up a new panic hook, which will be used /// for detecting panics on the threads being watched. /// /// The monitor will initially ignore all panics. Use [`set_watched`] or /// [`watch`] to specify which threads to monitor. /// /// [`set_watched`]: #method.set_watched /// [`watch`]: #method.watch pub fn init(&'static self) -> Result<()> { let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?; // If we are already initialized, then do nothing. Adding another // panic hook for the same monitor would cause issues. if state.initialized { return Err(Error::Reinitialized); } let hook = panic::take_hook(); panic::set_hook(Box::new(move |panic_info| { match self.state.lock() { Ok(mut state) => { if let Some(watched) = &state.watched { let current_thread = thread::current(); // Only notify if the thread ID is being watched. if watched.contains(¤t_thread.id()) { state.panicked.insert(current_thread.id(), current_thread); self.condvar.notify_all(); } } } Err(_) => log::error!("Unable to update map of panicked threads."), } hook(panic_info) })); state.initialized = true; Ok(()) } /// Set the threads to be watched by this monitor. /// /// If [`ThreadMonitor::init`] has been called, the monitor will begin /// recording panics for the specified threads. pub fn set_watched(&self, thread_ids: HashSet<ThreadId>) -> Result<()> { let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?; state.watched = Some(thread_ids); Ok(()) } /// Watches the provided thread IDs. /// /// [`ThreadMonitor::init`] has to be called before this method. An /// [`Uninitialized`] error will be returned if it's not. /// /// - If an empty set is passed, this function returns immediately. /// - If a thread set is not passed and one hasn't been set with /// [`set_watched`], a [`NoWatches`] error will be returned. /// - If [`set_watched`] was previously called and one of the watched threads /// already panicked, this function will return immediately. /// - Otherwise, the monitor will block the current thread until one of the /// watched threads has a panic. /// /// [`set_watched`]: #method.set_watched /// [`NoWatches`]: ./enum.Error.html#variant.NoWatches /// [`Uninitialized`]: ./enum.Error.html#variant.Uninitialized pub fn watch(&self, thread_ids: Option<&HashSet<ThreadId>>) -> Result<Vec<Thread>> { let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?; if !state.initialized { return Err(Error::Uninitialized); } let thread_ids = match thread_ids { Some(thread_ids) => Ok(thread_ids.clone()), None => match &(state.watched) { Some(thread_ids) => Ok(thread_ids.clone()), None => Err(Error::NoWatches), }, }?; if thread_ids.is_empty() { return Ok(vec![]); } state.panicked = HashMap::new(); state.watched = Some(thread_ids); let mut watched_panicked = vec![]; loop { // Since `state` may have changed, we need to reload the list of thread // ids, otherwise we would be stuck checking for thread ids that may not // be watched anymore. let thread_ids = match &state.watched { Some(thread_ids) => thread_ids, None => return Err(Error::NoWatches), }; if thread_ids.is_empty() { return Ok(vec![]); } for thread_id in thread_ids { if let Some(thread) = state.panicked.get(thread_id) { watched_panicked.push(thread.clone().clone()); } } if !watched_panicked.is_empty() { return Ok(watched_panicked); } state = self .condvar .wait(state) .map_err(|_| Error::WatchStateLock)?; } } } impl Default for ThreadMonitor { fn default() -> Self { Self::new() } } #[cfg(test)] mod tests { use super::ThreadMonitor; use lazy_static::lazy_static; use std::{collections::HashSet, sync::mpsc, thread, time::Duration}; lazy_static! { static ref MONITOR: ThreadMonitor = ThreadMonitor::new(); } #[test] pub fn test_watch() { MONITOR.init().unwrap(); // Monitoring an empty list of threads should return immediately. MONITOR.watch(Some(&HashSet::new())).unwrap(); let (tx, rx) = mpsc::channel(); let handle = thread::spawn(move || { rx.recv().unwrap(); panic!("Oh no"); }); let mut thread_ids = HashSet::new(); thread_ids.insert(handle.thread().id()); thread::sleep(Duration::from_millis(10)); let test_handle = thread::spawn(move || { let watch_result = MONITOR.watch(Some(&thread_ids)).unwrap(); assert!(!watch_result.is_empty()); assert_eq!(watch_result[0].id(), handle.thread().id()); }); thread::sleep(Duration::from_millis(10)); tx.send(true).unwrap(); test_handle.join().unwrap(); } }