diff --git a/appinfo/info.xml b/appinfo/info.xml index f62e4a5120977a8be237719f65353743026017b1..46c802ca63fc761b0626947c7167c59ea6cf19d9 100644 --- a/appinfo/info.xml +++ b/appinfo/info.xml @@ -40,5 +40,6 @@ Once the app is installed, the push binary needs to be setup. You can either use <command>OCA\NotifyPush\Command\SelfTest</command> <command>OCA\NotifyPush\Command\Log</command> <command>OCA\NotifyPush\Command\Metrics</command> + <command>OCA\NotifyPush\Command\Reset</command> </commands> </info> diff --git a/lib/Command/Reset.php b/lib/Command/Reset.php new file mode 100644 index 0000000000000000000000000000000000000000..a7946496a1f437e36a09774e2eaee7e03d75a451 --- /dev/null +++ b/lib/Command/Reset.php @@ -0,0 +1,58 @@ +<?php + +declare(strict_types=1); +/** + * @copyright Copyright (c) 2021 Robin Appelman <robin@icewind.nl> + * + * @license GNU AGPL version 3 or any later version + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as + * published by the Free Software Foundation, either version 3 of the + * License, or (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see <http://www.gnu.org/licenses/>. + * + */ + +namespace OCA\NotifyPush\Command; + + +use OCA\NotifyPush\Queue\IQueue; +use Symfony\Component\Console\Command\Command; +use Symfony\Component\Console\Input\InputArgument; +use Symfony\Component\Console\Input\InputInterface; +use Symfony\Component\Console\Input\InputOption; +use Symfony\Component\Console\Output\OutputInterface; + +class Reset extends Command { + private $queue; + + public function __construct( + IQueue $queue + ) { + parent::__construct(); + $this->queue = $queue; + } + + /** + * @return void + */ + protected function configure() { + $this + ->setName('notify_push:reset') + ->setDescription('Cancel all active connections to the push server'); + parent::configure(); + } + + protected function execute(InputInterface $input, OutputInterface $output) { + $this->queue->push("notify_signal", "reset"); + return 0; + } +} diff --git a/src/connection.rs b/src/connection.rs index 89c000e4b082fa0fe948abfaf2da5337b4fa419c..6d6a2cb155bf9f213a9a0f8be25f2ddc557181ea 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -87,32 +87,44 @@ pub async fn handle_user_socket(mut ws: WebSocket, app: Arc<App>, forwarded_for: let transmit = async move { let mut debounce = DebounceMap::default(); - loop { - match timeout(Duration::from_secs(30), rx.recv()).await { - Ok(Ok(msg)) => { - log::debug!(target: "notify_push::send", "Sending {} to {}", msg, user_id); - if debounce.should_send(&msg) { - METRICS.add_message(); - user_ws_tx.send(msg.into()).await.ok(); - } - } - Err(_timout) => { - let data = rand::random::<NonZeroUsize>().into(); - let last_ping = expect_pong.swap(data, Ordering::SeqCst); - if last_ping > 0 { - log::info!("{} didn't reply to ping, closing", user_id); - break; + + let mut reset = app.reset_rx(); + + 'tx_loop: loop { + tokio::select! { + msg = timeout(Duration::from_secs(30), rx.recv()) => { + match msg { + Ok(Ok(msg)) => { + log::debug!(target: "notify_push::send", "Sending {} to {}", msg, user_id); + if debounce.should_send(&msg) { + METRICS.add_message(); + user_ws_tx.send(msg.into()).await.ok(); + } + } + Err(_timout) => { + let data = rand::random::<NonZeroUsize>().into(); + let last_ping = expect_pong.swap(data, Ordering::SeqCst); + if last_ping > 0 { + log::info!("{} didn't reply to ping, closing", user_id); + break; + } + log::debug!(target: "notify_push::send", "Sending ping to {}", user_id); + user_ws_tx + .send(Message::ping(data.to_le_bytes())) + .await + .ok(); + } + Ok(Err(_)) => { + // we dont care about dropped messages + } } - log::debug!(target: "notify_push::send", "Sending ping to {}", user_id); - user_ws_tx - .send(Message::ping(data.to_le_bytes())) - .await - .ok(); - } - Ok(Err(_)) => { - // we dont care about dropped messages - } - } + }, + _ = reset.recv() => { + user_ws_tx.close().await.ok(); + log::debug!("Connection closed by reset request"); + break 'tx_loop; + }, + }; } }; diff --git a/src/event.rs b/src/event.rs index f3a01f4b9b228106faa82ed28a734c690b5e18bd..5194215c2e494615a2178b7d7a4288cb03229529 100644 --- a/src/event.rs +++ b/src/event.rs @@ -63,6 +63,12 @@ pub struct Custom { pub body: Value, } +#[derive(Debug, Deserialize, Display)] +#[serde(rename_all = "snake_case")] +pub enum Signal { + Reset, +} + #[derive(Debug, Display)] pub enum Event { #[display("storage update notification for storage {0.storage} and path {0.path}")] @@ -85,6 +91,8 @@ pub enum Event { Config(Config), #[display("{0} query")] Query(Query), + #[display("{0} signal")] + Signal(Signal), } #[derive(Debug, Error)] @@ -130,6 +138,9 @@ impl TryFrom<Msg> for Event { "notify_query" => Ok(Event::Query(serde_json::from_slice( msg.get_payload_bytes(), )?)), + "notify_signal" => Ok(Event::Signal(serde_json::from_slice( + msg.get_payload_bytes(), + )?)), _ => Err(MessageDecodeError::UnsupportedEventType), } } @@ -153,6 +164,7 @@ pub async fn subscribe( "notify_custom", "notify_config", "notify_query", + "notify_signal", ]; for channel in channels.iter() { pubsub diff --git a/src/lib.rs b/src/lib.rs index d75560491307a403f03d1cfc09df45bcd6828392..8ae6bb77b6d16bc47250b0c89c22df48f467cc66 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -25,8 +25,8 @@ use std::os::unix::fs::PermissionsExt; use std::sync::atomic::{AtomicU32, Ordering}; use std::time::{Duration, Instant}; use tokio::net::UnixListener; -use tokio::sync::oneshot; use tokio::sync::Mutex; +use tokio::sync::{broadcast, oneshot}; use tokio::time::sleep; use tokio_stream::wrappers::UnixListenerStream; use warp::filters::addr::remote; @@ -51,6 +51,8 @@ pub struct App { test_cookie: AtomicU32, redis: Redis, log_handle: Mutex<LoggerHandle>, + reset_tx: broadcast::Sender<()>, + _reset_rx: broadcast::Receiver<()>, } impl App { @@ -64,6 +66,8 @@ impl App { let redis = Redis::new(config.redis)?; + let (reset_tx, reset_rx) = broadcast::channel(1); + Ok(App { connections, nc_client, @@ -72,6 +76,8 @@ impl App { storage_mapping, redis, log_handle: Mutex::new(log_handle), + reset_tx, + _reset_rx: reset_rx, }) } @@ -91,6 +97,8 @@ impl App { let redis = Redis::new(config.redis)?; + let (reset_tx, reset_rx) = broadcast::channel(1); + Ok(App { connections, nc_client, @@ -99,6 +107,8 @@ impl App { storage_mapping, redis, log_handle: Mutex::new(log_handle), + reset_tx, + _reset_rx: reset_rx, }) } @@ -213,8 +223,18 @@ impl App { } Err(e) => log::warn!("Failed to set metrics: {}", e), }, + Event::Signal(event::Signal::Reset) => { + log::info!("Stopping all open connections"); + if let Err(e) = self.reset_tx.send(()) { + log::warn!("Failed to send reset command to all connections: {}", e); + } + } } } + + pub fn reset_rx(&self) -> broadcast::Receiver<()> { + self.reset_tx.subscribe() + } } pub fn serve(