Newer
Older
//! Utilities for monitoring when a thread has panicked.
//!
//! Multi-threaded applications may have to react to panics in individual
//! threads. This module provides an implementation of an abstraction to do
//! exactly this, [`ThreadMonitor`].
use std::{
collections::{HashMap, HashSet},
panic,
sync::{Condvar, Mutex},
thread::{self, Thread, ThreadId},
/// Error type for [`ThreadMonitor`]
#[derive(Debug, Error)]
pub enum Error {
/// Unable to adquire a lock of the monitor state.
#[error("Unable to adquire a lock of the monitor state.")]
WatchStateLock,
/// Got more than one active call to [`ThreadMonitor::watch`].
#[error("There should only be one active call to watch().")]
MultipleWatches,
/// Got no threads to watch.
#[error("Got no threads to watch.")]
NoWatches,
/// The monitor was invoked before it was initialized.
///
/// Make sure to call [`ThreadMonitor::init`] before using it.
#[error("The monitor was invoked before it was initialized.")]
Uninitialized,
}
/// Result type for [`ThreadMonitor`] methods.
pub type Result<T, E = Error> = std::result::Result<T, E>;
struct State {
panicked: HashMap<ThreadId, Thread>,
watched: Option<HashSet<ThreadId>>,
initialized: bool,
}
/// A thread panic monitor.
pub struct ThreadMonitor {
condvar: Condvar,
state: Mutex<State>,
}
impl ThreadMonitor {
/// Instantiates a new thread monitor.
pub fn new() -> Self {
ThreadMonitor {
condvar: Condvar::new(),
state: Mutex::new(State {
panicked: HashMap::new(),
watched: None,
initialized: false,
}),
/// Initializes the thread monitor.
///
/// This should be done before watching any threads.
///
/// Internally, the monitor will set up a new panic hook, which will be used
/// for detecting panics on the threads being watched.
///
/// The monitor will initially ignore all panics. Use [`set_watched`] or
/// [`watch`] to specify which threads to monitor.
///
/// [`set_watched`]: #method.set_watched
/// [`watch`]: #method.watch
pub fn init(&'static self) -> Result<()> {
let hook = panic::take_hook();
panic::set_hook(Box::new(move |panic_info| {
match self.state.lock() {
Ok(mut state) => {
if let Some(watched) = &state.watched {
let current_thread = thread::current();
// Only notify if the thread ID is being watched.
if watched.contains(¤t_thread.id()) {
state.panicked.insert(current_thread.id(), current_thread);
self.condvar.notify_all();
}
}
}
Err(_) => log::error!("Unable to update map of panicked threads."),
}
hook(panic_info)
}));
let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;
state.initialized = true;
Ok(())
/// Set the threads to be watched by this monitor.
///
/// If [`ThreadMonitor::init`] has been called, the monitor will begin
/// recording panics for the specified threads.
pub fn set_watched(&self, thread_ids: HashSet<ThreadId>) -> Result<()> {
let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;
state.watched = Some(thread_ids);
/// Watches the provided thread IDs.
///
/// [`ThreadMonitor::init`] has to be called before this method. An
/// [`Uninitialized`] error will be returned if it's not.
///
/// - If an empty set is passed, this function returns immediately.
/// - If a thread set is not passed and one hasn't been set with
/// [`set_watched`], a [`NoWatches`] error will be returned.
/// - If [`set_watched`] was previously called and one of the watched threads
/// already panicked, this function will return immediately.
/// - Otherwise, the monitor will block the current thread until one of the
/// watched threads has a panic.
///
/// [`set_watched`]: #method.set_watched
/// [`NoWatches`]: ./enum.Error.html#variant.NoWatches
/// [`Uninitialized`]: ./enum.Error.html#variant.Uninitialized
pub fn watch(&self, thread_ids: Option<&HashSet<ThreadId>>) -> Result<Vec<Thread>> {
let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;
if !state.initialized {
return Err(Error::Uninitialized);
}
let thread_ids = match thread_ids {
Some(thread_ids) => Ok(thread_ids.clone()),
None => match &(state.watched) {
Some(thread_ids) => Ok(thread_ids.clone()),
None => Err(Error::NoWatches),
},
}?;
if thread_ids.is_empty() {
return Ok(vec![]);
state.panicked = HashMap::new();
state.watched = Some(thread_ids);
let mut watched_panicked = vec![];
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
loop {
// Since `state` may have changed, we need to reload the list of thread
// ids, otherwise we would be stuck checking for thread ids that may not
// be watched anymore.
let thread_ids = match &state.watched {
Some(thread_ids) => thread_ids,
None => return Err(Error::NoWatches),
};
if thread_ids.is_empty() {
return Ok(vec![]);
}
for thread_id in thread_ids {
if let Some(thread) = state.panicked.get(thread_id) {
watched_panicked.push(thread.clone().clone());
}
}
if !watched_panicked.is_empty() {
return Ok(watched_panicked);
}
state = self
.condvar
.wait(state)
.map_err(|_| Error::WatchStateLock)?;
}
}
}
impl Default for ThreadMonitor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::ThreadMonitor;
use lazy_static::lazy_static;
use std::{collections::HashSet, sync::mpsc, thread, time::Duration};
lazy_static! {
static ref MONITOR: ThreadMonitor = ThreadMonitor::new();
}
#[test]
pub fn test_watch() {
MONITOR.init().unwrap();
// Monitoring an empty list of threads should return immediately.
MONITOR.watch(Some(&HashSet::new())).unwrap();
let handle = thread::spawn(move || {
rx.recv().unwrap();
let mut thread_ids = HashSet::new();
thread_ids.insert(handle.thread().id());
thread::sleep(Duration::from_millis(10));
let test_handle = thread::spawn(move || {
let watch_result = MONITOR.watch(Some(&thread_ids)).unwrap();
assert!(!watch_result.is_empty());
assert_eq!(watch_result[0].id(), handle.thread().id());
});
thread::sleep(Duration::from_millis(10));