diff --git a/src/main.rs b/src/main.rs index 645a3c5..80d1b2f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,14 +7,14 @@ use tower_http::cors::{Any, CorsLayer}; use tower_http::trace::TraceLayer; use tracing::Span; mod color; -mod routes; mod models; +mod redis_migrations; +mod repository; +mod routes; use models::*; -use tokio::sync::broadcast; - - +use crate::repository::Repository; #[tokio::main] async fn main() { @@ -23,31 +23,24 @@ async fn main() { .allow_headers(Any) .allow_methods(Any); - let redis_string = env::var("REDIS_STRING").expect("REDIS_STRING is not set"); - let jwt_key = env::var("JWT_KEY").expect("JWT_KEY is not set"); - let client = redis::Client::open(redis_string.to_owned()).expect("Could not connect to redis"); - let manager = redis::aio::ConnectionManager::new(client.clone()) - .await - .unwrap(); - let instance_properties = InstanceProperties { demo: env::var("INSTANCE_DEMO").unwrap_or("false".to_owned()) == "true", donation: env::var("INSTANCE_DONATION_PAYPAL") .map(|id| Some(vec![DonationMethod::PayPal(id)])) .unwrap_or(None), + s3_host: env::var("S3_HOST").unwrap_or("".to_owned()), }; + let jwt_key = env::var("JWT_KEY").expect("JWT_KEY is not set"); + let redis_string = env::var("REDIS_STRING").expect("REDIS_STRING is not set"); - let (redis_task_tx, redis_task_rx) = broadcast::channel::(10); + let repository = Repository::new(redis_string).await; let state: SharedState = Arc::new(AppState { - redis: manager.clone(), + repository, jwt_key, instance_properties, - redis_task_rx, }); - routes::ws::spawn_global_redis_listener_task(manager, client, redis_task_tx); - let app = Router::new() .nest("/api/ws", routes::ws::routes()) .nest("/api/timer", routes::timer::routes(state.clone())) diff --git a/src/models.rs b/src/models.rs index a46805d..3276759 100644 --- a/src/models.rs +++ b/src/models.rs @@ -1,54 +1,6 @@ +use crate::repository::{DisplayOptions, Repository, Segment, Timer, TimerMetadata}; use serde::{Deserialize, Serialize}; -use crate::color::Color; use std::sync::Arc; -use tokio::sync::broadcast; - -//main.rs -fn default_zero() -> u32 { - 0 -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub struct Segment { - label: String, - time: u32, - sound: bool, - color: Option, - #[serde(default = "default_zero")] - count_to: u32, -} - -#[derive(Serialize, Deserialize, Clone, Debug)] -pub enum PreStartBehaviour { - ShowZero, - RunNormally, -} - -impl Default for PreStartBehaviour { - fn default() -> Self { - PreStartBehaviour::ShowZero - } -} - -#[derive(Serialize, Deserialize, Clone, Default, Debug)] -pub struct DisplayOptions { - #[serde(default)] - clock: bool, - #[serde(default)] - pre_start_behaviour: PreStartBehaviour, -} - -#[derive(Serialize, Deserialize, Clone, Default, Debug)] -pub struct Timer { - // Return after TimerRequest - pub segments: Vec, - pub repeat: bool, - pub display_options: Option, - pub start_at: u64, - pub stop_at: Option, - pub password: String, - pub id: String, // 5 random chars -} #[derive(Serialize, Clone)] #[serde(tag = "type", content = "data")] @@ -60,17 +12,16 @@ pub enum DonationMethod { pub struct InstanceProperties { pub demo: bool, pub donation: Option>, + pub s3_host: String, } pub type SharedState = Arc; pub struct AppState { - pub redis: redis::aio::ConnectionManager, + pub repository: Repository, pub jwt_key: String, pub instance_properties: InstanceProperties, - pub redis_task_rx: broadcast::Receiver, } - //timer.rs #[derive(Serialize, Deserialize, Debug)] @@ -81,6 +32,7 @@ pub struct TimerResponse { pub display_options: DisplayOptions, pub start_at: u64, pub stop_at: Option, + pub metadata: TimerMetadata, } impl Into for Timer { @@ -89,9 +41,10 @@ impl Into for Timer { segments: self.segments, id: self.id, repeat: self.repeat, - display_options: self.display_options.unwrap_or(DisplayOptions::default()), + display_options: self.display_options, start_at: self.start_at, stop_at: self.stop_at, + metadata: self.metadata, } } } @@ -104,20 +57,36 @@ pub struct TimerCreationResponse { #[derive(Serialize, Deserialize)] pub struct TimerCreationRequest { - // Get from User pub segments: Vec, pub id: String, pub password: String, pub repeat: bool, pub start_at: u64, + pub metadata: TimerMetadata, + pub display_options: DisplayOptions, +} + +impl TimerCreationRequest { + pub fn into(self, hashed_password: String) -> Timer { + Timer { + segments: self.segments, + repeat: self.repeat, + display_options: self.display_options, + start_at: self.start_at, + stop_at: None, + password: hashed_password, + id: self.id, + metadata: self.metadata, + } + } } #[derive(Serialize, Deserialize)] pub struct TimerUpdateRequest { - // Get from User pub segments: Vec, pub repeat: bool, - pub display_options: Option, + pub display_options: DisplayOptions, + pub metadata: TimerMetadata, pub start_at: u64, pub stop_at: Option, } @@ -139,3 +108,30 @@ pub struct Claims { pub struct TokenResponse { pub token: String, } + +/// +/// Websocket +/// + +#[derive(Serialize, Deserialize, Debug)] +pub struct WsTimerResponse { + pub segments: Vec, + pub id: String, + pub repeat: bool, + pub display_options: DisplayOptions, + pub start_at: u64, + pub stop_at: Option, +} + +impl Into for Timer { + fn into(self) -> WsTimerResponse { + WsTimerResponse { + segments: self.segments, + id: self.id, + repeat: self.repeat, + display_options: self.display_options, + start_at: self.start_at, + stop_at: self.stop_at, + } + } +} diff --git a/src/redis_migrations/display_options.rs b/src/redis_migrations/display_options.rs new file mode 100644 index 0000000..0ad9f87 --- /dev/null +++ b/src/redis_migrations/display_options.rs @@ -0,0 +1,48 @@ +use serde::Deserialize; + +use crate::repository::DisplayOptions; + +use super::pre_start_behaviour::RedisPreStartBehaviour; + +#[derive(Deserialize, Clone)] +#[serde(untagged)] +pub enum RedisDisplayOptions { + V0(DisplayOptionsV0), +} + +impl Into for RedisDisplayOptions { + fn into(self) -> DisplayOptions { + match self { + RedisDisplayOptions::V0(v0) => v0.into(), + } + } +} + +impl Into for Option { + fn into(self) -> DisplayOptions { + self.map(|o| o.into()).unwrap_or(DisplayOptions::default()) + } +} + +#[derive(Deserialize, Clone)] +pub struct DisplayOptionsV0 { + #[serde(default)] + clock: bool, + #[serde(default)] + pre_start_behaviour: RedisPreStartBehaviour, +} + +impl Into for DisplayOptionsV0 { + fn into(self) -> DisplayOptions { + DisplayOptions { + clock: self.clock, + pre_start_behaviour: self.pre_start_behaviour.into(), + } + } +} + +impl Into for Option { + fn into(self) -> DisplayOptions { + self.map(|o| o.into()).unwrap_or(DisplayOptions::default()) + } +} diff --git a/src/redis_migrations/mod.rs b/src/redis_migrations/mod.rs new file mode 100644 index 0000000..2cd32c3 --- /dev/null +++ b/src/redis_migrations/mod.rs @@ -0,0 +1,8 @@ +mod display_options; +mod pre_start_behaviour; +mod segment; +mod tests; +mod timer; +mod timer_metadata; + +pub use timer::RedisTimer; diff --git a/src/redis_migrations/pre_start_behaviour.rs b/src/redis_migrations/pre_start_behaviour.rs new file mode 100644 index 0000000..be9f891 --- /dev/null +++ b/src/redis_migrations/pre_start_behaviour.rs @@ -0,0 +1,71 @@ +use serde::Deserialize; + +use crate::repository::PreStartBehaviour; + +#[derive(Deserialize, Clone)] +#[serde(untagged)] +pub enum RedisPreStartBehaviour { + V1(PreStartBehaviourV1), + V0(PreStartBehaviourV0), +} + +impl Into for RedisPreStartBehaviour { + fn into(self) -> PreStartBehaviour { + match self { + RedisPreStartBehaviour::V0(v0) => v0.into(), + RedisPreStartBehaviour::V1(v1) => v1.into(), + } + } +} + +impl Default for RedisPreStartBehaviour { + fn default() -> Self { + Self::V1(PreStartBehaviourV1::default()) + } +} + +/// === V1 === +#[derive(Deserialize, Clone)] +pub enum PreStartBehaviourV1 { + ShowFirstSegment, + ShowLastSegment, + RunNormally, +} + +impl Default for PreStartBehaviourV1 { + fn default() -> Self { + PreStartBehaviourV1::ShowFirstSegment + } +} + +impl Into for PreStartBehaviourV1 { + fn into(self) -> PreStartBehaviour { + match self { + PreStartBehaviourV1::RunNormally => PreStartBehaviour::RunNormally, + PreStartBehaviourV1::ShowFirstSegment => PreStartBehaviour::ShowFirstSegment, + PreStartBehaviourV1::ShowLastSegment => PreStartBehaviour::ShowLastSegment, + } + } +} + +/// === V0 === +#[derive(Deserialize, Clone)] +pub enum PreStartBehaviourV0 { + ShowZero, + RunNormally, +} + +impl Default for PreStartBehaviourV0 { + fn default() -> Self { + PreStartBehaviourV0::ShowZero + } +} + +impl Into for PreStartBehaviourV0 { + fn into(self) -> PreStartBehaviour { + match self { + PreStartBehaviourV0::RunNormally => PreStartBehaviour::RunNormally, + PreStartBehaviourV0::ShowZero => PreStartBehaviour::ShowFirstSegment, + } + } +} diff --git a/src/redis_migrations/segment.rs b/src/redis_migrations/segment.rs new file mode 100644 index 0000000..34c53fa --- /dev/null +++ b/src/redis_migrations/segment.rs @@ -0,0 +1,43 @@ +use serde::Deserialize; + +use crate::{color::Color, repository::Segment}; + +#[derive(Deserialize, Clone)] +#[serde(untagged)] +pub enum RedisSegment { + V0(SegmentV0), +} + +impl Into for RedisSegment { + fn into(self) -> Segment { + match self { + RedisSegment::V0(v0) => v0.into(), + } + } +} + +fn default_zero() -> u32 { + 0 +} + +#[derive(Deserialize, Clone)] +pub struct SegmentV0 { + label: String, + time: u32, + sound: bool, + color: Option, + #[serde(default = "default_zero")] + count_to: u32, +} + +impl Into for SegmentV0 { + fn into(self) -> Segment { + Segment { + label: self.label, + time: self.time, + sound: self.sound, + color: self.color, + count_to: self.count_to, + } + } +} diff --git a/src/redis_migrations/tests.rs b/src/redis_migrations/tests.rs new file mode 100644 index 0000000..10fd6c6 --- /dev/null +++ b/src/redis_migrations/tests.rs @@ -0,0 +1,79 @@ +#[allow(unused_imports)] +use crate::repository::PreStartBehaviour; +#[allow(unused_imports)] +use crate::{redis_migrations::timer::RedisTimer, repository::Timer}; + +#[test] +fn test_v0() { + let payload = r##" + { + "segments":[ + { + "label":"Boulder", + "time":230000, + "sound":true, + "color":"#26A269", + "count_to":11000 + } + ], + "id":"v0", + "repeat":true, + "display_options":{ + "clock":false, + "pre_start_behaviour":"ShowZero" + }, + "start_at":1688236579108, + "stop_at":null, + "password": "test" + } + "##; + + let timer: RedisTimer = serde_json::from_str(payload).unwrap(); + let timer: Timer = timer.into(); + assert_eq!(timer.segments.len(), 1); + assert_eq!(timer.segments[0].label, "Boulder"); + assert_eq!(timer.metadata.delay_start_stop, 0); + assert_eq!( + timer.display_options.pre_start_behaviour, + PreStartBehaviour::ShowFirstSegment + ); +} + +#[test] +fn test_v1() { + let payload = r##" + { + "segments":[ + { + "label":"Boulder", + "time":230000, + "sound":true, + "color":"#26A269", + "count_to":11000 + } + ], + "id":"v0", + "repeat":true, + "display_options":{ + "clock":false, + "pre_start_behaviour":"ShowLastSegment" + }, + "start_at":1688236579108, + "stop_at":null, + "password": "test", + "metadata": { + "delay_start_stop": 5 + } + } + "##; + + let timer: RedisTimer = serde_json::from_str(payload).unwrap(); + let timer: Timer = timer.into(); + assert_eq!(timer.segments.len(), 1); + assert_eq!(timer.segments[0].label, "Boulder"); + assert_eq!(timer.metadata.delay_start_stop, 5); + assert_eq!( + timer.display_options.pre_start_behaviour, + PreStartBehaviour::ShowLastSegment + ); +} diff --git a/src/redis_migrations/timer.rs b/src/redis_migrations/timer.rs new file mode 100644 index 0000000..9cc0107 --- /dev/null +++ b/src/redis_migrations/timer.rs @@ -0,0 +1,78 @@ +use serde::Deserialize; + +use crate::repository::{Timer, TimerMetadata}; + +use super::display_options::RedisDisplayOptions; +use super::segment::RedisSegment; +use super::timer_metadata::RedisTimerMetadata; + +#[derive(Deserialize, Clone)] +#[serde(untagged)] +pub enum RedisTimer { + V1(TimerV1), + V0(TimerV0), +} + +impl Into for RedisTimer { + fn into(self) -> Timer { + match self { + RedisTimer::V0(t) => t.into(), + RedisTimer::V1(t) => t.into(), + } + } +} + +/// === V1 === +#[derive(Deserialize, Clone)] +pub struct TimerV1 { + pub segments: Vec, + pub repeat: bool, + pub display_options: Option, + pub start_at: u64, + pub stop_at: Option, + pub password: String, + pub id: String, + pub metadata: RedisTimerMetadata, +} + +impl Into for TimerV1 { + fn into(self) -> Timer { + Timer { + segments: self.segments.into_iter().map(|s| s.into()).collect(), + repeat: self.repeat, + display_options: self.display_options.into(), + start_at: self.start_at, + stop_at: self.stop_at, + password: self.password, + id: self.id, + metadata: self.metadata.into(), + } + } +} + +/// === V0 === +#[derive(Deserialize, Clone)] +pub struct TimerV0 { + pub segments: Vec, + pub repeat: bool, + pub display_options: Option, + pub start_at: u64, + pub stop_at: Option, + pub password: String, + pub id: String, +} + +impl Into for TimerV0 { + fn into(self) -> Timer { + Timer { + segments: self.segments.into_iter().map(|s| s.into()).collect(), + repeat: self.repeat, + display_options: self.display_options.into(), + start_at: self.start_at, + stop_at: self.stop_at, + password: self.password, + id: self.id, + metadata: TimerMetadata::default(), + } + } +} diff --git a/src/redis_migrations/timer_metadata.rs b/src/redis_migrations/timer_metadata.rs new file mode 100644 index 0000000..e0d69ae --- /dev/null +++ b/src/redis_migrations/timer_metadata.rs @@ -0,0 +1,30 @@ +use serde::Deserialize; + +use crate::repository::TimerMetadata; + +#[derive(Deserialize, Clone)] +#[serde(untagged)] +pub enum RedisTimerMetadata { + V0(TimerMetadataV0), +} + +impl Into for RedisTimerMetadata { + fn into(self) -> TimerMetadata { + match self { + RedisTimerMetadata::V0(v0) => v0.into(), + } + } +} + +#[derive(Deserialize, Clone)] +pub struct TimerMetadataV0 { + pub delay_start_stop: u32, +} + +impl Into for TimerMetadataV0 { + fn into(self) -> TimerMetadata { + TimerMetadata { + delay_start_stop: self.delay_start_stop, + } + } +} diff --git a/src/repository.rs b/src/repository.rs new file mode 100644 index 0000000..3210dc8 --- /dev/null +++ b/src/repository.rs @@ -0,0 +1,172 @@ +use std::sync::Arc; + +use crate::{color::Color, redis_migrations::RedisTimer}; +use futures::StreamExt; +use serde::{Deserialize, Serialize}; + +use redis::AsyncCommands; +use tokio::{ + sync::broadcast::{self, Receiver}, + task::JoinHandle, +}; + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Segment { + pub label: String, + pub time: u32, + pub sound: bool, + pub color: Option, + pub count_to: u32, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub struct Sound { + pub filename: String, + pub play_at: u32, +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub enum PreStartBehaviour { + ShowFirstSegment, + ShowLastSegment, + RunNormally, +} + +impl Default for PreStartBehaviour { + fn default() -> Self { + PreStartBehaviour::ShowFirstSegment + } +} + +#[derive(Serialize, Deserialize, Clone, Default, Debug)] +pub struct DisplayOptions { + pub clock: bool, + pub pre_start_behaviour: PreStartBehaviour, +} + +#[derive(Serialize, Deserialize, Clone, Default, Debug)] +pub struct TimerMetadata { + pub delay_start_stop: u32, +} + +#[derive(Serialize, Deserialize, Clone, Default, Debug)] +pub struct Timer { + pub segments: Vec, + pub repeat: bool, + pub display_options: DisplayOptions, + pub start_at: u64, + pub stop_at: Option, + pub password: String, + pub id: String, + pub metadata: TimerMetadata, +} + +#[derive(Clone)] +pub struct Repository { + redis: redis::aio::ConnectionManager, + pub updates_rx: Arc>, +} + +impl Repository { + pub async fn new(redis_string: String) -> Self { + let client = + redis::Client::open(redis_string.to_owned()).expect("Could not connect to redis"); + let manager = redis::aio::ConnectionManager::new(client.clone()) + .await + .unwrap(); + + let (redis_task_tx, redis_task_rx) = broadcast::channel::(10); + spawn_global_redis_listener_task(manager.clone(), client, redis_task_tx); + + Repository { + redis: manager, + updates_rx: Arc::new(redis_task_rx), + } + } + + pub async fn get_timer(&self, id: String) -> Option { + let mut redis = self.redis.clone(); + let timer = &redis.get::(id).await; + + if timer.is_err() { + return None; + } + + let timer: RedisTimer = serde_json::from_str(timer.as_ref().unwrap()).unwrap(); + Some(timer.into()) + } + + pub async fn create_timer(&self, timer: &Timer) -> Result<(), ()> { + let mut redis = self.redis.clone(); + if redis + .exists::(timer.id.clone()) + .await + .unwrap() + { + return Err(()); + } + + redis + .set::(timer.id.clone(), serde_json::to_string(timer).unwrap()) + .await + .unwrap(); + + return Ok(()); + } + + pub async fn update_timer(&self, timer: &Timer) { + let mut redis = self.redis.clone(); + redis + .set::(timer.id.clone(), serde_json::to_string(timer).unwrap()) + .await + .unwrap(); + } + + pub async fn delete_timer(&self, id: String) -> Result<(), ()> { + self.redis + .clone() + .del::(id) + .await + .map_err(|_| ()) + } +} + +pub fn spawn_global_redis_listener_task( + mut redis: redis::aio::ConnectionManager, + redis_client: redis::Client, + redis_task_tx: broadcast::Sender, +) -> JoinHandle<()> { + tokio::spawn(async move { + let mut connection = redis_client.get_async_connection().await.unwrap(); + let _: () = redis::cmd("CONFIG") + .arg("SET") + .arg("notify-keyspace-events") + .arg("KEA") + .query_async(&mut connection) + .await + .unwrap(); + + let mut pubsub = connection.into_pubsub(); + + pubsub + .psubscribe("__keyspace@*__:*") + .await + .expect("Failed to subscribe to redis channel"); + + let mut pubsub = pubsub.into_on_message(); + + while let Some(msg) = pubsub.next().await { + println!("Updated! {:?}", msg); + let timer_id = msg.get_channel_name().split(":").last().unwrap(); + + let timer_str = &redis + .get::(String::from(timer_id)) + .await + .expect("Did not find timer in redis"); + let timer: Timer = serde_json::from_str(timer_str).unwrap(); + + // Broadcast to all listeners + redis_task_tx.send(timer).unwrap(); + } + }) +} diff --git a/src/routes/timer.rs b/src/routes/timer.rs index 276a99f..af8dd81 100644 --- a/src/routes/timer.rs +++ b/src/routes/timer.rs @@ -10,8 +10,6 @@ use axum::{ Json, TypedHeader, }; use jsonwebtoken::{decode, encode, Algorithm, DecodingKey, EncodingKey, Header, Validation}; -use redis::aio::ConnectionManager; -use redis::AsyncCommands; use regex::Regex; use std::str; @@ -21,6 +19,7 @@ use argon2::{ }; use crate::models::*; +use crate::repository::Timer; async fn auth_middleware( State(state): State, @@ -61,20 +60,6 @@ fn check_password_hash(password: &str, password_hash: &str) -> bool { .is_ok() } -async fn get_timer_from_redis( - id: String, - redis: &mut ConnectionManager, -) -> Result { - let timer = &redis.get::(id).await; - - if timer.is_err() { - return Err(StatusCode::UNAUTHORIZED); - } - - let timer: Timer = serde_json::from_str(timer.as_ref().unwrap()).unwrap(); - Ok(timer) -} - fn create_jwt(id: String, key: &str) -> String { let claims = Claims { id: id, @@ -94,9 +79,12 @@ async fn create_token( State(state): State, Json(request): Json, ) -> Result, StatusCode> { - let mut redis = state.as_ref().redis.clone(); - - let timer = get_timer_from_redis(request.id.clone(), &mut redis).await?; + let timer = state + .repository + .get_timer(request.id.clone()) + .await + .map(|t| Ok(t)) + .unwrap_or(Err(StatusCode::UNAUTHORIZED))?; if !check_password_hash(&request.password, &timer.password) { return Err(StatusCode::UNAUTHORIZED); @@ -115,32 +103,15 @@ async fn create_timer( if !id_regex.is_match(&request.id) { return Err(StatusCode::BAD_REQUEST); } - // Timer already exists - let mut redis = state.as_ref().redis.clone(); - if redis - .exists::(request.id.clone()) - .await - .unwrap() - { - return Err(StatusCode::CONFLICT); - } - let password = hash_password(&request.password); + let hashed_password = hash_password(&request.password); + let timer = request.into(hashed_password); - let timer = Timer { - segments: request.segments, - repeat: request.repeat, - start_at: request.start_at, - stop_at: None, - display_options: None, - password, - id: request.id, - }; - - redis - .set::(timer.id.clone(), serde_json::to_string(&timer).unwrap()) + state + .repository + .create_timer(&timer) .await - .unwrap(); + .map_err(|_| StatusCode::CONFLICT)?; let token = create_jwt(timer.id.clone(), &state.jwt_key); @@ -154,8 +125,12 @@ async fn get_timer( State(state): State, Path(id): Path, ) -> Result, StatusCode> { - let mut redis = state.as_ref().redis.clone(); - let timer = get_timer_from_redis(id, &mut redis).await?; + let timer = state + .repository + .get_timer(id) + .await + .map(|t| Ok(t)) + .unwrap_or(Err(StatusCode::UNAUTHORIZED))?; Ok(Json(timer.into())) } @@ -164,23 +139,24 @@ async fn update_timer( Path(id): Path, Json(request): Json, ) -> Result, StatusCode> { - let mut redis = state.as_ref().redis.clone(); - - let old_timer: Timer = get_timer_from_redis(id, &mut redis).await?; + let old_timer: Timer = state + .repository + .get_timer(id) + .await + .map(|t| Ok(t)) + .unwrap_or(Err(StatusCode::UNAUTHORIZED))?; let timer = Timer { segments: request.segments, repeat: request.repeat, display_options: request.display_options, + metadata: request.metadata, start_at: request.start_at, stop_at: request.stop_at, ..old_timer }; - redis - .set::(timer.id.clone(), serde_json::to_string(&timer).unwrap()) - .await - .unwrap(); + state.repository.update_timer(&timer).await; Ok(Json(timer.into())) } @@ -189,12 +165,12 @@ async fn delete_timer( State(state): State, Path(id): Path, ) -> impl IntoResponse { - let mut redis = state.as_ref().redis.clone(); - if redis.del::(id).await.is_err() { - StatusCode::NOT_FOUND - } else { - StatusCode::OK - } + state + .repository + .delete_timer(id) + .await + .map(|_| StatusCode::OK) + .map_err(|_| StatusCode::NOT_FOUND) } pub fn routes(state: SharedState) -> Router { diff --git a/src/routes/ws.rs b/src/routes/ws.rs index 587a21d..94840a1 100644 --- a/src/routes/ws.rs +++ b/src/routes/ws.rs @@ -16,13 +16,13 @@ use futures::{ use serde::{Deserialize, Serialize}; use serde_json; -use redis::AsyncCommands; -use tokio::sync::broadcast; use tokio::sync::mpsc::{Receiver, Sender}; use tokio::task::JoinHandle; -use crate::SharedState; -use crate::Timer; +use crate::{ + repository::{Repository, Timer}, + SharedState, +}; use std::time::{SystemTime, UNIX_EPOCH}; @@ -33,7 +33,7 @@ use crate::models::*; enum WSMessage { Hello(String), GetTime, - Timer(TimerResponse), + Timer(WsTimerResponse), Timestamp(u128), Error((u128, String)), } @@ -48,7 +48,7 @@ impl WsConnection { let ws_sender_task = WsConnection::spawn_ws_sender_task(ws_sender, ws_message_rx); let ws_receiver_task = WsConnection::spawn_ws_receiver_task( - state.redis.clone(), + state.repository.clone(), ws_message_tx.clone(), redis_listen_id_tx, ws_receiver, @@ -56,7 +56,7 @@ impl WsConnection { let redis_listener_task = WsConnection::spawn_redis_listener_task( ws_message_tx, redis_listen_id_rx, - state.redis_task_rx.resubscribe(), + state.repository.updates_rx.resubscribe(), ); ws_receiver_task.await.unwrap(); @@ -104,21 +104,21 @@ impl WsConnection { } fn spawn_ws_receiver_task( - redis: redis::aio::ConnectionManager, + repository: Repository, ws_message_tx: Sender, redis_listen_id_tx: Sender, ws_receiver: SplitStream, ) -> JoinHandle<()> { tokio::spawn(async move { let mut message_handler = - WsMessageHandler::new(redis, ws_message_tx, redis_listen_id_tx, ws_receiver); + WsMessageHandler::new(repository, ws_message_tx, redis_listen_id_tx, ws_receiver); message_handler.listen().await; }) } } struct WsMessageHandler { - redis: redis::aio::ConnectionManager, + repository: Repository, ws_message_tx: Sender, redis_listen_id_tx: Sender, ws_receiver: SplitStream, @@ -126,13 +126,13 @@ struct WsMessageHandler { impl WsMessageHandler { fn new( - redis: redis::aio::ConnectionManager, + repository: Repository, ws_message_tx: Sender, redis_listen_id_tx: Sender, ws_receiver: SplitStream, ) -> Self { WsMessageHandler { - redis, + repository, ws_message_tx, redis_listen_id_tx, ws_receiver, @@ -174,14 +174,10 @@ impl WsMessageHandler { self.redis_listen_id_tx.send(id.clone()).await.unwrap(); - let timer_string = self.redis.get::(id).await; - if timer_string.is_err() { - return WSMessage::Error((404, "Timer not found!".to_owned())); - } - - let timer: Timer = serde_json::from_str(&timer_string.unwrap()).unwrap(); - - WSMessage::Timer(timer.into()) + self.repository.get_timer(id).await.map_or_else( + || WSMessage::Error((404, "Timer not found!".to_owned())), + |t| WSMessage::Timer(t.into()), + ) } } @@ -192,46 +188,6 @@ pub async fn ws_handler( ws.on_upgrade(move |socket| WsConnection::new(state, socket)) } -pub fn spawn_global_redis_listener_task( - mut redis: redis::aio::ConnectionManager, - redis_client: redis::Client, - redis_task_tx: broadcast::Sender, -) -> JoinHandle<()> { - tokio::spawn(async move { - let mut connection = redis_client.get_async_connection().await.unwrap(); - let _: () = redis::cmd("CONFIG") - .arg("SET") - .arg("notify-keyspace-events") - .arg("KEA") - .query_async(&mut connection) - .await - .unwrap(); - - let mut pubsub = connection.into_pubsub(); - - pubsub - .psubscribe("__keyspace@*__:*") - .await - .expect("Failed to subscribe to redis channel"); - - let mut pubsub = pubsub.into_on_message(); - - while let Some(msg) = pubsub.next().await { - println!("Updated! {:?}", msg); - let timer_id = msg.get_channel_name().split(":").last().unwrap(); - - let timer_str = &redis - .get::(String::from(timer_id)) - .await - .expect("Did not find timer in redis"); - let timer: Timer = serde_json::from_str(timer_str).unwrap(); - - // Broadcast to all listeners - redis_task_tx.send(timer).unwrap(); - } - }) -} - pub fn routes() -> Router { Router::new().route("/", get(ws_handler)) }