Skip to content
Snippets Groups Projects
monitor.rs 7.21 KiB
Newer Older
//! Utilities for monitoring when a thread has panicked.
//!
//! Multi-threaded applications may have to react to panics in individual
//! threads. This module provides an implementation of an abstraction to do
//! exactly this, [`ThreadMonitor`].

use std::{
    collections::{HashMap, HashSet},
    panic,
    sync::{Condvar, Mutex},
    thread::{self, Thread, ThreadId},
use thiserror::Error;

/// Error type for [`ThreadMonitor`]
#[derive(Debug, Error)]
pub enum Error {
    /// Unable to adquire a lock of the monitor state.
    #[error("Unable to adquire a lock of the monitor state.")]
    WatchStateLock,
    /// Got more than one active call to [`ThreadMonitor::watch`].
    #[error("There should only be one active call to watch().")]
    MultipleWatches,
    /// Got no threads to watch.
    #[error("Got no threads to watch.")]
    NoWatches,
    /// The monitor was invoked before it was initialized.
    ///
    /// Make sure to call [`ThreadMonitor::init`] before using it.
    #[error("The monitor was invoked before it was initialized.")]
    Uninitialized,
}

/// Result type for [`ThreadMonitor`] methods.
pub type Result<T, E = Error> = std::result::Result<T, E>;

struct State {
    panicked: HashMap<ThreadId, Thread>,
    watched: Option<HashSet<ThreadId>>,
    initialized: bool,
}

/// A thread panic monitor.
pub struct ThreadMonitor {
    condvar: Condvar,
    state: Mutex<State>,
    /// Instantiates a new thread monitor.
    pub fn new() -> Self {
        ThreadMonitor {
            condvar: Condvar::new(),
            state: Mutex::new(State {
                panicked: HashMap::new(),
                watched: None,
                initialized: false,
            }),
    /// Initializes the thread monitor.
    ///
    /// This should be done before watching any threads.
    ///
    /// Internally, the monitor will set up a new panic hook, which will be used
    /// for detecting panics on the threads being watched.
    ///
    /// The monitor will initially ignore all panics. Use [`set_watched`] or
    /// [`watch`] to specify which threads to monitor.
    ///
    /// [`set_watched`]: #method.set_watched
    /// [`watch`]: #method.watch
    pub fn init(&'static self) -> Result<()> {
        let hook = panic::take_hook();

        panic::set_hook(Box::new(move |panic_info| {
            match self.state.lock() {
                Ok(mut state) => {
                    if let Some(watched) = &state.watched {
                        let current_thread = thread::current();
                        // Only notify if the thread ID is being watched.
                        if watched.contains(&current_thread.id()) {
                            state.panicked.insert(current_thread.id(), current_thread);
                            self.condvar.notify_all();
                        }
                    }
                }
                Err(_) => log::error!("Unable to update map of panicked threads."),
            }

            hook(panic_info)
        }));
        let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;

        state.initialized = true;

        Ok(())
    /// Set the threads to be watched by this monitor.
    ///
    /// If [`ThreadMonitor::init`] has been called, the monitor will begin
    /// recording panics for the specified threads.
    pub fn set_watched(&self, thread_ids: HashSet<ThreadId>) -> Result<()> {
        let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;

        state.watched = Some(thread_ids);
    /// Watches the provided thread IDs.
    ///
    /// [`ThreadMonitor::init`] has to be called before this method. An
    /// [`Uninitialized`] error will be returned if it's not.
    ///
    /// - If an empty set is passed, this function returns immediately.
    /// - If a thread set is not passed and one hasn't been set with
    ///   [`set_watched`], a [`NoWatches`] error will be returned.
    /// - If [`set_watched`] was previously called and one of the watched threads
    ///   already panicked, this function will return immediately.
    /// - Otherwise, the monitor will block the current thread until one of the
    ///   watched threads has a panic.
    ///
    /// [`set_watched`]: #method.set_watched
    /// [`NoWatches`]: ./enum.Error.html#variant.NoWatches
    /// [`Uninitialized`]: ./enum.Error.html#variant.Uninitialized
    pub fn watch(&self, thread_ids: Option<&HashSet<ThreadId>>) -> Result<Vec<Thread>> {
        let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;

        if !state.initialized {
            return Err(Error::Uninitialized);
        }
        let thread_ids = match thread_ids {
            Some(thread_ids) => Ok(thread_ids.clone()),
            None => match &(state.watched) {
                Some(thread_ids) => Ok(thread_ids.clone()),
                None => Err(Error::NoWatches),
            },
        }?;
        if thread_ids.is_empty() {
            return Ok(vec![]);
        state.panicked = HashMap::new();
        state.watched = Some(thread_ids);

        let mut watched_panicked = vec![];
        loop {
            // Since `state` may have changed, we need to reload the list of thread
            // ids, otherwise we would be stuck checking for thread ids that may not
            // be watched anymore.
            let thread_ids = match &state.watched {
                Some(thread_ids) => thread_ids,
                None => return Err(Error::NoWatches),
            };

            if thread_ids.is_empty() {
                return Ok(vec![]);
            }

            for thread_id in thread_ids {
                if let Some(thread) = state.panicked.get(thread_id) {
                    watched_panicked.push(thread.clone().clone());
                }
            }

            if !watched_panicked.is_empty() {
                return Ok(watched_panicked);
            }

            state = self
                .condvar
                .wait(state)
                .map_err(|_| Error::WatchStateLock)?;
        }
    }
}

impl Default for ThreadMonitor {
    fn default() -> Self {
        Self::new()
    }
    use super::ThreadMonitor;
    use lazy_static::lazy_static;
    use std::{collections::HashSet, sync::mpsc, thread, time::Duration};
    lazy_static! {
        static ref MONITOR: ThreadMonitor = ThreadMonitor::new();
    }
    #[test]
    pub fn test_watch() {
        MONITOR.init().unwrap();
        // Monitoring an empty list of threads should return immediately.
        MONITOR.watch(Some(&HashSet::new())).unwrap();
        let (tx, rx) = mpsc::channel();
        let handle = thread::spawn(move || {
            rx.recv().unwrap();
            panic!("Oh no");
        });
        let mut thread_ids = HashSet::new();
        thread_ids.insert(handle.thread().id());
        thread::sleep(Duration::from_millis(10));
        let test_handle = thread::spawn(move || {
            let watch_result = MONITOR.watch(Some(&thread_ids)).unwrap();
            assert!(!watch_result.is_empty());
            assert_eq!(watch_result[0].id(), handle.thread().id());
        });
        thread::sleep(Duration::from_millis(10));
        tx.send(true).unwrap();
        test_handle.join().unwrap();
    }