From b03081b54501fb2814b9728b8fcc13c1297dae57 Mon Sep 17 00:00:00 2001
From: Eduardo Trujillo <ed@chromabits.com>
Date: Tue, 15 Nov 2022 12:44:20 -0800
Subject: [PATCH] refactor(thread): Use thread module from collective crate

---
 Cargo.lock    |   2 +-
 Cargo.toml    |   2 +-
 src/lib.rs    |   1 -
 src/main.rs   |  52 ++++++++--------
 src/server.rs |  15 ++---
 src/stats.rs  |  14 ++---
 src/thread.rs | 167 --------------------------------------------------
 7 files changed, 37 insertions(+), 216 deletions(-)
 delete mode 100644 src/thread.rs

diff --git a/Cargo.lock b/Cargo.lock
index 800f412..ce654a8 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -933,7 +933,7 @@ dependencies = [
 [[package]]
 name = "collective"
 version = "0.1.2"
-source = "git+https://gitlab.chromabits.com/etcinit/collective.git?rev=d976875136f684e04aa8e5a800d35d5a9c08e480#d976875136f684e04aa8e5a800d35d5a9c08e480"
+source = "git+https://gitlab.chromabits.com/etcinit/collective.git?rev=f6f46f690d63f142ad6c5e95dd806d24b9cea6d4#f6f46f690d63f142ad6c5e95dd806d24b9cea6d4"
 dependencies = [
  "clap",
  "figment",
diff --git a/Cargo.toml b/Cargo.toml
index 735fcf8..ebaa783 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -46,7 +46,7 @@ console-subscriber = "0.1.8"
 
 [dependencies.collective]
 git = "https://gitlab.chromabits.com/etcinit/collective.git"
-rev = "d976875136f684e04aa8e5a800d35d5a9c08e480"
+rev = "f6f46f690d63f142ad6c5e95dd806d24b9cea6d4"
 
 [dependencies.tokio]
 version = "1.0"
diff --git a/src/lib.rs b/src/lib.rs
index 901e286..c9ae179 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -12,4 +12,3 @@ pub mod files;
 pub mod monitor;
 pub mod server;
 pub mod stats;
-pub mod thread;
diff --git a/src/main.rs b/src/main.rs
index a07de9d..d47c14f 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -11,6 +11,7 @@ use clap::{Parser, Subcommand};
 use collective::{
   cli::{AppOpts, ConfigurableAppOpts},
   config::ConfigFileFormat,
+  thread,
 };
 use lazy_static::lazy_static;
 use monitor::Monitor;
@@ -31,7 +32,6 @@ pub mod files;
 pub mod monitor;
 pub mod server;
 pub mod stats;
-pub mod thread;
 
 lazy_static! {
   static ref MONITOR: Monitor = Monitor::new();
@@ -157,7 +157,6 @@ async fn serve(config: Arc<config::Config>) -> Result<()> {
 
   // Keep track of what threads have been started.
   let mut server_thread_ids = HashSet::new();
-  let mut server_thread_handles = vec![];
 
   // Set up unbundler.
   let serve_dir = Arc::new(RwLock::new(None));
@@ -166,15 +165,15 @@ async fn serve(config: Arc<config::Config>) -> Result<()> {
   // Set up main server.
   let server = Server::new(config.server.clone(), serve_dir);
 
-  let (server_handle, server_join_handle, server_thread_handle) = server
-    .spawn(monitor_tx.clone())
-    .await
-    .map_err(|err| Error::ServeError {
-      source: Box::new(err),
-    })?;
+  let (server_handle, server_thread_handle) =
+    server
+      .spawn(monitor_tx.clone())
+      .await
+      .map_err(|err| Error::ServeError {
+        source: Box::new(err),
+      })?;
 
-  server_thread_ids.insert(server_join_handle.thread().id());
-  server_thread_handles.push(server_thread_handle);
+  server_thread_ids.insert(server_thread_handle.thread().id());
 
   // Set up optional stats server.
   let mut maybe_stats_server_handle = None;
@@ -183,35 +182,34 @@ async fn serve(config: Arc<config::Config>) -> Result<()> {
     Some(stats_config) => {
       let stats_server = StatsServer::new(stats_config.clone(), unbundler.clone());
 
-      let (stats_server_handle, stats_join_handle, stats_thread_handle) = stats_server
+      let (stats_server_handle, stats_thread_handle) = stats_server
         .spawn(monitor_tx.clone())
         .await
         .context(ServeStats)?;
 
       maybe_stats_server_handle = Some(stats_server_handle);
-      server_thread_ids.insert(stats_join_handle.thread().id());
-      server_thread_handles.push(stats_thread_handle);
+      server_thread_ids.insert(stats_thread_handle.thread().id());
     }
     None => {}
   }
 
-  let (unbundler_join_handle, unbundler_thread_handle) =
-    thread::spawn(monitor_tx.clone(), move || {
-      let sys = System::new();
+  let unbundler_thread_handle = thread::handle::spawn(monitor_tx.clone(), move || {
+    let sys = System::new();
 
-      let result = sys
-        .block_on(async move { unbundler.enter().await })
-        .context(Unbundle);
+    let result = sys
+      .block_on(async move { unbundler.enter().await })
+      .context(Unbundle);
 
-      if let Err(e) = result {
-        error!("Unbundler failed: {:?}", e);
-      }
-    });
+    if let Err(e) = result {
+      error!("Unbundler failed: {:?}", e);
+    }
+  });
+  let unbundler_thread_id = unbundler_thread_handle.thread().id();
 
-  let (_, monitor_thread_handle) = thread::spawn(monitor_tx.clone(), move || {
+  let monitor_thread_handle = thread::handle::spawn(monitor_tx.clone(), move || {
     let mut watched_thread_ids = HashSet::new();
 
-    watched_thread_ids.insert(unbundler_join_handle.thread().id());
+    watched_thread_ids.insert(unbundler_thread_id);
 
     for server_thread_id in server_thread_ids {
       watched_thread_ids.insert(server_thread_id);
@@ -226,11 +224,11 @@ async fn serve(config: Arc<config::Config>) -> Result<()> {
   loop {
     monitor_rx.recv().map_err(|_| Error::RecvNotify)?;
 
-    if Ok(true) == monitor_thread_handle.has_ended() {
+    if Ok(true) == monitor_thread_handle.get_end_handle().has_ended() {
       info!("Stopping servers due to a panic.");
 
       break;
-    } else if Ok(true) == unbundler_thread_handle.has_ended() {
+    } else if Ok(true) == unbundler_thread_handle.get_end_handle().has_ended() {
       info!("Stopping servers due to unbundler shutdown.");
 
       break;
diff --git a/src/server.rs b/src/server.rs
index dfcbdbf..3ae0be1 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,8 +1,6 @@
 use crate::{
   config::{CompressionConfig, ServerConfig},
   files::{path_context::PathContext, Files},
-  thread,
-  thread::ThreadHandle,
 };
 use actix_http::http::uri::InvalidUri;
 use actix_rt::Runtime;
@@ -11,6 +9,7 @@ use actix_web::{
   middleware::{self, Condition, Logger},
   App, HttpServer,
 };
+use collective::thread::{self, handle::ThreadHandle};
 use snafu::{ResultExt, Snafu};
 use std::{
   convert::{TryFrom, TryInto},
@@ -19,7 +18,7 @@ use std::{
     mpsc::{self, RecvError, Sender},
     Arc,
   },
-  thread::{JoinHandle, Thread},
+  thread::Thread,
 };
 use tokio::sync::RwLock;
 
@@ -57,7 +56,7 @@ impl Server {
   pub async fn spawn(
     self,
     notify_sender: Sender<Thread>,
-  ) -> Result<(ServerHandle, JoinHandle<Result<()>>, ThreadHandle)> {
+  ) -> Result<(ServerHandle, ThreadHandle<Result<()>>)> {
     log::debug!("Starting server thread");
 
     let (tx, rx) = mpsc::channel();
@@ -83,7 +82,7 @@ impl Server {
 
     let path_contexts = Arc::new(path_contexts);
 
-    let (join_handle, thread_handle) = thread::spawn(notify_sender, move || {
+    let thread_handle = thread::handle::spawn(notify_sender, move || {
       let rt = Runtime::new().unwrap();
 
       let srv = HttpServer::new(move || {
@@ -118,10 +117,6 @@ impl Server {
       Ok(())
     });
 
-    Ok((
-      rx.recv().context(ChannelReceive)?,
-      join_handle,
-      thread_handle,
-    ))
+    Ok((rx.recv().context(ChannelReceive)?, thread_handle))
   }
 }
diff --git a/src/stats.rs b/src/stats.rs
index 2eef1e4..0c513f2 100644
--- a/src/stats.rs
+++ b/src/stats.rs
@@ -1,4 +1,3 @@
-use crate::thread::{self, ThreadHandle};
 use crate::{
   bundle::{Unbundler, UnbundlerStatus},
   config::StatsConfig,
@@ -7,13 +6,14 @@ use actix_rt::Runtime;
 use actix_web::dev::ServerHandle;
 use actix_web::web::Data;
 use actix_web::{middleware::Logger, App, HttpResponse, HttpServer, Responder};
+use collective::thread::{self, handle::ThreadHandle};
 use mpsc::{RecvError, SendError, Sender};
 use serde::Serialize;
 use snafu::{ResultExt, Snafu};
 use std::{
   path::PathBuf,
   sync::{mpsc, Arc},
-  thread::{JoinHandle, Thread},
+  thread::Thread,
 };
 
 #[derive(Debug, Snafu)]
@@ -43,10 +43,10 @@ impl StatsServer {
   pub async fn spawn(
     self,
     notify_sender: Sender<Thread>,
-  ) -> Result<(ServerHandle, JoinHandle<Result<()>>, ThreadHandle)> {
+  ) -> Result<(ServerHandle, ThreadHandle<Result<()>>)> {
     let (tx, rx) = mpsc::channel();
 
-    let (join_handle, thread_handle) = thread::spawn(notify_sender, move || {
+    let thread_handle = thread::handle::spawn(notify_sender, move || {
       let rt = Runtime::new().unwrap();
 
       let unbundler = self.unbundler.clone();
@@ -73,11 +73,7 @@ impl StatsServer {
       Ok(())
     });
 
-    Ok((
-      rx.recv().context(ChannelReceive)?,
-      join_handle,
-      thread_handle,
-    ))
+    Ok((rx.recv().context(ChannelReceive)?, thread_handle))
   }
 }
 
diff --git a/src/thread.rs b/src/thread.rs
deleted file mode 100644
index 8d48c37..0000000
--- a/src/thread.rs
+++ /dev/null
@@ -1,167 +0,0 @@
-//! Multithreading Utilities.
-
-use mpsc::Sender;
-use snafu::Snafu;
-use std::{
-  sync::{mpsc, Arc, RwLock},
-  thread::JoinHandle,
-};
-
-#[derive(Snafu, Debug, PartialEq)]
-pub enum Error {
-  // Unable to obtain a read lock to check the status of a thread.
-  LockRead,
-}
-
-pub type Result<T, E = Error> = std::result::Result<T, E>;
-
-/// A lightweight abstration over a regular thread that provides an API for
-/// determining if a thread has terminated.
-pub struct ThreadHandle {
-  ended: Arc<RwLock<bool>>,
-}
-
-impl ThreadHandle {
-  /// Attempts to check if the thread has ended.
-  ///
-  /// An error may be returned if the underlying channel is disconnected.
-  pub fn has_ended(&self) -> Result<bool> {
-    let result = self.ended.read().map_err(|_| Error::LockRead)?;
-
-    Ok(*result)
-  }
-}
-
-/// Like `std::thread::spawn`, but returns a `ThreadHandle` instead.
-///
-/// # Examples
-///
-/// Create ten threads and wait for all threads to finish.
-///
-/// ```
-/// use espresso::thread::spawn;
-/// use std::{
-///    collections::HashMap,
-///    sync::{mpsc, Arc, Barrier},
-///  };
-///
-/// let (monitor_tx, monitor_rx) = mpsc::channel();
-/// let barrier = Arc::new(Barrier::new(10));
-///
-/// let mut handles = HashMap::new();
-///
-/// for _ in 0..10 {
-///   let bc = barrier.clone();
-///
-///   let (join_handle, thread_handle) = spawn(monitor_tx.clone(), move || {
-///     /// Sync all threads.
-///     bc.wait();
-///   });
-///
-///   handles.insert(join_handle.thread().id(), thread_handle);
-/// }
-///
-/// // Loop until we have been notified of every thread ending.
-/// loop {
-///   let thread = monitor_rx.recv().unwrap();
-///
-///   handles.remove(&thread.id());
-///
-///   if handles.is_empty() {
-///     break;
-///   }
-/// }
-/// ```
-pub fn spawn<F, T>(
-  notify_sender: Sender<std::thread::Thread>,
-  f: F,
-) -> (JoinHandle<T>, ThreadHandle)
-where
-  F: FnOnce() -> T,
-  F: Send + 'static,
-  T: Send + 'static,
-{
-  let ended = Arc::new(RwLock::new(false));
-  let ended_for_spawn = ended.clone();
-
-  let join_handle = std::thread::spawn(move || {
-    let ended = ended_for_spawn.clone();
-
-    let result = f();
-
-    let mut ended = ended.write().unwrap();
-    *ended = true;
-    notify_sender.send(std::thread::current()).unwrap();
-
-    result
-  });
-
-  (join_handle, ThreadHandle { ended })
-}
-
-#[cfg(test)]
-mod tests {
-  use super::spawn;
-  use std::{
-    collections::HashMap,
-    sync::{mpsc, Arc, Barrier},
-  };
-
-  #[test]
-  fn test_spawn() {
-    let (monitor_tx, monitor_rx) = mpsc::channel();
-    let (ready_tx, ready_rx) = mpsc::channel();
-    let (end_tx, end_rx) = mpsc::channel();
-
-    let (join_handle, handle) = spawn(monitor_tx, move || {
-      ready_tx.send(()).unwrap();
-
-      end_rx.recv().unwrap();
-    });
-
-    ready_rx.recv().unwrap();
-
-    assert_eq!((&handle).has_ended(), Ok(false));
-
-    end_tx.send(()).unwrap();
-
-    monitor_rx.recv().unwrap();
-    join_handle.join().unwrap();
-
-    assert_eq!(handle.has_ended(), Ok(true));
-  }
-
-  #[test]
-  fn test_multiple() {
-    let (monitor_tx, monitor_rx) = mpsc::channel();
-    let barrier = Arc::new(Barrier::new(11));
-
-    let mut handles = HashMap::new();
-
-    for _ in 0..10 {
-      let bc = barrier.clone();
-
-      let (join_handle, thread_handle) = spawn(monitor_tx.clone(), move || {
-        bc.wait();
-      });
-
-      handles.insert(join_handle.thread().id(), thread_handle);
-    }
-
-    for (_, handle) in &handles {
-      assert_eq!(handle.has_ended(), Ok(false));
-    }
-
-    barrier.wait();
-
-    loop {
-      let thread = monitor_rx.recv().unwrap();
-
-      handles.remove(&thread.id());
-
-      if handles.is_empty() {
-        break;
-      }
-    }
-  }
-}
-- 
GitLab