Skip to content
Snippets Groups Projects
monitor.rs 6.58 KiB
Newer Older
//! Utilities for monitoring when a thread has panicked.
//!
//! Multi-threaded applications may have to react to panics in individual 
//! threads. This module provides an implementation of an abstraction to do
//! exactly this, [`ThreadMonitor`].

use thiserror::Error;
use std::{
  collections::{HashMap, HashSet},
  panic,
  sync::{Condvar, Mutex},
  thread::{self, Thread, ThreadId},
};

/// Error type for [`ThreadMonitor`]
#[derive(Debug, Error)]
pub enum Error {
  /// Unable to adquire a lock of the monitor state.
  #[error("Unable to adquire a lock of the monitor state.")]
  WatchStateLock,
  /// Got more than one active call to [`ThreadMonitor::watch`].
  #[error("There should only be one active call to watch().")]
  MultipleWatches,
  /// Got no threads to watch.
  #[error("Got no threads to watch.")]
  NoWatches,
  /// The monitor was invoked before it was initialized.
  ///
  /// Make sure to call [`ThreadMonitor::init`] before using it.
  #[error("The monitor was invoked before it was initialized.")]
  Uninitialized,
}

/// Result type for [`ThreadMonitor`] methods.
pub type Result<T, E = Error> = std::result::Result<T, E>;

struct State {
  panicked: HashMap<ThreadId, Thread>,
  watched: Option<HashSet<ThreadId>>,
  initialized: bool,
}

/// A thread panic monitor.
pub struct ThreadMonitor {
  condvar: Condvar,
  state: Mutex<State>,
}

impl ThreadMonitor {
  /// Instantiates a new thread monitor.
  pub fn new() -> Self {
    ThreadMonitor {
      condvar: Condvar::new(),
      state: Mutex::new(State {
        panicked: HashMap::new(),
        watched: None,
        initialized: false,
      }),
    }
  }

  /// Initializes the thread monitor.
  ///
  /// This should be done before watching any threads.
  ///
  /// Internally, the monitor will set up a new panic hook, which will be used
  /// for detecting panics on the threads being watched.
  ///
  /// The monitor will initially ignore all panics. Use [`set_watched`] or
  /// [`watch`] to specify which threads to monitor.
  ///
  /// [`set_watched`]: #method.set_watched
  /// [`watch`]: #method.watch
  pub fn init(&'static self) -> Result<()> {
    let hook = panic::take_hook();

    panic::set_hook(Box::new(move |panic_info| {
      match self.state.lock() {
        Ok(mut state) => {
          match &state.watched {
            Some(watched) => {
              let current_thread = thread::current();

              // Only notify if the thread ID is being watched.
              if watched.contains(&current_thread.id()) {
                state.panicked.insert(current_thread.id(), current_thread);

                self.condvar.notify_all();
              }
            }
            None => {}
          }
        }
        Err(_) => log::error!("Unable to update map of panicked threads."),
      }

      hook(panic_info)
    }));

    let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;

    state.initialized = true;

    Ok(())
  }

  /// Set the threads to be watched by this monitor.
  ///
  /// If [`ThreadMonitor::init`] has been called, the monitor will begin
  /// recording panics for the specified threads.
  pub fn set_watched(&self, thread_ids: HashSet<ThreadId>) -> Result<()> {
    let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;

    state.watched = Some(thread_ids);

    Ok(())
  }

  /// Watches the provided thread IDs.
  ///
  /// [`ThreadMonitor::init`] has to be called before this method. An
  /// [`Uninitialized`] error will be returned if it's not.
  ///
  /// - If an empty set is passed, this function returns immediately.
  /// - If a thread set is not passed and one hasn't been set with
  ///   [`set_watched`], a [`NoWatches`] error will be returned.
  /// - If [`set_watched`] was previously called and one of the watched threads
  ///   already panicked, this function will return immediately.
  /// - Otherwise, the monitor will block the current thread until one of the
  ///   watched threads has a panic.
  ///
  /// [`set_watched`]: #method.set_watched
  /// [`NoWatches`]: ./enum.Error.html#variant.NoWatches
  /// [`Uninitialized`]: ./enum.Error.html#variant.Uninitialized
  pub fn watch(&self, thread_ids: Option<&HashSet<ThreadId>>) -> Result<Vec<Thread>> {
    let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;

    if !state.initialized {
      return Err(Error::Uninitialized);
    }

    let thread_ids = match thread_ids {
      Some(thread_ids) => Ok(thread_ids.clone()),
      None => match &(state.watched) {
        Some(thread_ids) => Ok(thread_ids.clone()),
        None => Err(Error::NoWatches),
      },
    }?;

    if thread_ids.is_empty() {
      return Ok(vec![]);
    }

    state.panicked = HashMap::new();
    state.watched = Some(thread_ids);

    let mut watched_panicked = vec![];

    loop {
      // Since `state` may have changed, we need to reload the list of thread
      // ids, otherwise we would be stuck checking for thread ids that may not
      // be watched anymore.
      let thread_ids = match &state.watched {
        Some(thread_ids) => thread_ids,
        None => return Err(Error::NoWatches),
      };

      if thread_ids.is_empty() {
        return Ok(vec![]);
      }

      for thread_id in thread_ids {
        if let Some(thread) = state.panicked.get(thread_id) {
          watched_panicked.push(thread.clone().clone());
        }
      }

      if !watched_panicked.is_empty() {
        return Ok(watched_panicked);
      }

      state = self
        .condvar
        .wait(state)
        .map_err(|_| Error::WatchStateLock)?;
    }
  }
}

impl Default for ThreadMonitor {
  fn default() -> Self {
    Self::new()
  }
}

#[cfg(test)]
mod tests {
  use super::ThreadMonitor;
  use lazy_static::lazy_static;
  use std::{collections::HashSet, sync::mpsc, thread, time::Duration};

  lazy_static! {
    static ref MONITOR: ThreadMonitor = ThreadMonitor::new();
  }

  #[test]
  pub fn test_watch() {
    MONITOR.init().unwrap();

    // Monitoring an empty list of threads should return immediately.
    MONITOR.watch(Some(&HashSet::new())).unwrap();

    let (tx, rx) = mpsc::channel();

    let handle = thread::spawn(move || {
      rx.recv().unwrap();

      panic!("Oh no");
    });

    let mut thread_ids = HashSet::new();
    thread_ids.insert(handle.thread().id());

    thread::sleep(Duration::from_millis(10));

    let test_handle = thread::spawn(move || {
      let watch_result = MONITOR.watch(Some(&thread_ids)).unwrap();

      assert_eq!(watch_result.is_empty(), false);
      assert_eq!(watch_result[0].id(), handle.thread().id());
    });

    thread::sleep(Duration::from_millis(10));

    tx.send(true).unwrap();

    test_handle.join().unwrap();
  }
}