From 0e98ab58f8cbd67ee348204848a7fe536147e033 Mon Sep 17 00:00:00 2001
From: Robin Appelman <robin@icewind.nl>
Date: Thu, 9 Dec 2021 18:16:50 +0100
Subject: [PATCH] make fileid notifications opt in

Signed-off-by: Robin Appelman <robin@icewind.nl>
---
 src/connection.rs       | 23 ++++++++++++++++++-----
 src/message.rs          | 19 +++++++++++--------
 test_client/src/main.rs |  3 +++
 3 files changed, 32 insertions(+), 13 deletions(-)

diff --git a/src/connection.rs b/src/connection.rs
index f92b3e3..e3d337a 100644
--- a/src/connection.rs
+++ b/src/connection.rs
@@ -7,7 +7,7 @@ use dashmap::DashMap;
 use futures::{future::select, pin_mut, SinkExt, StreamExt};
 use std::net::IpAddr;
 use std::num::NonZeroUsize;
-use std::sync::atomic::{AtomicUsize, Ordering};
+use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
 use std::sync::Arc;
 use std::time::{Duration, Instant};
 use tokio::sync::broadcast;
@@ -42,6 +42,11 @@ impl ActiveConnections {
     }
 }
 
+#[derive(Default)]
+pub struct ConnectionOptions {
+    pub listen_file_id: AtomicBool,
+}
+
 pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: Vec<IpAddr>) {
     let user_id = match timeout(
         Duration::from_secs(15),
@@ -78,6 +83,8 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for:
 
     METRICS.add_connection();
 
+    let opts = ConnectionOptions::default();
+
     // Every time we send a ping, we set this to a random non-zero value
     // when a pong is returned, we check it against the expected value and reset this to 0
     // If we get the wrong pong back, or the expected value hasn't been cleared
@@ -85,7 +92,7 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for:
     let expect_pong = AtomicUsize::default();
     let expect_pong = &expect_pong;
 
-    let transmit = async move {
+    let transmit = async {
         let mut send_queue = SendQueue::default();
 
         let mut reset = app.reset_rx();
@@ -103,7 +110,7 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for:
                                 log::debug!(target: "notify_push::send", "Sending {} to {}", msg, user_id);
                                 METRICS.add_message();
                                 last_send = now;
-                                user_ws_tx.send(msg.into()).await.ok();
+                                user_ws_tx.send(msg.to_message(&opts)).await.ok();
                             }
                         }
                         Err(_timout) => {
@@ -111,7 +118,7 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for:
                                 last_send = now;
                                 METRICS.add_message();
                                 log::debug!(target: "notify_push::send", "Sending debounced {} to {}", msg, user_id);
-                                user_ws_tx.feed(msg.into()).await.ok();
+                                user_ws_tx.feed(msg.to_message(&opts)).await.ok();
                             }
 
                             if now.duration_since(last_send) > ping_interval {
@@ -144,7 +151,7 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for:
         }
     };
 
-    let receive = async move {
+    let receive = async {
         // handle messages until the client closes the connection
         while let Some(result) = user_ws_rx.next().await {
             match result {
@@ -155,6 +162,12 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for:
                         break;
                     }
                 }
+                Ok(msg) if msg.is_text() => {
+                    let text = msg.to_str().unwrap_or_default();
+                    if text == "listen notify_file_id" {
+                        opts.listen_file_id.store(true, Ordering::Relaxed);
+                    }
+                }
                 Ok(_) => {}
                 Err(e) => {
                     let formatted = e.to_string();
diff --git a/src/message.rs b/src/message.rs
index d6539bd..3f6479c 100644
--- a/src/message.rs
+++ b/src/message.rs
@@ -1,3 +1,4 @@
+use crate::connection::ConnectionOptions;
 use parse_display::Display;
 use serde_json::Value;
 use smallvec::SmallVec;
@@ -71,15 +72,17 @@ impl PushMessage {
     }
 }
 
-impl From<PushMessage> for Message {
-    fn from(msg: PushMessage) -> Self {
-        match msg {
+impl PushMessage {
+    pub fn to_message(self, opts: &ConnectionOptions) -> Message {
+        match self {
             PushMessage::File(ids) => match ids {
-                UpdatedFiles::Unknown => Message::text(String::from("notify_file")),
-                UpdatedFiles::Known(ids) => Message::text(format!(
-                    "notify_file {}",
-                    serde_json::to_string(&ids).unwrap()
-                )),
+                UpdatedFiles::Known(ids) if opts.listen_file_id.load(Ordering::Relaxed) => {
+                    Message::text(format!(
+                        "notify_file_id {}",
+                        serde_json::to_string(&ids).unwrap()
+                    ))
+                }
+                _ => Message::text(String::from("notify_file")),
             },
             PushMessage::Activity => Message::text(String::from("notify_activity")),
             PushMessage::Notification => Message::text(String::from("notify_file")),
diff --git a/test_client/src/main.rs b/test_client/src/main.rs
index fb6072d..5c05b6c 100644
--- a/test_client/src/main.rs
+++ b/test_client/src/main.rs
@@ -34,6 +34,9 @@ fn main() -> Result<()> {
     socket
         .write_message(Message::Text(password))
         .wrap_err("Failed to send password")?;
+    // socket
+    //     .write_message(Message::Text("listen notify_file_id".into()))
+    //     .wrap_err("Failed to send username")?;
 
     loop {
         if let Message::Text(text) = socket.read_message()? {
-- 
GitLab