From c32b31bf98095ba7480cc6a986c9199fd80ca299 Mon Sep 17 00:00:00 2001
From: Eduardo Trujillo <ed@chromabits.com>
Date: Tue, 15 Nov 2022 10:02:32 -0800
Subject: [PATCH] refactor: Use tokio's async-aware RwLock

---
 src/bundle/mod.rs    | 42 +++++++++++++++++++++---------------------
 src/files/mod.rs     |  3 ++-
 src/files/service.rs | 22 ++++++++++------------
 src/main.rs          |  3 ++-
 src/server.rs        |  3 ++-
 src/stats.rs         | 17 +++++++++--------
 6 files changed, 46 insertions(+), 44 deletions(-)

diff --git a/src/bundle/mod.rs b/src/bundle/mod.rs
index 68cd98e..69f291f 100644
--- a/src/bundle/mod.rs
+++ b/src/bundle/mod.rs
@@ -3,11 +3,11 @@ use rundir::RunDir;
 use s3::packager::S3BundlePackager;
 use serde::Serialize;
 use snafu::{ResultExt, Snafu};
-use std::{
-  path::PathBuf,
-  sync::{Arc, RwLock},
+use std::{path::PathBuf, sync::Arc};
+use tokio::{
+  sync::RwLock,
+  time::{interval, Duration},
 };
-use tokio::time::{interval, Duration};
 
 use self::{local_dir::poller::LocalBundlePoller, s3::poller::S3BundlePoller};
 
@@ -145,20 +145,20 @@ impl Unbundler {
     }
   }
 
