From f7d2c1636d1671777d66bdee85bb545f110147d2 Mon Sep 17 00:00:00 2001
From: Robin Appelman <robin@icewind.nl>
Date: Thu, 9 Dec 2021 17:24:34 +0100
Subject: [PATCH] send fileids in notification if known

Signed-off-by: Robin Appelman <robin@icewind.nl>
---
 Cargo.lock              |   3 +
 Cargo.toml              |   2 +-
 lib/Listener.php        |   1 +
 src/connection.rs       |  62 ++++-----
 src/event.rs            |   1 +
 src/lib.rs              |  20 +--
 src/message.rs          | 278 +++++++++++++++++++++++++++-------------
 test_client/src/main.rs |   4 +-
 8 files changed, 243 insertions(+), 128 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index e681fbc..deebd59 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -1858,6 +1858,9 @@ name = "smallvec"
 version = "1.7.0"
 source = "registry+https://github.com/rust-lang/crates.io-index"
 checksum = "1ecab6c735a6bb4139c0caafd0cc3635748bbb3acf4550e8138122099251f309"
+dependencies = [
+ "serde",
+]
 
 [[package]]
 name = "socket2"
diff --git a/Cargo.toml b/Cargo.toml
index e585966..eda32b9 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -18,7 +18,7 @@ dotenv = "0.15"
 dashmap = "4"
 once_cell = "1"
 color-eyre = "0.5"
-smallvec = "1"
+smallvec = { version = "1", features = ["serde"] }
 reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] }
 warp-real-ip = "0.2"
 parse-display = "0.5"
diff --git a/lib/Listener.php b/lib/Listener.php
index 29f4291..3fe79b6 100644
--- a/lib/Listener.php
+++ b/lib/Listener.php
@@ -49,6 +49,7 @@ class Listener implements IConsumer, IApp, INotifier, IDismissableNotifier {
 			$this->queue->push('notify_storage_update', [
 				'storage' => $event->getStorageId(),
 				'path' => $event->getPath(),
+				'file_id' => $event->getFileId(),
 			]);
 		}
 	}
diff --git a/src/connection.rs b/src/connection.rs
index 64c8133..f92b3e3 100644
--- a/src/connection.rs
+++ b/src/connection.rs
@@ -1,4 +1,4 @@
-use crate::message::{DebounceMap, MessageType};
+use crate::message::{PushMessage, SendQueue};
 use crate::metrics::METRICS;
 use crate::{App, UserId};
 use ahash::RandomState;
@@ -17,10 +17,10 @@ use warp::filters::ws::{Message, WebSocket};
 const USER_CONNECTION_LIMIT: usize = 64;
 
 #[derive(Default)]
