Skip to content
Snippets Groups Projects

Compare revisions

Changes are shown as if the source revision was being merged into the target revision. Learn more about comparing revisions.

Source

Select target project
No results found

Target

Select target project
  • etcinit/collective
1 result
Show changes
Commits on Source (2)
use std::{ use std::{
collections::HashSet, collections::HashSet,
sync::mpsc::{self, Receiver, Sender}, sync::mpsc::{self, Receiver, SendError, Sender},
thread::{JoinHandle, Thread, ThreadId}, thread::{JoinHandle, Thread, ThreadId},
}; };
...@@ -22,6 +22,7 @@ pub enum Error { ...@@ -22,6 +22,7 @@ pub enum Error {
} }
/// Reason why a call to [`ThreadManager::join_all`] returned. /// Reason why a call to [`ThreadManager::join_all`] returned.
#[derive(Debug)]
pub enum FinishReason { pub enum FinishReason {
/// The monitor thread encountered an error. /// The monitor thread encountered an error.
MonitorError { error: ThreadMonitorError }, MonitorError { error: ThreadMonitorError },
...@@ -29,6 +30,10 @@ pub enum FinishReason { ...@@ -29,6 +30,10 @@ pub enum FinishReason {
ThreadFinished { thread: Thread }, ThreadFinished { thread: Thread },
/// One of more threads panicked. /// One of more threads panicked.
ThreadPanic { threads: Vec<Thread> }, ThreadPanic { threads: Vec<Thread> },
/// Unable to join with monnitor thread
MonitorJoinError,
/// Monitor was unable to send a notification
MonitorSendError { error: SendError<Thread> },
} }
/// An abstraction for building simple multi-threaded applications. /// An abstraction for building simple multi-threaded applications.
...@@ -135,14 +140,16 @@ impl ThreadManager { ...@@ -135,14 +140,16 @@ impl ThreadManager {
}) })
.collect(); .collect();
#[allow(suspicious_double_ref_op)]
let monitor_for_monitor_thread = self.monitor; let monitor_for_monitor_thread = self.monitor;
let monitor_join_handle = std::thread::spawn(move || { let monitor_join_handle = std::thread::spawn(move || {
let watch_result = monitor_for_monitor_thread.watch(Some(&watched_thread_ids)); let watch_result = monitor_for_monitor_thread.watch(Some(&watched_thread_ids));
monitor_tx_for_spawn.send(std::thread::current()).unwrap(); let notification_result = monitor_tx_for_spawn.send(std::thread::current());
watch_result match notification_result {
Ok(()) => Ok(watch_result),
Err(send_error) => Err(send_error),
}
}); });
// Wait for a thread to finish. // Wait for a thread to finish.
...@@ -151,12 +158,17 @@ impl ThreadManager { ...@@ -151,12 +158,17 @@ impl ThreadManager {
let finished_thread_id = finished_thread.id(); let finished_thread_id = finished_thread.id();
if finished_thread_id == monitor_join_handle.thread().id() { if finished_thread_id == monitor_join_handle.thread().id() {
let watch_result = monitor_join_handle.join().unwrap(); return match monitor_join_handle.join() {
Ok(join_result_result) => match join_result_result {
return Ok(match watch_result { Ok(watch_result) => Ok(match watch_result {
Ok(threads) => FinishReason::ThreadPanic { threads }, Ok(threads) => FinishReason::ThreadPanic { threads },
Err(error) => FinishReason::MonitorError { error }, Err(error) => FinishReason::MonitorError { error },
}); }),
Err(error) => Ok(FinishReason::MonitorSendError { error }),
},
Err(_err) => Ok(FinishReason::MonitorJoinError),
};
} else if watched_trigger_thread_ids.contains(&finished_thread_id) { } else if watched_trigger_thread_ids.contains(&finished_thread_id) {
return Ok(FinishReason::ThreadFinished { return Ok(FinishReason::ThreadFinished {
thread: finished_thread, thread: finished_thread,
...@@ -176,16 +188,17 @@ mod tests { ...@@ -176,16 +188,17 @@ mod tests {
use super::ThreadManager; use super::ThreadManager;
lazy_static! { lazy_static! {
static ref MONITOR: ThreadMonitor = ThreadMonitor::new(); static ref MONITOR_FOR_FINISH: ThreadMonitor = ThreadMonitor::new();
static ref MONITOR_FOR_PANIC: ThreadMonitor = ThreadMonitor::new();
} }
#[test] #[test]
pub fn test_join_with_panic() { pub fn test_join_with_panic() {
MONITOR.init().unwrap(); MONITOR_FOR_PANIC.init().unwrap();
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
let mut manager = ThreadManager::new(&MONITOR); let mut manager = ThreadManager::new(&MONITOR_FOR_PANIC);
manager.spawn( manager.spawn(
move || { move || {
...@@ -197,27 +210,29 @@ mod tests { ...@@ -197,27 +210,29 @@ mod tests {
); );
// Create a separate thread to trigger the panic. // Create a separate thread to trigger the panic.
thread::spawn(move || { let trigger_handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(10)); thread::sleep(Duration::from_millis(1000));
tx.send(true).unwrap(); tx.send(true)
}); });
let join_result = manager.join_all(); let join_result = manager.join_all();
let trigger_result = trigger_handle.join().unwrap();
assert!(trigger_result.is_ok());
assert!(matches!( assert!(matches!(
join_result.unwrap(), join_result,
FinishReason::ThreadPanic { threads: _ } Ok(FinishReason::ThreadPanic { threads: _ })
)) ));
} }
#[test] #[test]
pub fn test_join_with_finish() { pub fn test_join_with_finish() {
MONITOR.init().unwrap(); MONITOR_FOR_FINISH.init().unwrap();
let (tx, rx) = mpsc::channel(); let (tx, rx) = mpsc::channel();
let mut manager = ThreadManager::new(&MONITOR); let mut manager = ThreadManager::new(&MONITOR_FOR_FINISH);
manager.spawn( manager.spawn(
move || { move || {
...@@ -227,17 +242,20 @@ mod tests { ...@@ -227,17 +242,20 @@ mod tests {
); );
// Create a separate thread to trigger the finish. // Create a separate thread to trigger the finish.
thread::spawn(move || { let trigger_handle = thread::spawn(move || {
thread::sleep(Duration::from_millis(10)); thread::sleep(Duration::from_millis(10));
tx.send(true).unwrap(); tx.send(true)
}); });
let join_result = manager.join_all(); let join_result = manager.join_all();
let trigger_result = trigger_handle.join().unwrap();
assert!(trigger_result.is_ok());
assert!(matches!( assert!(matches!(
join_result.unwrap(), join_result,
FinishReason::ThreadFinished { thread: _ } Ok(FinishReason::ThreadFinished { thread: _ })
)) ))
} }
} }
...@@ -29,6 +29,8 @@ pub enum Error { ...@@ -29,6 +29,8 @@ pub enum Error {
/// Make sure to call [`ThreadMonitor::init`] before using it. /// Make sure to call [`ThreadMonitor::init`] before using it.
#[error("The monitor was invoked before it was initialized.")] #[error("The monitor was invoked before it was initialized.")]
Uninitialized, Uninitialized,
#[error("There was an attempt to reinitialize the monitor.")]
Reinitialized,
} }
/// Result type for [`ThreadMonitor`] methods. /// Result type for [`ThreadMonitor`] methods.
...@@ -72,6 +74,14 @@ impl ThreadMonitor { ...@@ -72,6 +74,14 @@ impl ThreadMonitor {
/// [`set_watched`]: #method.set_watched /// [`set_watched`]: #method.set_watched
/// [`watch`]: #method.watch /// [`watch`]: #method.watch
pub fn init(&'static self) -> Result<()> { pub fn init(&'static self) -> Result<()> {
let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;
// If we are already initialized, then do nothing. Adding another
// panic hook for the same monitor would cause issues.
if state.initialized {
return Err(Error::Reinitialized);
}
let hook = panic::take_hook(); let hook = panic::take_hook();
panic::set_hook(Box::new(move |panic_info| { panic::set_hook(Box::new(move |panic_info| {
...@@ -94,8 +104,6 @@ impl ThreadMonitor { ...@@ -94,8 +104,6 @@ impl ThreadMonitor {
hook(panic_info) hook(panic_info)
})); }));
let mut state = self.state.lock().map_err(|_| Error::WatchStateLock)?;
state.initialized = true; state.initialized = true;
Ok(()) Ok(())
......