Skip to content
Snippets Groups Projects
manager.rs 7.35 KiB
Newer Older
use std::{
    collections::HashSet,
    sync::mpsc::{self, Receiver, Sender},
    thread::{JoinHandle, Thread, ThreadId},
};

use thiserror::Error;

use crate::thread::monitor::{Error as ThreadMonitorError, ThreadMonitor};

/// Error type for [`ThreadManager`].
#[derive(Error, Debug)]
pub enum Error {
    /// Unable to receive from the monitor channel.
    #[error("Unable to receive from the monitor channel.")]
    RecvNotify,
    #[error("Unable to initialize internal monitor: {source}")]
    Monitor {
        #[from]
        source: ThreadMonitorError,
    },
}

/// Reason why a call to [`ThreadManager::join_all`] returned.
pub enum FinishReason {
    /// The monitor thread encountered an error.
    MonitorError { error: ThreadMonitorError },
    /// A thread marked as `triggers_shutdown` finished.
    ThreadFinished { thread: Thread },
    /// One of more threads panicked.
    ThreadPanic { threads: Vec<Thread> },
}

/// An abstraction for building simple multi-threaded applications.
/// It's goal is to simplify the process of setting up an application composed
/// of a fixed number of threads, adding in support for common tasks such as
/// detecting panics and waiting for some threads to finish.
///
/// It achieves this by focusing on a narrow use-case:
///
/// - The application has a fixed number of threads.
/// - The threads are all started at the same time during initialization.
/// - A panic in any thread is expected to trigger a shut down of the
///   application.
/// - Some threads finishing should trigger a shut down of the application.
/// For example, a web application has a main server thread and a metrics/stats
/// server thread.
/// [`ThreadManager`] is a higher level abstraction built on top of
/// [`ThreadMonitor`], which is used to detect and handle panics in other
/// threads. [`ThreadMonitor`] can be used directly if [`ThreadManager`] does
/// not meet a use-case or low level control is needed.
/// - Only threads spawned directly through the manager interface can be
///   monitored. The manager doesn't monitor other threads spawned separately,
///   even if it's done inside one of the original spawned threads.
/// - The return type of the [`JoinHandle`] is expected to be `()`.
pub struct ThreadManager {
    join_handles: Vec<(JoinHandle<()>, bool)>,
    monitor: &'static ThreadMonitor,
    monitor_tx: Sender<Thread>,
    monitor_rx: Receiver<Thread>,
}

impl ThreadManager {
    /// Instantiates a new manager.
    ///
    /// A static reference to a [`ThreadMonitor`] that has already been
    /// initialized is expected.
    /// See the documentation of [`ThreadMonitor::init`] for more details.
    pub fn new(monitor: &'static ThreadMonitor) -> ThreadManager {
        let (monitor_tx, monitor_rx) = mpsc::channel();

        ThreadManager {
            join_handles: vec![],
            monitor,
            monitor_rx,
            monitor_tx,
        }
    }

    /// Spawns a new thread to be managed.
    ///
    /// Unlike [`std::thread::spawn`], a copy of the [`Thread`] will be
    /// returned instead of a [`std::thread::JoinHandle`]. The handle will be
    /// used internally by the manager.
    pub fn spawn<F>(&mut self, f: F, triggers_shutdown: bool) -> Thread
    where
        F: FnOnce(),
        F: Send + 'static,
    {
        let monitor_tx_for_spawn = self.monitor_tx.clone();

        let join_handle = std::thread::spawn(move || {
            let result = f();

            monitor_tx_for_spawn.send(std::thread::current()).unwrap();

            result
        });

        let thread = join_handle.thread().clone();

        self.join_handles.push((join_handle, triggers_shutdown));

        thread
    }

    /// Waits for any thread marked as `triggers_shutdown` to finish or for one
    /// or more threads to panic.
    /// A [`FinishReason`] will be returned providing context into why the
    /// manager returned.
    pub fn join_all(self) -> Result<FinishReason, Error> {
        let monitor_tx_for_spawn = self.monitor_tx.clone();

        let watched_thread_ids: HashSet<ThreadId> = self
            .join_handles
            .iter()
            .map(|(join_handle, _)| join_handle.thread().id())
            .collect();

        let watched_trigger_thread_ids: HashSet<ThreadId> = self
            .join_handles
            .iter()
            .filter_map(|(join_handle, triggers_shutdown)| {
                if *triggers_shutdown {
                    Some(join_handle.thread().id())
                } else {
                    None
                }
            })
            .collect();

        #[allow(clippy::clone_double_ref)]
        let monitor_for_monitor_thread = self.monitor.clone();
        let monitor_join_handle = std::thread::spawn(move || {
            let watch_result = monitor_for_monitor_thread.watch(Some(&watched_thread_ids));

            monitor_tx_for_spawn.send(std::thread::current()).unwrap();

            watch_result
        });

        // Wait for a thread to finish.
        loop {
            let finished_thread = self.monitor_rx.recv().map_err(|_| Error::RecvNotify)?;
            let finished_thread_id = finished_thread.id();

            if finished_thread_id == monitor_join_handle.thread().id() {
                let watch_result = monitor_join_handle.join().unwrap();

                return Ok(match watch_result {
                    Ok(threads) => FinishReason::ThreadPanic { threads },
                    Err(error) => FinishReason::MonitorError { error },
                });
            } else if watched_trigger_thread_ids.contains(&finished_thread_id) {
                return Ok(FinishReason::ThreadFinished {
                    thread: finished_thread,
                });
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use crate::thread::{manager::FinishReason, monitor::ThreadMonitor};

    use lazy_static::lazy_static;
    use std::{sync::mpsc, thread, time::Duration};

    use super::ThreadManager;

    lazy_static! {
        static ref MONITOR: ThreadMonitor = ThreadMonitor::new();
    }

    #[test]
    pub fn test_join_with_panic() {
        MONITOR.init().unwrap();

        let (tx, rx) = mpsc::channel();

        let mut manager = ThreadManager::new(&MONITOR);

        manager.spawn(
            move || {
                rx.recv().unwrap();

                panic!("Oh no");
            },
            true,
        );

        // Create a separate thread to trigger the panic.
        thread::spawn(move || {
            thread::sleep(Duration::from_millis(10));

            tx.send(true).unwrap();
        });

        let join_result = manager.join_all();

        assert!(matches!(
            join_result.unwrap(),
            FinishReason::ThreadPanic { threads: _ }
        ))
    }

    #[test]
    pub fn test_join_with_finish() {
        MONITOR.init().unwrap();

        let (tx, rx) = mpsc::channel();

        let mut manager = ThreadManager::new(&MONITOR);

        manager.spawn(
            move || {
                rx.recv().unwrap();
            },
            true,
        );

        // Create a separate thread to trigger the finish.
        thread::spawn(move || {
            thread::sleep(Duration::from_millis(10));

            tx.send(true).unwrap();
        });

        let join_result = manager.join_all();

        assert!(matches!(
            join_result.unwrap(),
            FinishReason::ThreadFinished { thread: _ }
        ))
    }
}