diff --git a/src/thread/manager.rs b/src/thread/manager.rs index 10b6815a391be278bcfa8ba3067f940651991dfe..bb6b87aa61c6e4081276e51906a5bbed665e354a 100644 --- a/src/thread/manager.rs +++ b/src/thread/manager.rs @@ -1,6 +1,6 @@ use std::{ collections::HashSet, - sync::mpsc::{self, Receiver, Sender}, + sync::mpsc::{self, Receiver, SendError, Sender}, thread::{JoinHandle, Thread, ThreadId}, }; @@ -22,6 +22,7 @@ pub enum Error { } /// Reason why a call to [`ThreadManager::join_all`] returned. +#[derive(Debug)] pub enum FinishReason { /// The monitor thread encountered an error. MonitorError { error: ThreadMonitorError }, @@ -29,6 +30,10 @@ pub enum FinishReason { ThreadFinished { thread: Thread }, /// One of more threads panicked. ThreadPanic { threads: Vec<Thread> }, + /// Unable to join with monnitor thread + MonitorJoinError, + /// Monitor was unable to send a notification + MonitorSendError { error: SendError<Thread> }, } /// An abstraction for building simple multi-threaded applications. @@ -135,14 +140,16 @@ impl ThreadManager { }) .collect(); - #[allow(suspicious_double_ref_op)] let monitor_for_monitor_thread = self.monitor; let monitor_join_handle = std::thread::spawn(move || { let watch_result = monitor_for_monitor_thread.watch(Some(&watched_thread_ids)); - monitor_tx_for_spawn.send(std::thread::current()).unwrap(); + let notification_result = monitor_tx_for_spawn.send(std::thread::current()); - watch_result + match notification_result { + Ok(()) => Ok(watch_result), + Err(send_error) => Err(send_error), + } }); // Wait for a thread to finish. @@ -151,12 +158,17 @@ impl ThreadManager { let finished_thread_id = finished_thread.id(); if finished_thread_id == monitor_join_handle.thread().id() { - let watch_result = monitor_join_handle.join().unwrap(); - - return Ok(match watch_result { - Ok(threads) => FinishReason::ThreadPanic { threads }, - Err(error) => FinishReason::MonitorError { error }, - }); + return match monitor_join_handle.join() { + Ok(join_result_result) => match join_result_result { + Ok(watch_result) => Ok(match watch_result { + Ok(threads) => FinishReason::ThreadPanic { threads }, + Err(error) => FinishReason::MonitorError { error }, + }), + + Err(error) => Ok(FinishReason::MonitorSendError { error }), + }, + Err(_err) => Ok(FinishReason::MonitorJoinError), + }; } else if watched_trigger_thread_ids.contains(&finished_thread_id) { return Ok(FinishReason::ThreadFinished { thread: finished_thread, @@ -171,21 +183,26 @@ mod tests { use crate::thread::{manager::FinishReason, monitor::ThreadMonitor}; use lazy_static::lazy_static; - use std::{sync::mpsc, thread, time::Duration}; + use std::{ + sync::mpsc, + thread, + time::Duration, + }; use super::ThreadManager; lazy_static! { - static ref MONITOR: ThreadMonitor = ThreadMonitor::new(); + static ref MONITOR_FOR_FINISH: ThreadMonitor = ThreadMonitor::new(); + static ref MONITOR_FOR_PANIC: ThreadMonitor = ThreadMonitor::new(); } #[test] pub fn test_join_with_panic() { - MONITOR.init().unwrap(); + MONITOR_FOR_PANIC.init().unwrap(); let (tx, rx) = mpsc::channel(); - let mut manager = ThreadManager::new(&MONITOR); + let mut manager = ThreadManager::new(&MONITOR_FOR_PANIC); manager.spawn( move || { @@ -197,27 +214,29 @@ mod tests { ); // Create a separate thread to trigger the panic. - thread::spawn(move || { - thread::sleep(Duration::from_millis(10)); + let trigger_handle = thread::spawn(move || { + thread::sleep(Duration::from_millis(1000)); - tx.send(true).unwrap(); + tx.send(true) }); let join_result = manager.join_all(); + let trigger_result = trigger_handle.join().unwrap(); + assert!(trigger_result.is_ok()); assert!(matches!( - join_result.unwrap(), - FinishReason::ThreadPanic { threads: _ } - )) + join_result, + Ok(FinishReason::ThreadPanic { threads: _ }) + )); } #[test] pub fn test_join_with_finish() { - MONITOR.init().unwrap(); + MONITOR_FOR_FINISH.init().unwrap(); let (tx, rx) = mpsc::channel(); - let mut manager = ThreadManager::new(&MONITOR); + let mut manager = ThreadManager::new(&MONITOR_FOR_FINISH); manager.spawn( move || { @@ -227,17 +246,20 @@ mod tests { ); // Create a separate thread to trigger the finish. - thread::spawn(move || { + let trigger_handle = thread::spawn(move || { thread::sleep(Duration::from_millis(10)); - tx.send(true).unwrap(); + tx.send(true) }); let join_result = manager.join_all(); + let trigger_result = trigger_handle.join().unwrap(); + + assert!(trigger_result.is_ok()); assert!(matches!( - join_result.unwrap(), - FinishReason::ThreadFinished { thread: _ } + join_result, + Ok(FinishReason::ThreadFinished { thread: _ }) )) } } diff --git a/src/thread/monitor.rs b/src/thread/monitor.rs index cc9eb568729241296fd70659cd98b9ac9b67b0bc..074b18015968f35d9d45e78947f87f3a48bce1ed 100644 --- a/src/thread/monitor.rs +++ b/src/thread/monitor.rs @@ -29,6 +29,8 @@ pub enum Error { /// Make sure to call [`ThreadMonitor::init`] before using it. #[error("The monitor was invoked before it was initialized.")] Uninitialized, + #[error("There was an attempt to reinitialize the monitor.")] + Reinitialized, } /// Result type for [`ThreadMonitor`] methods. @@ -72,6 +74,14 @@ impl ThreadMonitor { /// [`set_watched`]: #method.set_watched /// [`watch`]: #method.watch pub fn init(&'static self) -> Result<()> { + let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?; + + // If we are already initialized, then do nothing. Adding another + // panic hook for the same monitor would cause issues. + if state.initialized { + return Err(Error::Reinitialized); + } + let hook = panic::take_hook(); panic::set_hook(Box::new(move |panic_info| { @@ -94,7 +104,6 @@ impl ThreadMonitor { hook(panic_info) })); - let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?; state.initialized = true;