From bf150723a4fa1904460cca66371b0e485122172b Mon Sep 17 00:00:00 2001
From: Robin Appelman <robin@icewind.nl>
Date: Thu, 11 Nov 2021 17:23:23 +0100
Subject: [PATCH] add support for listening with tls

Signed-off-by: Robin Appelman <robin@icewind.nl>
---
 Cargo.lock           |  1 +
 Cargo.toml           |  2 +-
 README.md            |  7 +++++++
 src/config.rs        | 33 +++++++++++++++++++++++++++++++++
 src/lib.rs           | 36 +++++++++++++++++++++++++++---------
 src/main.rs          |  9 +++++++--
 src/metrics.rs       |  5 +++--
 tests/integration.rs |  3 ++-
 8 files changed, 81 insertions(+), 15 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 5d9a7d7..e681fbc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -2611,6 +2611,7 @@ dependencies = [
  "serde_json",
  "serde_urlencoded",
  "tokio",
+ "tokio-rustls",
  "tokio-stream",
  "tokio-tungstenite",
  "tokio-util",
diff --git a/Cargo.toml b/Cargo.toml
index 6c34325..e585966 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -9,7 +9,7 @@ redis = { version = "0.21", features = ["tokio-comp", "aio", "cluster"] }
 serde = { version = "1", features = ["derive"] }
 serde_json = "1"
 thiserror = "1"
-warp = "0.3"
+warp = { version = "0.3", features = ["tls"] }
 tokio = { version = "1", features = ["macros", "rt-multi-thread", "signal"] }
 futures = "0.3"
 log = "0.4"
diff --git a/README.md b/README.md
index b337c7c..88e7eb6 100644
--- a/README.md
+++ b/README.md
@@ -135,6 +135,13 @@ Alternatively you can configure the server to listen on a unix socket by setting
 Note that Nextcloud load all files matching `*.config.php` in the config directory in additional to the main config file.
 You can enable this same behavior by passing the `--glob-config` option.
 
+#### TLS Configuration
+
+The push server can be configured to serve over TLS. This is mostly intended for securing the traffic between the push server
+and the reverse proxy if they are running on different hosts, running without a reverse proxy (or load balancer) is not recommended.
+
+TLS can be enabled by setting the `--tls-cert` and `--tls-key` arguments (or the `TLS_CERT` and `TLS_KEY` environment variables).
+
 #### Starting the service
 
 Once the systemd service file is set up with the correct configuration you can start it using
diff --git a/src/config.rs b/src/config.rs
index b33d6e7..4f938cd 100644
--- a/src/config.rs
+++ b/src/config.rs
@@ -69,6 +69,12 @@ pub struct Opt {
     /// Load other files named *.config.php in the config folder
     #[structopt(long)]
     pub glob_config: bool,
+    /// TLS certificate
+    #[structopt(long)]
+    pub tls_cert: Option<PathBuf>,
+    /// TLS key
+    #[structopt(long)]
+    pub tls_key: Option<PathBuf>,
 }
 
 #[derive(Debug)]
@@ -82,6 +88,13 @@ pub struct Config {
     pub bind: Bind,
     pub allow_self_signed: bool,
     pub no_ansi: bool,
+    pub tls: Option<TlsConfig>,
+}
+
+#[derive(Debug, Clone)]
+pub struct TlsConfig {
+    pub key: PathBuf,
+    pub cert: PathBuf,
 }
 
 #[derive(Clone, Derivative)]
@@ -166,6 +179,7 @@ impl TryFrom<PartialConfig> for Config {
             bind,
             allow_self_signed: config.allow_self_signed.unwrap_or(false),
             no_ansi: config.no_ansi.unwrap_or(false),
+            tls: config.tls,
         })
     }
 }
