1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200
//! Multithreading Utilities.
use mpsc::Sender;
use std::{
sync::{mpsc, Arc, RwLock},
thread::{JoinHandle, Thread},
};
use thiserror::Error;
/// Error type for [`ThreadHandle`].
#[derive(Error, Debug, PartialEq)]
pub enum Error {
/// Unable to obtain a read lock to check the status of a thread.
#[error("Unable to obtain a read lock to check the status of a thread.")]
LockRead,
}
/// Result type for [`ThreadHandle`].
pub type Result<T, E = Error> = std::result::Result<T, E>;
/// A lightweight abstration over a regular thread that provides an API for
/// determining if a thread has terminated.
pub struct ThreadHandle<T> {
join_handle: JoinHandle<T>,
ended: Arc<RwLock<bool>>,
}
impl<T> ThreadHandle<T> {
/// Returns a [`ThreadEndHandle`], which can be used for determining if a
/// thread has ended.
pub fn get_end_handle(&self) -> ThreadEndHandle {
ThreadEndHandle {
ended: self.ended.clone(),
}
}
/// Extracts a handle to the underlying thread.
pub fn thread(&self) -> &Thread {
self.join_handle.thread()
}
/// Waits for the associated thread to finish.
///
/// See [`std::thread::JoinHandle::join`].
pub fn join(self) -> std::thread::Result<T> {
self.join_handle.join()
}
}
pub struct ThreadEndHandle {
ended: Arc<RwLock<bool>>,
}
impl ThreadEndHandle {
/// Attempts to check if the thread has ended.
///
/// A [`Error::LockRead`] error may be returned if the underlying channel is
/// disconnected.
pub fn has_ended(&self) -> Result<bool> {
let result = self.ended.read().map_err(|_| Error::LockRead)?;
Ok(*result)
}
}
/// Like [`std::thread::spawn`], but returns a [`ThreadHandle`] instead.
///
/// # Examples
///
/// Create ten threads and wait for all threads to finish.
///
/// ```
/// use collective::thread::handle::spawn;
/// use std::{
/// collections::HashMap,
/// sync::{mpsc, Arc, Barrier},
/// };
///
/// let (monitor_tx, monitor_rx) = mpsc::channel();
/// let barrier = Arc::new(Barrier::new(10));
///
/// let mut end_handles = HashMap::new();
///
/// for _ in 0..10 {
/// let bc = barrier.clone();
///
/// let handle = spawn(monitor_tx.clone(), move || {
/// /// Sync all threads.
/// bc.wait();
/// });
///
/// end_handles.insert(handle.thread().id(), handle.get_end_handle());
/// }
///
/// // Loop until we have been notified of every thread ending.
/// loop {
/// let thread = monitor_rx.recv().unwrap();
///
/// end_handles.remove(&thread.id());
///
/// if end_handles.is_empty() {
/// break;
/// }
/// }
/// ```
pub fn spawn<F, T>(notify_sender: Sender<std::thread::Thread>, f: F) -> ThreadHandle<T>
where
F: FnOnce() -> T,
F: Send + 'static,
T: Send + 'static,
{
let ended = Arc::new(RwLock::new(false));
let ended_for_spawn = ended.clone();
let join_handle = std::thread::spawn(move || {
let ended = ended_for_spawn.clone();
let result = f();
let mut ended = ended.write().unwrap();
*ended = true;
// Try to notify. This is a best-effort approach since the receiver may
// be already deallocated in some scenarios (e.g. The application is
// terminating).
let _ = notify_sender.send(std::thread::current());
result
});
ThreadHandle { ended, join_handle }
}
#[cfg(test)]
mod tests {
use super::spawn;
use std::{
collections::HashMap,
sync::{mpsc, Arc, Barrier},
};
#[test]
fn test_spawn() {
let (monitor_tx, monitor_rx) = mpsc::channel();
let (ready_tx, ready_rx) = mpsc::channel();
let (end_tx, end_rx) = mpsc::channel();
let handle = spawn(monitor_tx, move || {
ready_tx.send(()).unwrap();
end_rx.recv().unwrap();
});
ready_rx.recv().unwrap();
let end_handle = handle.get_end_handle();
assert_eq!(end_handle.has_ended(), Ok(false));
end_tx.send(()).unwrap();
monitor_rx.recv().unwrap();
handle.join().unwrap();
assert_eq!(end_handle.has_ended(), Ok(true));
}
#[test]
fn test_multiple() {
let (monitor_tx, monitor_rx) = mpsc::channel();
let barrier = Arc::new(Barrier::new(11));
let mut end_handles = HashMap::new();
for _ in 0..10 {
let bc = barrier.clone();
let handle = spawn(monitor_tx.clone(), move || {
bc.wait();
});
end_handles.insert(handle.thread().id(), handle.get_end_handle());
}
for end_handle in end_handles.values() {
assert_eq!(end_handle.has_ended(), Ok(false));
}
barrier.wait();
loop {
let thread = monitor_rx.recv().unwrap();
end_handles.remove(&thread.id());
if end_handles.is_empty() {
break;
}
}
}
}