From 0e98ab58f8cbd67ee348204848a7fe536147e033 Mon Sep 17 00:00:00 2001 From: Robin Appelman <robin@icewind.nl> Date: Thu, 9 Dec 2021 18:16:50 +0100 Subject: [PATCH] make fileid notifications opt in Signed-off-by: Robin Appelman <robin@icewind.nl> --- src/connection.rs | 23 ++++++++++++++++++----- src/message.rs | 19 +++++++++++-------- test_client/src/main.rs | 3 +++ 3 files changed, 32 insertions(+), 13 deletions(-) diff --git a/src/connection.rs b/src/connection.rs index f92b3e3..e3d337a 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -7,7 +7,7 @@ use dashmap::DashMap; use futures::{future::select, pin_mut, SinkExt, StreamExt}; use std::net::IpAddr; use std::num::NonZeroUsize; -use std::sync::atomic::{AtomicUsize, Ordering}; +use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering}; use std::sync::Arc; use std::time::{Duration, Instant}; use tokio::sync::broadcast; @@ -42,6 +42,11 @@ impl ActiveConnections { } } +#[derive(Default)] +pub struct ConnectionOptions { + pub listen_file_id: AtomicBool, +} + pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: Vec<IpAddr>) { let user_id = match timeout( Duration::from_secs(15), @@ -78,6 +83,8 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: METRICS.add_connection(); + let opts = ConnectionOptions::default(); + // Every time we send a ping, we set this to a random non-zero value // when a pong is returned, we check it against the expected value and reset this to 0 // If we get the wrong pong back, or the expected value hasn't been cleared @@ -85,7 +92,7 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: let expect_pong = AtomicUsize::default(); let expect_pong = &expect_pong; - let transmit = async move { + let transmit = async { let mut send_queue = SendQueue::default(); let mut reset = app.reset_rx(); @@ -103,7 +110,7 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: log::debug!(target: "notify_push::send", "Sending {} to {}", msg, user_id); METRICS.add_message(); last_send = now; - user_ws_tx.send(msg.into()).await.ok(); + user_ws_tx.send(msg.to_message(&opts)).await.ok(); } } Err(_timout) => { @@ -111,7 +118,7 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: last_send = now; METRICS.add_message(); log::debug!(target: "notify_push::send", "Sending debounced {} to {}", msg, user_id); - user_ws_tx.feed(msg.into()).await.ok(); + user_ws_tx.feed(msg.to_message(&opts)).await.ok(); } if now.duration_since(last_send) > ping_interval { @@ -144,7 +151,7 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: } }; - let receive = async move { + let receive = async { // handle messages until the client closes the connection while let Some(result) = user_ws_rx.next().await { match result { @@ -155,6 +162,12 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: break; } } + Ok(msg) if msg.is_text() => { + let text = msg.to_str().unwrap_or_default(); + if text == "listen notify_file_id" { + opts.listen_file_id.store(true, Ordering::Relaxed); + } + } Ok(_) => {} Err(e) => { let formatted = e.to_string(); diff --git a/src/message.rs b/src/message.rs index d6539bd..3f6479c 100644 --- a/src/message.rs +++ b/src/message.rs @@ -1,3 +1,4 @@ +use crate::connection::ConnectionOptions; use parse_display::Display; use serde_json::Value; use smallvec::SmallVec; @@ -71,15 +72,17 @@ impl PushMessage { } } -impl From<PushMessage> for Message { - fn from(msg: PushMessage) -> Self { - match msg { +impl PushMessage { + pub fn to_message(self, opts: &ConnectionOptions) -> Message { + match self { PushMessage::File(ids) => match ids { - UpdatedFiles::Unknown => Message::text(String::from("notify_file")), - UpdatedFiles::Known(ids) => Message::text(format!( - "notify_file {}", - serde_json::to_string(&ids).unwrap() - )), + UpdatedFiles::Known(ids) if opts.listen_file_id.load(Ordering::Relaxed) => { + Message::text(format!( + "notify_file_id {}", + serde_json::to_string(&ids).unwrap() + )) + } + _ => Message::text(String::from("notify_file")), }, PushMessage::Activity => Message::text(String::from("notify_activity")), PushMessage::Notification => Message::text(String::from("notify_file")), diff --git a/test_client/src/main.rs b/test_client/src/main.rs index fb6072d..5c05b6c 100644 --- a/test_client/src/main.rs +++ b/test_client/src/main.rs @@ -34,6 +34,9 @@ fn main() -> Result<()> { socket .write_message(Message::Text(password)) .wrap_err("Failed to send password")?; + // socket + // .write_message(Message::Text("listen notify_file_id".into())) + // .wrap_err("Failed to send username")?; loop { if let Message::Text(text) = socket.read_message()? { -- GitLab