-pub struct ActiveConnections(DashMap<UserId, broadcast::Sender<MessageType>, RandomState>);
+pub struct ActiveConnections(DashMap<UserId, broadcast::Sender<PushMessage>, RandomState>);
 
 impl ActiveConnections {
-    pub async fn add(&self, user: UserId) -> Result<broadcast::Receiver<MessageType>> {
+    pub async fn add(&self, user: UserId) -> Result<broadcast::Receiver<PushMessage>> {
         if let Some(sender) = self.0.get(&user) {
             // stop a single user from trying to eat all the resources
             if sender.receiver_count() > USER_CONNECTION_LIMIT {
@@ -35,7 +35,7 @@ impl ActiveConnections {
         }
     }
 
-    pub async fn send_to_user(&self, user: &UserId, msg: MessageType) {
+    pub async fn send_to_user(&self, user: &UserId, msg: PushMessage) {
         if let Some(tx) = self.0.get(user) {
             tx.send(msg).ok();
         }
@@ -86,45 +86,49 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for:
     let expect_pong = &expect_pong;
 
     let transmit = async move {
-        let mut debounce = DebounceMap::default();
+        let mut send_queue = SendQueue::default();
 
         let mut reset = app.reset_rx();
 
+        let ping_interval = Duration::from_secs(30);
+        let mut last_send = Instant::now() - ping_interval;
+
         'tx_loop: loop {
             tokio::select! {
-                msg = timeout(Duration::from_secs(30), rx.recv()) => {
+                msg = timeout(Duration::from_millis(500), rx.recv()) => {
+                    let now = Instant::now();
                     match msg {
                         Ok(Ok(msg)) => {
-                            if debounce.should_send(&msg) {
+                            if let Some(msg) = send_queue.push(msg, now) {
                                 log::debug!(target: "notify_push::send", "Sending {} to {}", msg, user_id);
                                 METRICS.add_message();
+                                last_send = now;
                                 user_ws_tx.send(msg.into()).await.ok();
-                            } else {
-                                log::debug!(target: "notify_push::send", "Debouncing {} to {}", msg, user_id);
-                            }
-                        }
-                        Err(_timout) if debounce.has_held_message() => {
-                            // if any message got held back for debounce, we try sending them now
-                            for msg in debounce.get_held_messages() {
-                                if debounce.should_send(&msg) {
-                                    log::debug!(target: "notify_push::send", "Sending debounced {} to {}", msg, user_id);
-                                    METRICS.add_message();
-                                    user_ws_tx.send(msg.into()).await.ok();
-                                }
                             }
                         }
                         Err(_timout) => {
-                            let data = rand::random::<NonZeroUsize>().into();
-                            let last_ping = expect_pong.swap(data, Ordering::SeqCst);
-                            if last_ping > 0 {
-                                log::info!("{} didn't reply to ping, closing", user_id);
-                                break;
+                            for msg in send_queue.drain(now) {
+                                last_send = now;
+                                METRICS.add_message();
+                                log::debug!(target: "notify_push::send", "Sending debounced {} to {}", msg, user_id);
+                                user_ws_tx.feed(msg.into()).await.ok();
+                            }
+
+                            if now.duration_since(last_send) > ping_interval {
+                                let data = rand::random::<NonZeroUsize>().into();
+                                let last_ping = expect_pong.swap(data, Ordering::SeqCst);
+                                if last_ping > 0 {
+                                    log::info!("{} didn't reply to ping, closing", user_id);
+                                    break;
+                                }
+                                log::debug!(target: "notify_push::send", "Sending ping to {}", user_id);
+                                last_send = now;
+                                user_ws_tx
+                                    .feed(Message::ping(data.to_le_bytes()))
+                                    .await
+                                    .ok();
                             }
-                            log::debug!(target: "notify_push::send", "Sending ping to {}", user_id);
-                            user_ws_tx
-                                .send(Message::ping(data.to_le_bytes()))
-                                .await
-                                .ok();
+                            user_ws_tx.flush().await.ok();
                         }
                         Ok(Err(_)) => {
                             // we dont care about dropped messages
diff --git a/src/event.rs b/src/event.rs
index 5194215..2ca4db8 100644
--- a/src/event.rs
+++ b/src/event.rs
@@ -13,6 +13,7 @@ use tokio_stream::{Stream, StreamExt};
 pub struct StorageUpdate {
     pub storage: u32,
     pub path: String,
+    pub file_id: Option<u64>,
 }
 
 #[derive(Debug, Deserialize)]
diff --git a/src/lib.rs b/src/lib.rs
index 7ae6454..10ecbe5 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -3,7 +3,7 @@ use crate::connection::{handle_user_socket, ActiveConnections};
 use crate::event::{
     Activity, Custom, Event, GroupUpdate, Notification, PreAuth, ShareCreate, StorageUpdate,
 };
-use crate::message::MessageType;
+use crate::message::{PushMessage, UpdatedFiles};
 use crate::metrics::METRICS;
 use crate::redis::Redis;
 use crate::storage_mapping::StorageMapping;
@@ -148,7 +148,11 @@ impl App {
 
     async fn handle_event(&self, event: Event) {
         match event {
-            Event::StorageUpdate(StorageUpdate { storage, path }) => {
+            Event::StorageUpdate(StorageUpdate {
+                storage,
+                path,
+                file_id,
+            }) => {
                 match self
                     .storage_mapping
                     .get_users_for_storage_path(storage, &path)
@@ -157,7 +161,7 @@ impl App {
                     Ok(users) => {
                         for user in users {
                             self.connections
-                                .send_to_user(&user, MessageType::File)
+                                .send_to_user(&user, PushMessage::File(file_id.into()))
                                 .await;
                         }
                     }
@@ -166,12 +170,12 @@ impl App {
             }
             Event::GroupUpdate(GroupUpdate { user, .. }) => {
                 self.connections
-                    .send_to_user(&user, MessageType::File)
+                    .send_to_user(&user, PushMessage::File(UpdatedFiles::Unknown))
                     .await;
             }
             Event::ShareCreate(ShareCreate { user }) => {
                 self.connections
-                    .send_to_user(&user, MessageType::File)
+                    .send_to_user(&user, PushMessage::File(UpdatedFiles::Unknown))
                     .await;
             }
             Event::TestCookie(cookie) => {
@@ -179,12 +183,12 @@ impl App {
             }
             Event::Activity(Activity { user }) => {
                 self.connections
-                    .send_to_user(&user, MessageType::Activity)
+                    .send_to_user(&user, PushMessage::Activity)
                     .await;
             }
             Event::Notification(Notification { user }) => {
                 self.connections
-                    .send_to_user(&user, MessageType::Notification)
+                    .send_to_user(&user, PushMessage::Notification)
                     .await;
             }
             Event::PreAuth(PreAuth { user, token }) => {
@@ -196,7 +200,7 @@ impl App {
                 body,
             }) => {
                 self.connections
-                    .send_to_user(&user, MessageType::Custom(message, body))
+                    .send_to_user(&user, PushMessage::Custom(message, body))
                     .await;
             }
             Event::Config(event::Config::LogSpec(spec)) => {
diff --git a/src/message.rs b/src/message.rs
index 0862430..d6539bd 100644
--- a/src/message.rs
+++ b/src/message.rs
@@ -1,16 +1,50 @@
 use parse_display::Display;
-use rand::{thread_rng, Rng};
 use serde_json::Value;
+use smallvec::SmallVec;
 use std::fmt::Write;
 use std::sync::atomic::{AtomicBool, Ordering};
 use std::time::Instant;
 use tokio::time::Duration;
 use warp::ws::Message;
 
-#[derive(Debug, Clone, Display)]
-pub enum MessageType {
+#[derive(Debug, Clone, PartialEq)]
+pub enum UpdatedFiles {
+    Unknown,
+    Known(SmallVec<[u64; 4]>),
+}
+
+impl UpdatedFiles {
+    pub fn extend(&mut self, more: &UpdatedFiles) {
+        match (self, more) {
+            (UpdatedFiles::Known(items), UpdatedFiles::Known(b)) => {
+                for id in b {
+                    if !items.contains(id) {
+                        items.push(*id);
+                    }
+                }
+            }
+            (self_, _) => *self_ = UpdatedFiles::Unknown,
+        }
+    }
+}
+
+impl From<Option<u64>> for UpdatedFiles {
+    fn from(id: Option<u64>) -> Self {
+        match id {
+            Some(id) => {
+                let mut ids = SmallVec::new();
+                ids.push(id);
+                UpdatedFiles::Known(ids)
+            }
+            None => UpdatedFiles::Unknown,
+        }
+    }
+}
+
+#[derive(Debug, Clone, Display, PartialEq)]
+pub enum PushMessage {
     #[display("notify_file")]
-    File,
+    File(UpdatedFiles),
     #[display("notify_activity")]
     Activity,
     #[display("notify_notification")]
@@ -19,14 +53,38 @@ pub enum MessageType {
     Custom(String, Value),
 }
 
-impl From<MessageType> for Message {
-    fn from(msg: MessageType) -> Self {
+impl PushMessage {
+    pub fn merge(&mut self, other: &PushMessage) {
+        match (self, other) {
+            (PushMessage::File(a), PushMessage::File(b)) => a.extend(b),
+            _ => {}
+        }
+    }
+
+    pub fn debounce_time(&self) -> Duration {
+        match self {
+            PushMessage::File(_) => Duration::from_secs(60),
+            PushMessage::Activity => Duration::from_secs(120),
+            PushMessage::Notification => Duration::from_secs(30),
+            PushMessage::Custom(..) => Duration::from_millis(1), // no debouncing for custom messages
+        }
+    }
+}
+
+impl From<PushMessage> for Message {
+    fn from(msg: PushMessage) -> Self {
         match msg {
-            MessageType::File => Message::text(String::from("notify_file")),
-            MessageType::Activity => Message::text(String::from("notify_activity")),
-            MessageType::Notification => Message::text(String::from("notify_notification")),
-            MessageType::Custom(ty, Value::Null) => Message::text(ty),
-            MessageType::Custom(ty, body) => Message::text({
+            PushMessage::File(ids) => match ids {
+                UpdatedFiles::Unknown => Message::text(String::from("notify_file")),
+                UpdatedFiles::Known(ids) => Message::text(format!(
+                    "notify_file {}",
+                    serde_json::to_string(&ids).unwrap()
+                )),
+            },
+            PushMessage::Activity => Message::text(String::from("notify_activity")),
+            PushMessage::Notification => Message::text(String::from("notify_file")),
+            PushMessage::Custom(ty, Value::Null) => Message::text(ty),
+            PushMessage::Custom(ty, body) => Message::text({
                 let mut str = ty;
                 write!(&mut str, " {}", body).ok();
                 str
@@ -37,99 +95,143 @@ impl From<MessageType> for Message {
 
 pub static DEBOUNCE_ENABLE: AtomicBool = AtomicBool::new(true);
 
-pub struct DebounceMap {
-    file: Instant,
-    activity: Instant,
-    notification: Instant,
-    file_held: bool,
-    activity_held: bool,
-    notification_held: bool,
+#[derive(Clone, Debug)]
+struct SendQueueItem {
+    received: Instant,
+    sent: Instant,
+    message: Option<PushMessage>,
 }
 
-impl Default for DebounceMap {
+impl Default for SendQueueItem {
     fn default() -> Self {
-        let past = Instant::now() - Duration::from_secs(600);
-        DebounceMap {
-            file: past,
-            activity: past,
-            notification: past,
-            file_held: false,
-            activity_held: false,
-            notification_held: false,
+        SendQueueItem {
+            received: Instant::now() - Duration::from_secs(120),
+            sent: Instant::now() - Duration::from_secs(120),
+            message: None,
         }
     }
 }
 
-impl DebounceMap {
-    /// Check if the debounce time has passed and set the last send time if so
-    pub fn should_send(&mut self, ty: &MessageType) -> bool {
-        if DEBOUNCE_ENABLE.load(Ordering::Relaxed) {
-            let last_send = self.get_last_send(ty);
-            if Instant::now().duration_since(last_send) > Self::debounce_time(ty) {
-                self.set_last_send(ty);
-                self.set_held(ty, false);
-                true
-            } else if Instant::now().duration_since(last_send) > Duration::from_millis(100) {
-                self.set_held(ty, true);
-                false
-            } else {
-                false
-            }
-        } else {
-            true
-        }
-    }
-
-    pub fn has_held_message(&self) -> bool {
-        self.file_held || self.activity_held || self.notification_held
-    }
+#[derive(Default, Debug)]
+pub struct SendQueue {
+    items: [SendQueueItem; 3],
+}
 
-    pub fn get_held_messages(&self) -> impl Iterator<Item = MessageType> {
-        let file_opt = self.file_held.then(|| MessageType::File);
-        let activity_opt = self.activity_held.then(|| MessageType::Activity);
-        let notification_opt = self.notification_held.then(|| MessageType::Notification);
-        file_opt
-            .into_iter()
-            .chain(activity_opt.into_iter())
-            .chain(notification_opt.into_iter())
+impl SendQueue {
+    pub fn new() -> Self {
+        SendQueue::default()
     }
 
-    fn get_last_send(&self, ty: &MessageType) -> Instant {
-        match ty {
-            MessageType::File => self.file,
-            MessageType::Activity => self.activity,
-            MessageType::Notification => self.notification,
-            MessageType::Custom(..) => Instant::now() - Duration::from_secs(600), // no debouncing for custom messages
+    fn item_mut(&mut self, message: &PushMessage) -> Option<&mut SendQueueItem> {
+        match message {
+            PushMessage::File(_) => Some(&mut self.items[0]),
+            PushMessage::Activity => Some(&mut self.items[1]),
+            PushMessage::Notification => Some(&mut self.items[2]),
+            PushMessage::Custom(_, _) => None,
         }
     }
 
-    fn set_last_send(&mut self, ty: &MessageType) {
-        // apply a randomized offset to the last_send
-        // this helps mitigate against load bursts from many clients receiving the same updates
-        let spread = Duration::from_millis(thread_rng().gen_range(0..1000));
-        match ty {
-            MessageType::File => self.file = Instant::now() - spread,
-            MessageType::Activity => self.activity = Instant::now() - spread,
-            MessageType::Notification => self.notification = Instant::now() - spread,
-            MessageType::Custom(..) => {} // no debouncing for custom messages
+    pub fn push(&mut self, message: PushMessage, time: Instant) -> Option<PushMessage> {
+        if !DEBOUNCE_ENABLE.load(Ordering::Relaxed) {
+            return Some(message);
         }
-    }
+        let item = match self.item_mut(&message) {
+            Some(item) => item,
+            None => return Some(message),
+        };
 
-    fn set_held(&mut self, ty: &MessageType, held: bool) {
-        match ty {
-            MessageType::File => self.file_held = held,
-            MessageType::Activity => self.activity_held = held,
-            MessageType::Notification => self.notification_held = held,
-            MessageType::Custom(..) => {} // no debouncing for custom messages
-        }
+        match &mut item.message {
+            Some(queued) => {
+                queued.merge(&message);
+            }
+            opt => {
+                *opt = Some(message);
+            }
+        };
+        item.received = time;
+
+        None
     }
 
-    fn debounce_time(ty: &MessageType) -> Duration {
-        match ty {
-            MessageType::File => Duration::from_secs(60),
-            MessageType::Activity => Duration::from_secs(120),
-            MessageType::Notification => Duration::from_secs(30),
-            MessageType::Custom(..) => Duration::from_millis(1), // no debouncing for custom messages
-        }
+    pub fn drain<'a>(&'a mut self, now: Instant) -> impl Iterator<Item = PushMessage> + 'a {
+        self.items.iter_mut().filter_map(move |item| {
+            let debounce_time = item.message.as_ref()?.debounce_time();
+            if now.duration_since(item.sent) > debounce_time {
+                if now.duration_since(item.received) > Duration::from_millis(100) {
+                    item.sent = now;
+                    item.message.take()
+                } else {
+                    None
+                }
+            } else {
+                None
+            }
+        })
     }
 }
+
+#[test]
+fn test_send_queue() {
+    let base_time = Instant::now();
+    let mut queue = SendQueue::new();
+    queue.push(PushMessage::Activity, base_time);
+    queue.push(
+        PushMessage::File(UpdatedFiles::Known(vec![1].into())),
+        base_time,
+    );
+    queue.push(
+        PushMessage::File(UpdatedFiles::Known(vec![2].into())),
+        base_time + Duration::from_millis(10),
+    );
+
+    // without 100ms the messages get merged
+    assert_eq!(
+        Vec::<PushMessage>::new(),
+        queue
+            .drain(base_time + Duration::from_millis(20))
+            .collect::<Vec<_>>()
+    );
+
+    // after 100ms the merged messages get send
+    assert_eq!(
+        vec![
+            PushMessage::File(UpdatedFiles::Known(vec![1, 2].into())),
+            PushMessage::Activity
+        ],
+        queue
+            .drain(base_time + Duration::from_millis(200))
+            .collect::<Vec<_>>()
+    );
+
+    // messages send within debounce time get held back
+    queue.push(
+        PushMessage::File(UpdatedFiles::Known(vec![3].into())),
+        base_time + Duration::from_secs(5),
+    );
+    queue.push(
+        PushMessage::File(UpdatedFiles::Known(vec![4].into())),
+        base_time + Duration::from_secs(6),
+    );
+    assert_eq!(
+        Vec::<PushMessage>::new(),
+        queue
+            .drain(base_time + Duration::from_secs(10))
+            .collect::<Vec<_>>()
+    );
+
+    // after debounce time we get the merged messages from the timeframe
+    assert_eq!(
+        vec![PushMessage::File(UpdatedFiles::Known(vec![3, 4].into()))],
+        queue
+            .drain(base_time + Duration::from_secs(70))
+            .collect::<Vec<_>>()
+    );
+
+    // nothing left
+    assert_eq!(
+        Vec::<PushMessage>::new(),
+        queue
+            .drain(base_time + Duration::from_secs(300))
+            .collect::<Vec<_>>()
+    );
+}
diff --git a/test_client/src/main.rs b/test_client/src/main.rs
index aab78ef..fb6072d 100644
--- a/test_client/src/main.rs
+++ b/test_client/src/main.rs
@@ -40,8 +40,8 @@ fn main() -> Result<()> {
             if text.starts_with("err: ") {
                 eprintln!("Received error: {}", &text[5..]);
                 return Ok(());
-            } else if text == "notify_file" {
-                println!("Received file update notification");
+            } else if text.starts_with("notify_file") {
+                println!("Received file update notification {}", text);
             } else if text == "notify_activity" {
                 println!("Received activity notification");
             } else if text == "notify_notification" {
-- 
GitLab