From fff02b5923d0d4d572bacf53663308fe952093a8 Mon Sep 17 00:00:00 2001
From: Eduardo Trujillo <ed@chromabits.com>
Date: Tue, 21 Jan 2025 03:27:31 +0000
Subject: [PATCH] fix(thread): Support concurrent tests, improve init checks

---
 src/thread/manager.rs | 74 ++++++++++++++++++++++++++++---------------
 src/thread/monitor.rs | 11 ++++++-
 2 files changed, 58 insertions(+), 27 deletions(-)

diff --git a/src/thread/manager.rs b/src/thread/manager.rs
index 10b6815..bb6b87a 100644
--- a/src/thread/manager.rs
+++ b/src/thread/manager.rs
@@ -1,6 +1,6 @@
 use std::{
     collections::HashSet,
-    sync::mpsc::{self, Receiver, Sender},
+    sync::mpsc::{self, Receiver, SendError, Sender},
     thread::{JoinHandle, Thread, ThreadId},
 };
 
@@ -22,6 +22,7 @@ pub enum Error {
 }
 
 /// Reason why a call to [`ThreadManager::join_all`] returned.
+#[derive(Debug)]
 pub enum FinishReason {
     /// The monitor thread encountered an error.
     MonitorError { error: ThreadMonitorError },
@@ -29,6 +30,10 @@ pub enum FinishReason {
     ThreadFinished { thread: Thread },
     /// One of more threads panicked.
     ThreadPanic { threads: Vec<Thread> },
+    /// Unable to join with monnitor thread
+    MonitorJoinError,
+    /// Monitor was unable to send a notification
+    MonitorSendError { error: SendError<Thread> },
 }
 
 /// An abstraction for building simple multi-threaded applications.
@@ -135,14 +140,16 @@ impl ThreadManager {
             })
             .collect();
 
-        #[allow(suspicious_double_ref_op)]
         let monitor_for_monitor_thread = self.monitor;
         let monitor_join_handle = std::thread::spawn(move || {
             let watch_result = monitor_for_monitor_thread.watch(Some(&watched_thread_ids));
 
-            monitor_tx_for_spawn.send(std::thread::current()).unwrap();
+            let notification_result = monitor_tx_for_spawn.send(std::thread::current());
 
-            watch_result
+            match notification_result {
+                Ok(()) => Ok(watch_result),
+                Err(send_error) => Err(send_error),
+            }
         });
 
         // Wait for a thread to finish.
