From b03081b54501fb2814b9728b8fcc13c1297dae57 Mon Sep 17 00:00:00 2001 From: Eduardo Trujillo <ed@chromabits.com> Date: Tue, 15 Nov 2022 12:44:20 -0800 Subject: [PATCH] refactor(thread): Use thread module from collective crate --- Cargo.lock | 2 +- Cargo.toml | 2 +- src/lib.rs | 1 - src/main.rs | 52 ++++++++-------- src/server.rs | 15 ++--- src/stats.rs | 14 ++--- src/thread.rs | 167 -------------------------------------------------- 7 files changed, 37 insertions(+), 216 deletions(-) delete mode 100644 src/thread.rs diff --git a/Cargo.lock b/Cargo.lock index 800f412..ce654a8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -933,7 +933,7 @@ dependencies = [ [[package]] name = "collective" version = "0.1.2" -source = "git+https://gitlab.chromabits.com/etcinit/collective.git?rev=d976875136f684e04aa8e5a800d35d5a9c08e480#d976875136f684e04aa8e5a800d35d5a9c08e480" +source = "git+https://gitlab.chromabits.com/etcinit/collective.git?rev=f6f46f690d63f142ad6c5e95dd806d24b9cea6d4#f6f46f690d63f142ad6c5e95dd806d24b9cea6d4" dependencies = [ "clap", "figment", diff --git a/Cargo.toml b/Cargo.toml index 735fcf8..ebaa783 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -46,7 +46,7 @@ console-subscriber = "0.1.8" [dependencies.collective] git = "https://gitlab.chromabits.com/etcinit/collective.git" -rev = "d976875136f684e04aa8e5a800d35d5a9c08e480" +rev = "f6f46f690d63f142ad6c5e95dd806d24b9cea6d4" [dependencies.tokio] version = "1.0" diff --git a/src/lib.rs b/src/lib.rs index 901e286..c9ae179 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -12,4 +12,3 @@ pub mod files; pub mod monitor; pub mod server; pub mod stats; -pub mod thread; diff --git a/src/main.rs b/src/main.rs index a07de9d..d47c14f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -11,6 +11,7 @@ use clap::{Parser, Subcommand}; use collective::{ cli::{AppOpts, ConfigurableAppOpts}, config::ConfigFileFormat, + thread, }; use lazy_static::lazy_static; use monitor::Monitor; @@ -31,7 +32,6 @@ pub mod files; pub mod monitor; pub mod server; pub mod stats; -pub mod thread; lazy_static! { static ref MONITOR: Monitor = Monitor::new(); @@ -157,7 +157,6 @@ async fn serve(config: Arc<config::Config>) -> Result<()> { // Keep track of what threads have been started. let mut server_thread_ids = HashSet::new(); - let mut server_thread_handles = vec![]; // Set up unbundler. let serve_dir = Arc::new(RwLock::new(None)); @@ -166,15 +165,15 @@ async fn serve(config: Arc<config::Config>) -> Result<()> { // Set up main server. let server = Server::new(config.server.clone(), serve_dir); - let (server_handle, server_join_handle, server_thread_handle) = server - .spawn(monitor_tx.clone()) - .await - .map_err(|err| Error::ServeError { - source: Box::new(err), - })?; + let (server_handle, server_thread_handle) = + server + .spawn(monitor_tx.clone()) + .await + .map_err(|err| Error::ServeError { + source: Box::new(err), + })?; - server_thread_ids.insert(server_join_handle.thread().id()); - server_thread_handles.push(server_thread_handle); + server_thread_ids.insert(server_thread_handle.thread().id()); // Set up optional stats server. let mut maybe_stats_server_handle = None; @@ -183,35 +182,34 @@ async fn serve(config: Arc<config::Config>) -> Result<()> { Some(stats_config) => { let stats_server = StatsServer::new(stats_config.clone(), unbundler.clone()); - let (stats_server_handle, stats_join_handle, stats_thread_handle) = stats_server + let (stats_server_handle, stats_thread_handle) = stats_server .spawn(monitor_tx.clone()) .await .context(ServeStats)?; maybe_stats_server_handle = Some(stats_server_handle); - server_thread_ids.insert(stats_join_handle.thread().id()); - server_thread_handles.push(stats_thread_handle); + server_thread_ids.insert(stats_thread_handle.thread().id()); } None => {} } - let (unbundler_join_handle, unbundler_thread_handle) = - thread::spawn(monitor_tx.clone(), move || { - let sys = System::new(); + let unbundler_thread_handle = thread::handle::spawn(monitor_tx.clone(), move || { + let sys = System::new(); - let result = sys - .block_on(async move { unbundler.enter().await }) - .context(Unbundle); + let result = sys + .block_on(async move { unbundler.enter().await }) + .context(Unbundle); - if let Err(e) = result { - error!("Unbundler failed: {:?}", e); - } - }); + if let Err(e) = result { + error!("Unbundler failed: {:?}", e); + } + }); + let unbundler_thread_id = unbundler_thread_handle.thread().id(); - let (_, monitor_thread_handle) = thread::spawn(monitor_tx.clone(), move || { + let monitor_thread_handle = thread::handle::spawn(monitor_tx.clone(), move || { let mut watched_thread_ids = HashSet::new(); - watched_thread_ids.insert(unbundler_join_handle.thread().id()); + watched_thread_ids.insert(unbundler_thread_id); for server_thread_id in server_thread_ids { watched_thread_ids.insert(server_thread_id); @@ -226,11 +224,11 @@ async fn serve(config: Arc<config::Config>) -> Result<()> { loop { monitor_rx.recv().map_err(|_| Error::RecvNotify)?; - if Ok(true) == monitor_thread_handle.has_ended() { + if Ok(true) == monitor_thread_handle.get_end_handle().has_ended() { info!("Stopping servers due to a panic."); break; - } else if Ok(true) == unbundler_thread_handle.has_ended() { + } else if Ok(true) == unbundler_thread_handle.get_end_handle().has_ended() { info!("Stopping servers due to unbundler shutdown."); break; diff --git a/src/server.rs b/src/server.rs index dfcbdbf..3ae0be1 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,8 +1,6 @@ use crate::{ config::{CompressionConfig, ServerConfig}, files::{path_context::PathContext, Files}, - thread, - thread::ThreadHandle, }; use actix_http::http::uri::InvalidUri; use actix_rt::Runtime; @@ -11,6 +9,7 @@ use actix_web::{ middleware::{self, Condition, Logger}, App, HttpServer, }; +use collective::thread::{self, handle::ThreadHandle}; use snafu::{ResultExt, Snafu}; use std::{ convert::{TryFrom, TryInto}, @@ -19,7 +18,7 @@ use std::{ mpsc::{self, RecvError, Sender}, Arc, }, - thread::{JoinHandle, Thread}, + thread::Thread, }; use tokio::sync::RwLock; @@ -57,7 +56,7 @@ impl Server { pub async fn spawn( self, notify_sender: Sender<Thread>, - ) -> Result<(ServerHandle, JoinHandle<Result<()>>, ThreadHandle)> { + ) -> Result<(ServerHandle, ThreadHandle<Result<()>>)> { log::debug!("Starting server thread"); let (tx, rx) = mpsc::channel(); @@ -83,7 +82,7 @@ impl Server { let path_contexts = Arc::new(path_contexts); - let (join_handle, thread_handle) = thread::spawn(notify_sender, move || { + let thread_handle = thread::handle::spawn(notify_sender, move || { let rt = Runtime::new().unwrap(); let srv = HttpServer::new(move || { @@ -118,10 +117,6 @@ impl Server { Ok(()) }); - Ok(( - rx.recv().context(ChannelReceive)?, - join_handle, - thread_handle, - )) + Ok((rx.recv().context(ChannelReceive)?, thread_handle)) } } diff --git a/src/stats.rs b/src/stats.rs index 2eef1e4..0c513f2 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -1,4 +1,3 @@ -use crate::thread::{self, ThreadHandle}; use crate::{ bundle::{Unbundler, UnbundlerStatus}, config::StatsConfig, @@ -7,13 +6,14 @@ use actix_rt::Runtime; use actix_web::dev::ServerHandle; use actix_web::web::Data; use actix_web::{middleware::Logger, App, HttpResponse, HttpServer, Responder}; +use collective::thread::{self, handle::ThreadHandle}; use mpsc::{RecvError, SendError, Sender}; use serde::Serialize; use snafu::{ResultExt, Snafu}; use std::{ path::PathBuf, sync::{mpsc, Arc}, - thread::{JoinHandle, Thread}, + thread::Thread, }; #[derive(Debug, Snafu)] @@ -43,10 +43,10 @@ impl StatsServer { pub async fn spawn( self, notify_sender: Sender<Thread>, - ) -> Result<(ServerHandle, JoinHandle<Result<()>>, ThreadHandle)> { + ) -> Result<(ServerHandle, ThreadHandle<Result<()>>)> { let (tx, rx) = mpsc::channel(); - let (join_handle, thread_handle) = thread::spawn(notify_sender, move || { + let thread_handle = thread::handle::spawn(notify_sender, move || { let rt = Runtime::new().unwrap(); let unbundler = self.unbundler.clone(); @@ -73,11 +73,7 @@ impl StatsServer { Ok(()) }); - Ok(( - rx.recv().context(ChannelReceive)?, - join_handle, - thread_handle, - )) + Ok((rx.recv().context(ChannelReceive)?, thread_handle)) } } diff --git a/src/thread.rs b/src/thread.rs deleted file mode 100644 index 8d48c37..0000000 --- a/src/thread.rs +++ /dev/null @@ -1,167 +0,0 @@ -//! Multithreading Utilities. - -use mpsc::Sender; -use snafu::Snafu; -use std::{ - sync::{mpsc, Arc, RwLock}, - thread::JoinHandle, -}; - -#[derive(Snafu, Debug, PartialEq)] -pub enum Error { - // Unable to obtain a read lock to check the status of a thread. - LockRead, -} - -pub type Result<T, E = Error> = std::result::Result<T, E>; - -/// A lightweight abstration over a regular thread that provides an API for -/// determining if a thread has terminated. -pub struct ThreadHandle { - ended: Arc<RwLock<bool>>, -} - -impl ThreadHandle { - /// Attempts to check if the thread has ended. - /// - /// An error may be returned if the underlying channel is disconnected. - pub fn has_ended(&self) -> Result<bool> { - let result = self.ended.read().map_err(|_| Error::LockRead)?; - - Ok(*result) - } -} - -/// Like `std::thread::spawn`, but returns a `ThreadHandle` instead. -/// -/// # Examples -/// -/// Create ten threads and wait for all threads to finish. -/// -/// ``` -/// use espresso::thread::spawn; -/// use std::{ -/// collections::HashMap, -/// sync::{mpsc, Arc, Barrier}, -/// }; -/// -/// let (monitor_tx, monitor_rx) = mpsc::channel(); -/// let barrier = Arc::new(Barrier::new(10)); -/// -/// let mut handles = HashMap::new(); -/// -/// for _ in 0..10 { -/// let bc = barrier.clone(); -/// -/// let (join_handle, thread_handle) = spawn(monitor_tx.clone(), move || { -/// /// Sync all threads. -/// bc.wait(); -/// }); -/// -/// handles.insert(join_handle.thread().id(), thread_handle); -/// } -/// -/// // Loop until we have been notified of every thread ending. -/// loop { -/// let thread = monitor_rx.recv().unwrap(); -/// -/// handles.remove(&thread.id()); -/// -/// if handles.is_empty() { -/// break; -/// } -/// } -/// ``` -pub fn spawn<F, T>( - notify_sender: Sender<std::thread::Thread>, - f: F, -) -> (JoinHandle<T>, ThreadHandle) -where - F: FnOnce() -> T, - F: Send + 'static, - T: Send + 'static, -{ - let ended = Arc::new(RwLock::new(false)); - let ended_for_spawn = ended.clone(); - - let join_handle = std::thread::spawn(move || { - let ended = ended_for_spawn.clone(); - - let result = f(); - - let mut ended = ended.write().unwrap(); - *ended = true; - notify_sender.send(std::thread::current()).unwrap(); - - result - }); - - (join_handle, ThreadHandle { ended }) -} - -#[cfg(test)] -mod tests { - use super::spawn; - use std::{ - collections::HashMap, - sync::{mpsc, Arc, Barrier}, - }; - - #[test] - fn test_spawn() { - let (monitor_tx, monitor_rx) = mpsc::channel(); - let (ready_tx, ready_rx) = mpsc::channel(); - let (end_tx, end_rx) = mpsc::channel(); - - let (join_handle, handle) = spawn(monitor_tx, move || { - ready_tx.send(()).unwrap(); - - end_rx.recv().unwrap(); - }); - - ready_rx.recv().unwrap(); - - assert_eq!((&handle).has_ended(), Ok(false)); - - end_tx.send(()).unwrap(); - - monitor_rx.recv().unwrap(); - join_handle.join().unwrap(); - - assert_eq!(handle.has_ended(), Ok(true)); - } - - #[test] - fn test_multiple() { - let (monitor_tx, monitor_rx) = mpsc::channel(); - let barrier = Arc::new(Barrier::new(11)); - - let mut handles = HashMap::new(); - - for _ in 0..10 { - let bc = barrier.clone(); - - let (join_handle, thread_handle) = spawn(monitor_tx.clone(), move || { - bc.wait(); - }); - - handles.insert(join_handle.thread().id(), thread_handle); - } - - for (_, handle) in &handles { - assert_eq!(handle.has_ended(), Ok(false)); - } - - barrier.wait(); - - loop { - let thread = monitor_rx.recv().unwrap(); - - handles.remove(&thread.id()); - - if handles.is_empty() { - break; - } - } - } -} -- GitLab