@@ -200,6 +214,7 @@ struct PartialConfig {
     pub socket_permissions: Option<String>,
     pub allow_self_signed: Option<bool>,
     pub no_ansi: Option<bool>,
+    pub tls: Option<TlsConfig>,
 }
 
 impl PartialConfig {
@@ -219,6 +234,15 @@ impl PartialConfig {
         let allow_self_signed = var("ALLOW_SELF_SIGNED").map(|val| val == "true").ok();
         let no_ansi = var("NO_ANSI").map(|val| val == "true").ok();
 
+        let tls_cert = parse_var("TLS_CERT").wrap_err("Invalid TLS_CERT")?;
+        let tls_key = parse_var("TLS_KEY").wrap_err("Invalid TLS_KEY")?;
+
+        let tls = if let (Some(cert), Some(key)) = (tls_cert, tls_key) {
+            Some(TlsConfig { cert, key })
+        } else {
+            None
+        };
+
         Ok(PartialConfig {
             database,
             database_prefix,
@@ -233,6 +257,7 @@ impl PartialConfig {
             socket_permissions,
             allow_self_signed,
             no_ansi,
+            tls,
         })
     }
 
@@ -241,6 +266,12 @@ impl PartialConfig {
     }
 
     fn from_opt(opt: Opt) -> Self {
+        let tls = if let (Some(cert), Some(key)) = (opt.tls_cert, opt.tls_key) {
+            Some(TlsConfig { cert, key })
+        } else {
+            None
+        };
+
         PartialConfig {
             database: opt.database_url,
             database_prefix: opt.database_prefix,
@@ -259,6 +290,7 @@ impl PartialConfig {
                 None
             },
             no_ansi: if opt.no_ansi { Some(true) } else { None },
+            tls,
         }
     }
 
@@ -281,6 +313,7 @@ impl PartialConfig {
             socket_permissions: self.socket_permissions.or(fallback.socket_permissions),
             allow_self_signed: self.allow_self_signed.or(fallback.allow_self_signed),
             no_ansi: self.no_ansi.or(fallback.no_ansi),
+            tls: self.tls.or(fallback.tls),
         }
     }
 }