@@ -151,12 +158,17 @@ impl ThreadManager {
             let finished_thread_id = finished_thread.id();
 
             if finished_thread_id == monitor_join_handle.thread().id() {
-                let watch_result = monitor_join_handle.join().unwrap();
-
-                return Ok(match watch_result {
-                    Ok(threads) => FinishReason::ThreadPanic { threads },
-                    Err(error) => FinishReason::MonitorError { error },
-                });
+                return match monitor_join_handle.join() {
+                    Ok(join_result_result) => match join_result_result {
+                        Ok(watch_result) => Ok(match watch_result {
+                            Ok(threads) => FinishReason::ThreadPanic { threads },
+                            Err(error) => FinishReason::MonitorError { error },
+                        }),
+
+                        Err(error) => Ok(FinishReason::MonitorSendError { error }),
+                    },
+                    Err(_err) => Ok(FinishReason::MonitorJoinError),
+                };
             } else if watched_trigger_thread_ids.contains(&finished_thread_id) {
                 return Ok(FinishReason::ThreadFinished {
                     thread: finished_thread,
@@ -171,21 +183,26 @@ mod tests {
     use crate::thread::{manager::FinishReason, monitor::ThreadMonitor};
 
     use lazy_static::lazy_static;
-    use std::{sync::mpsc, thread, time::Duration};
+    use std::{
+        sync::mpsc,
+        thread,
+        time::Duration,
+    };
 
     use super::ThreadManager;
 
     lazy_static! {
-        static ref MONITOR: ThreadMonitor = ThreadMonitor::new();
+        static ref MONITOR_FOR_FINISH: ThreadMonitor = ThreadMonitor::new();
+        static ref MONITOR_FOR_PANIC: ThreadMonitor = ThreadMonitor::new();
     }
 
     #[test]
     pub fn test_join_with_panic() {
-        MONITOR.init().unwrap();
+        MONITOR_FOR_PANIC.init().unwrap();
 
         let (tx, rx) = mpsc::channel();
 
-        let mut manager = ThreadManager::new(&MONITOR);
+        let mut manager = ThreadManager::new(&MONITOR_FOR_PANIC);
 
         manager.spawn(
             move || {
@@ -197,27 +214,29 @@ mod tests {
         );
 
         // Create a separate thread to trigger the panic.
-        thread::spawn(move || {
-            thread::sleep(Duration::from_millis(10));
+        let trigger_handle = thread::spawn(move || {
+            thread::sleep(Duration::from_millis(1000));
 
-            tx.send(true).unwrap();
+            tx.send(true)
         });
 
         let join_result = manager.join_all();
+        let trigger_result = trigger_handle.join().unwrap();
 
+        assert!(trigger_result.is_ok());
         assert!(matches!(
-            join_result.unwrap(),
-            FinishReason::ThreadPanic { threads: _ }
-        ))
+            join_result,
+            Ok(FinishReason::ThreadPanic { threads: _ })
+        ));
     }
 
     #[test]
     pub fn test_join_with_finish() {
-        MONITOR.init().unwrap();
+        MONITOR_FOR_FINISH.init().unwrap();
 
         let (tx, rx) = mpsc::channel();
 
-        let mut manager = ThreadManager::new(&MONITOR);
+        let mut manager = ThreadManager::new(&MONITOR_FOR_FINISH);
 
         manager.spawn(
             move || {
@@ -227,17 +246,20 @@ mod tests {
         );
 
         // Create a separate thread to trigger the finish.
-        thread::spawn(move || {
+        let trigger_handle = thread::spawn(move || {
             thread::sleep(Duration::from_millis(10));
 
-            tx.send(true).unwrap();
+            tx.send(true)
         });
 
         let join_result = manager.join_all();
+        let trigger_result = trigger_handle.join().unwrap();
+
+        assert!(trigger_result.is_ok());
 
         assert!(matches!(
-            join_result.unwrap(),
-            FinishReason::ThreadFinished { thread: _ }
+            join_result,
+            Ok(FinishReason::ThreadFinished { thread: _ })
         ))
     }
 }
diff --git a/src/thread/monitor.rs b/src/thread/monitor.rs
index cc9eb56..074b180 100644
--- a/src/thread/monitor.rs
+++ b/src/thread/monitor.rs
@@ -29,6 +29,8 @@ pub enum Error {
     /// Make sure to call [`ThreadMonitor::init`] before using it.
     #[error("The monitor was invoked before it was initialized.")]
     Uninitialized,
+    #[error("There was an attempt to reinitialize the monitor.")]
+    Reinitialized,
 }
 
 /// Result type for [`ThreadMonitor`] methods.
@@ -72,6 +74,14 @@ impl ThreadMonitor {
     /// [`set_watched`]: #method.set_watched
     /// [`watch`]: #method.watch
     pub fn init(&'static self) -> Result<()> {
+        let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;
+
+        // If we are already initialized, then do nothing. Adding another
+        // panic hook for the same monitor would cause issues.
+        if state.initialized {
+          return Err(Error::Reinitialized);
+        }
+
         let hook = panic::take_hook();
 
         panic::set_hook(Box::new(move |panic_info| {
@@ -94,7 +104,6 @@ impl ThreadMonitor {
             hook(panic_info)
         }));
 
-        let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;
 
         state.initialized = true;
 
-- 
GitLab