diff --git a/src/event.rs b/src/event.rs index 073ceb6237ad013cf5195f9a16d666657763a110..6d72c7cd09090644b505474f08017f32f575fa14 100644 --- a/src/event.rs +++ b/src/event.rs @@ -3,10 +3,7 @@ use std::{collections::HashSet, sync::Arc}; use anyhow::Result; use dbus::nonblock::SyncConnection; use serde_derive::{Deserialize, Serialize}; -use tokio::{ - sync::{broadcast, mpsc}, - task::JoinHandle, -}; +use tokio::{sync::mpsc, task::JoinHandle}; use crate::{ config::Config, @@ -19,6 +16,8 @@ pub enum Event { ConnectivityStateChanged { connectivity_state: ConnectivityState, }, + DaemonStarted, + DaemonStopping, DeviceAdded { device_path: dbus::Path<'static>, }, @@ -42,35 +41,32 @@ pub enum Event { pub async fn handle_events( conn: Arc<SyncConnection>, config: Arc<Config>, - mut stop_signal_rx: broadcast::Receiver<()>, mut event_rx: mpsc::Receiver<Event>, ) -> Result<JoinHandle<Result<()>>> { log::info!("Starting event handler task"); let join_handle = tokio::spawn(async move { loop { - tokio::select! { - event = event_rx.recv() => { - match event { - Some(event) => { - log::debug!("Got event: {:?}", &event); - - for rule in &config.rules { - if let Err(err) = rule.evaluate(&conn, &event).await { - log::error!("Got error evaluating rule: {:?} {:?}", rule, err); - }; - } - } - None => { - log::warn!("Got an empty event"); - } + let event = event_rx.recv().await; + + match event { + Some(event) => { + log::debug!("Got event: {:?}", &event); + + for rule in &config.rules { + if let Err(err) = rule.evaluate(&conn, &event).await { + log::error!("Got error evaluating rule: {:?} {:?}", rule, err); + }; } - }, - _ = stop_signal_rx.recv() => { - log::info!("Stoping event handler task"); + if let Event::DaemonStopping = event { + log::info!("Stoping event handler task"); - break; + break; + } + } + None => { + log::warn!("Got an empty event"); } } } @@ -87,6 +83,8 @@ pub enum Trigger { ConnectivityStateChanged { states: HashSet<ConnectivityState>, }, + DaemonStarted, + DaemonStopping, DeviceAdded { device_identifier: Option<DeviceIdentifier>, }, @@ -170,7 +168,9 @@ impl Trigger { }, (Event::StateChanged { state }, Trigger::StateChanged { states }) => { Ok(states.contains(state)) - } + }, + (Event::DaemonStarted, Trigger::DaemonStarted) => {Ok(true)}, + (Event::DaemonStopping, Trigger::DaemonStopping) => {Ok(true)}, _ => Ok(false), } } diff --git a/src/main.rs b/src/main.rs index 78e59a03affcb30378b5afb95bd0fe34fd03f21b..2857659fd12482dfccf7ec508559823994cfb6b8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -12,7 +12,7 @@ use tokio::{ use anyhow::Result; use clap::Parser; -use crate::event::handle_events; +use crate::event::{handle_events, Event}; mod action; mod condition; @@ -84,18 +84,23 @@ async fn inner_main() -> anyhow::Result<()> { let (stop_signal_tx, _stop_signal_rx) = broadcast::channel(1); let (event_tx, event_rx) = mpsc::channel(50); + let lifecycle_event_tx = event_tx.clone(); let mut task_handles = vec![]; let watcher_handle = watcher::watch(&conn, stop_signal_tx.subscribe(), event_tx).await?; let event_handler_handle = - handle_events(conn.clone(), config, stop_signal_tx.subscribe(), event_rx).await?; + handle_events(conn.clone(), config, event_rx).await?; task_handles.push(watcher_handle); task_handles.push(event_handler_handle); + lifecycle_event_tx.send(Event::DaemonStarted).await?; + signal::ctrl_c().await?; + lifecycle_event_tx.send(Event::DaemonStopping).await?; + stop_signal_tx.send(())?; for task_handle in task_handles {