diff --git a/src/lib.rs b/src/lib.rs
index 8ae6bb7..7ae6454 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -1,4 +1,4 @@
-use crate::config::{Bind, Config};
+use crate::config::{Bind, Config, TlsConfig};
 use crate::connection::{handle_user_socket, ActiveConnections};
 use crate::event::{
     Activity, Custom, Event, GroupUpdate, Notification, PreAuth, ShareCreate, StorageUpdate,
@@ -241,6 +241,7 @@ pub fn serve(
     app: Arc<App>,
     bind: Bind,
     cancel: oneshot::Receiver<()>,
+    tls: Option<&TlsConfig>,
 ) -> Result<impl Future<Output = ()> + Send> {
     let app = warp::any().map(move || app.clone());
 
@@ -354,22 +355,39 @@ pub fn serve(
 
     let routes = routes.clone().or(warp::path!("push" / ..).and(routes));
 
-    serve_at(routes, bind, cancel)
+    serve_at(routes, bind, cancel, tls)
 }
 
-fn serve_at<F, C>(filter: F, bind: Bind, cancel: C) -> Result<impl Future<Output = ()> + Send>
+fn serve_at<F, C>(
+    filter: F,
+    bind: Bind,
+    cancel: C,
+    tls: Option<&TlsConfig>,
+) -> Result<impl Future<Output = ()> + Send>
 where
     C: Future + Send + Sync + 'static,
     F: Filter + Clone + Send + Sync + 'static,
     F::Extract: Reply,
 {
     let cancel = cancel.map(|_| ());
-    match bind {
-        Bind::Tcp(addr) => {
-            let (_, server) = warp::serve(filter).bind_with_graceful_shutdown(addr, cancel);
-            Ok(Either::Left(server))
+    let server = warp::serve(filter);
+    match (bind, tls) {
+        (Bind::Tcp(addr), Some(tls)) => {
+            let (_, server) = server
+                .tls()
+                .cert_path(&tls.cert)
+                .key_path(&tls.key)
+                .bind_with_graceful_shutdown(addr, cancel);
+            Ok(Either::Left(Either::Left(server)))
+        }
+        (Bind::Tcp(addr), None) => {
+            let (_, server) = server.bind_with_graceful_shutdown(addr, cancel);
+            Ok(Either::Left(Either::Right(server)))
         }
-        Bind::Unix(socket_path, permissions) => {
+        (Bind::Unix(socket_path, permissions), tls) => {
+            if tls.is_some() {
+                log::warn!("Serving with TLS over a unix socket is not supported");
+            }
             fs::remove_file(&socket_path).ok();
 
             let listener = UnixListener::bind(&socket_path).wrap_err_with(|| {
@@ -382,7 +400,7 @@ where
 
             let stream = UnixListenerStream::new(listener);
             Ok(Either::Right(
-                warp::serve(filter)
+                server
                     .serve_incoming_with_graceful_shutdown(stream, cancel)
                     .map(move |_| {
                         fs::remove_file(&socket_path).ok();
diff --git a/src/main.rs b/src/main.rs
index 22db885..3b59a57 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -53,6 +53,7 @@ async fn main() -> Result<()> {
     }
 
     let bind = config.bind.clone();
+    let tls = config.tls.clone();
     let metrics_bind = config.metrics_bind.clone();
     let app = Arc::new(App::new(config, log_handle).await?);
     if let Err(e) = app.self_test().await {
@@ -60,11 +61,15 @@ async fn main() -> Result<()> {
     }
 
     log::trace!("Listening on {}", bind);
-    let server = spawn(serve(app.clone(), bind, serve_cancel_handle)?);
+    let server = spawn(serve(app.clone(), bind, serve_cancel_handle, tls.as_ref())?);
 
     if let Some(metrics_bind) = metrics_bind {
         log::trace!("Metrics listening {}", metrics_bind);
-        spawn(serve_metrics(metrics_bind, metrics_cancel_handle)?);
+        spawn(serve_metrics(
+            metrics_bind,
+            metrics_cancel_handle,
+            tls.as_ref(),
+        )?);
     }
 
     spawn(listen_loop(app, listen_cancel_handle));
diff --git a/src/metrics.rs b/src/metrics.rs
index 9cd2948..dc67311 100644
--- a/src/metrics.rs
+++ b/src/metrics.rs
@@ -1,4 +1,4 @@
-use crate::config::Bind;
+use crate::config::{Bind, TlsConfig};
 use crate::serve_at;
 use color_eyre::Result;
 use serde::{Serialize, Serializer};
@@ -117,6 +117,7 @@ impl Metrics {
 pub fn serve_metrics(
     bind: Bind,
     cancel: oneshot::Receiver<()>,
+    tls: Option<&TlsConfig>,
 ) -> Result<impl Future<Output = ()> + Send> {
     let metrics = warp::path!("metrics").map(|| {
         let mut response = String::with_capacity(128);
@@ -148,5 +149,5 @@ pub fn serve_metrics(
         response
     });
 
-    serve_at(metrics, bind, cancel)
+    serve_at(metrics, bind, cancel, tls)
 }
diff --git a/tests/integration.rs b/tests/integration.rs
index 7f6b203..ca1cbaa 100644
--- a/tests/integration.rs
+++ b/tests/integration.rs
@@ -154,6 +154,7 @@ impl Services {
             bind: Bind::Tcp(self.nextcloud.clone()),
             allow_self_signed: false,
             no_ansi: false,
+            tls: None,
         }
     }
 
@@ -178,7 +179,7 @@ impl Services {
 
         let bind = Bind::Tcp(addr);
         spawn(async move {
-            let serve = serve(app.clone(), bind, serve_rx).unwrap();
+            let serve = serve(app.clone(), bind, serve_rx, None).unwrap();
             let listen = listen_loop(app.clone(), listen_rx);
 
             pin_mut!(serve);
-- 
GitLab