-  pub fn get_status(&self) -> Result<UnbundlerStatus> {
-    let state = self.state.read().map_err(|_| Error::LockRead)?;
+  pub async fn get_status(&self) -> Result<UnbundlerStatus> {
+    let state = self.state.read().await;
 
     Ok(state.status)
   }
 
-  pub fn get_serve_dir(&self) -> Result<Option<PathBuf>> {
-    let serve_dir = self.serve_dir.read().map_err(|_| Error::LockRead)?;
+  pub async fn get_serve_dir(&self) -> Result<Option<PathBuf>> {
+    let serve_dir = self.serve_dir.read().await;
 
     Ok(serve_dir.clone())
   }
 
   pub async fn enter(&self) -> Result<()> {
-    self.init()?;
+    self.init().await?;
 
     let mut interval = interval(Duration::from_secs(self.config.unbundler.poll_seconds));
     let ctrl_c_fut = tokio::signal::ctrl_c();
@@ -181,13 +181,13 @@ impl Unbundler {
       }
     }
 
-    self.deinit()
+    self.deinit().await
   }
 
-  fn init(&self) -> Result<()> {
+  async fn init(&self) -> Result<()> {
     info!("Unbundler: Initializing...");
 
-    let mut state = self.state.write().map_err(|_| Error::LockWrite)?;
+    let mut state = self.state.write().await;
 
     state.rundir.initialize().context(InitRunDir)?;
     state.temp_dir = Some(state.rundir.create_subdir("temp").context(InitRunDir)?);
@@ -205,7 +205,7 @@ impl Unbundler {
     //
     // Lock will be released when `initial_state` goes out of scope.
     let active_bundle = {
-      let mut initial_state = self.state.write().map_err(|_| Error::LockWrite)?;
+      let mut initial_state = self.state.write().await;
 
       let active_bundle = initial_state.active_bundle.clone();
       if active_bundle.is_none() {
@@ -220,7 +220,7 @@ impl Unbundler {
     // Invoke the poller.
     let result = match self.poller.poll(&active_bundle).await {
       Err(err) => {
-        let mut state = self.state.write().map_err(|_| Error::LockWrite)?;
+        let mut state = self.state.write().await;
 
         // Rollback status if the poll fails.
         if active_bundle.is_none() {
@@ -237,7 +237,7 @@ impl Unbundler {
 
     match result {
       poller::PollResult::Skip => {
-        let mut state = self.state.write().map_err(|_| Error::LockWrite)?;
+        let mut state = self.state.write().await;
 
         state.status = UnbundlerStatus::Ready;
 
@@ -246,21 +246,21 @@ impl Unbundler {
         Ok(())
       }
       poller::PollResult::StaticUpdateReady { etag, path } => {
-        let mut state = self.state.write().map_err(|_| Error::LockWrite)?;
+        let mut state = self.state.write().await;
 
         // Replacing active bundle.
         state.active_bundle = Some(Bundle { etag });
         state.staging_bundle = None;
         state.status = UnbundlerStatus::Ready;
 
-        let mut serve_dir = self.serve_dir.write().map_err(|_| Error::LockWrite)?;
+        let mut serve_dir = self.serve_dir.write().await;
 
         serve_dir.replace(path);
 
         Ok(())
       }
       poller::PollResult::UpdateReady { etag } => {
-        let mut state = self.state.write().map_err(|_| Error::LockWrite)?;
+        let mut state = self.state.write().await;
 
         if state.rundir.subdir_exists(&etag).context(SubDirError)? {
           warn!("Unbundler: Skipping update. Subdir already exists.");
@@ -292,7 +292,7 @@ impl Unbundler {
         state.staging_bundle = None;
         state.status = UnbundlerStatus::Ready;
 
-        let mut serve_dir = self.serve_dir.write().map_err(|_| Error::LockWrite)?;
+        let mut serve_dir = self.serve_dir.write().await;
 
         serve_dir.replace(newdir);
 
@@ -301,8 +301,8 @@ impl Unbundler {
     }
   }
 
-  fn deinit(&self) -> Result<()> {
-    let mut state = self.state.write().map_err(|_| Error::LockWrite)?;
+  async fn deinit(&self) -> Result<()> {
+    let mut state = self.state.write().await;
 
     state.status = UnbundlerStatus::Idle;
     state.rundir.cleanup().context(DeinitRunDir)?;
diff --git a/src/files/mod.rs b/src/files/mod.rs
index 7d85c67..9eb8f71 100644
--- a/src/files/mod.rs
+++ b/src/files/mod.rs
@@ -4,7 +4,7 @@
 use std::cell::RefCell;
 use std::path::PathBuf;
 use std::rc::Rc;
-use std::sync::{Arc, RwLock};
+use std::sync::{Arc};
 
 use actix_service::boxed::{self, BoxServiceFactory};
 use actix_service::{IntoServiceFactory, ServiceFactory, ServiceFactoryExt};
@@ -17,6 +17,7 @@ use futures_util::future::LocalBoxFuture;
 
 use path_context::PathContext;
 use service::FilesService;
+use tokio::sync::RwLock;
 
 use self::service::FilesServiceInner;
 
diff --git a/src/files/service.rs b/src/files/service.rs
index b1e3771..8700c5c 100644
--- a/src/files/service.rs
+++ b/src/files/service.rs
@@ -20,13 +20,14 @@ use async_recursion::async_recursion;
 use futures_util::future::LocalBoxFuture;
 use snafu::ResultExt;
 use snafu::Snafu;
+use tokio::sync::RwLock;
 use std::{
   convert::TryInto,
   io,
   ops::Deref,
   path::{Path, PathBuf},
   rc::Rc,
-  sync::{Arc, RwLock},
+  sync::{Arc},
   task::{Context, Poll},
 };
 
@@ -163,8 +164,8 @@ impl FilesServiceInner {
     }
   }
 
-  fn get_request_context_for_request(&self, req: &ServiceRequest) -> Result<RequestContext, Error> {
-    let serve_dir = self.get_serve_dir()?;
+  async fn get_request_context_for_request(&self, req: &ServiceRequest) -> Result<RequestContext, Error> {
+    let serve_dir = self.get_serve_dir().await?;
     let path_from_request: UriPathBuf = req.try_into().context(BuildUriPath)?;
 
     Ok(RequestContext {
@@ -179,7 +180,7 @@ impl FilesServiceInner {
     e: Error,
     req: ServiceRequest,
   ) -> Result<ServiceResponse, ActixError> {
-    let request_context = self.get_request_context_for_request(&req);
+    let request_context = self.get_request_context_for_request(&req).await;
 
     self.handle_err(e, req, request_context.ok()).await
   }
@@ -382,11 +383,8 @@ impl FilesServiceInner {
     }
   }
 
-  fn get_serve_dir(&self) -> Result<PathBuf, Error> {
-    let maybe_dir = match self.directory.read().map(|dir| dir.clone()) {
-      Ok(maybe_dir) => maybe_dir,
-      Err(_) => return Err(Error::ServerDirReadLockFail),
-    };
+  async fn get_serve_dir(&self) -> Result<PathBuf, Error> {
+    let maybe_dir = self.directory.read().await.clone();
 
     match maybe_dir {
       Some(dir) => Ok(dir),
@@ -394,11 +392,11 @@ impl FilesServiceInner {
     }
   }
 
-  fn get_canonical_path_from_request(
+  async fn get_canonical_path_from_request(
     &self,
     req: &ServiceRequest,
   ) -> Result<(PathBuf, PathBuf, UriPathBuf), Error> {
-    let serve_dir = self.get_serve_dir()?;
+    let serve_dir = self.get_serve_dir().await?;
 
     let path_from_request: UriPathBuf = req.try_into().context(BuildUriPath)?;
 
@@ -463,7 +461,7 @@ impl Service<ServiceRequest> for FilesService {
           .await;
       }
 
-      let maybe_path = this.get_canonical_path_from_request(&req);
+      let maybe_path = this.get_canonical_path_from_request(&req).await;
       let (path, serve_dir, path_from_request) = match maybe_path {
         Ok(path) => path,
         Err(e) => return this.handle_early_err(e, req).await,
diff --git a/src/main.rs b/src/main.rs
index 4dd2cff..a9b2c1e 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -11,11 +11,12 @@ use monitor::Monitor;
 use server::Server;
 use snafu::{ResultExt, Snafu};
 use stats::StatsServer;
+use tokio::sync::RwLock;
 use std::{
   collections::HashSet,
   net::SocketAddr,
   path::PathBuf,
-  sync::{mpsc, Arc, RwLock},
+  sync::{mpsc, Arc},
 };
 
 pub mod bundle;
diff --git a/src/server.rs b/src/server.rs
index 19554ba..331fbcd 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -12,12 +12,13 @@ use actix_web::{
   App, HttpServer,
 };
 use snafu::{ResultExt, Snafu};
+use tokio::sync::RwLock;
 use std::{
   convert::{TryFrom, TryInto},
   path::PathBuf,
   sync::{
     mpsc::{self, RecvError, Sender},
-    Arc, RwLock,
+    Arc,
   },
   thread::{JoinHandle, Thread},
 };
diff --git a/src/stats.rs b/src/stats.rs
index 835fe1b..b374f45 100644
--- a/src/stats.rs
+++ b/src/stats.rs
@@ -68,9 +68,7 @@ impl StatsServer {
 
       tx.send(srv.handle()).context(ChannelSend)?;
 
-      rt.block_on(async {
-        srv.await
-      }).context(SystemRun)?;
+      rt.block_on(async { srv.await }).context(SystemRun)?;
 
       Ok(())
     });
@@ -90,15 +88,18 @@ struct GetStatusResponse {
 }
 
 async fn try_get_status(data: actix_web::web::Data<State>) -> anyhow::Result<String> {
-  let status = data.unbundler.get_status()?;
-    let serve_dir = data.unbundler.get_serve_dir()?;
+  let status = data.unbundler.get_status().await?;
+  let serve_dir = data.unbundler.get_serve_dir().await?;
 
-  Ok(serde_json::to_string(&GetStatusResponse { status, serve_dir })?)
+  Ok(serde_json::to_string(&GetStatusResponse {
+    status,
+    serve_dir,
+  })?)
 }
 
 async fn get_status(data: actix_web::web::Data<State>) -> impl Responder {
   let result = try_get_status(data).await;
-  
+
   match result {
     Ok(response) => HttpResponse::Ok().body(response),
     Err(_) => HttpResponse::InternalServerError().body("Internal Server Error"),
@@ -106,7 +107,7 @@ async fn get_status(data: actix_web::web::Data<State>) -> impl Responder {
 }
 
 async fn get_health(data: actix_web::web::Data<State>) -> impl Responder {
-  match data.unbundler.get_status() {
+  match data.unbundler.get_status().await {
     Ok(status) => match status {
       UnbundlerStatus::Ready => HttpResponse::Ok().body("Ready"),
       UnbundlerStatus::Polling => HttpResponse::Ok().body("Ready"),
-- 
GitLab