diff --git a/core/connectors/runtime/src/api/sink.rs b/core/connectors/runtime/src/api/sink.rs index 0a2c8791bc..50c9d9b7a4 100644 --- a/core/connectors/runtime/src/api/sink.rs +++ b/core/connectors/runtime/src/api/sink.rs @@ -30,7 +30,7 @@ use axum::{ extract::{Path, Query, State}, http::{HeaderMap, StatusCode, header}, response::IntoResponse, - routing::get, + routing::{get, post}, }; use serde::Deserialize; use std::sync::Arc; @@ -52,6 +52,7 @@ pub fn router(state: Arc) -> Router { "/sinks/{key}/configs/active", get(get_sink_active_config).put(update_sink_active_config), ) + .route("/sinks/{key}/restart", post(restart_sink)) .with_state(state) } @@ -246,3 +247,20 @@ async fn delete_sink_config( .await?; Ok(StatusCode::NO_CONTENT) } + +async fn restart_sink( + State(context): State>, + Path(key): Path, +) -> Result { + context + .sinks + .restart_connector( + &key, + context.config_provider.as_ref(), + &context.iggy_clients.consumer, + &context.metrics, + &context, + ) + .await?; + Ok(StatusCode::NO_CONTENT) +} diff --git a/core/connectors/runtime/src/api/source.rs b/core/connectors/runtime/src/api/source.rs index 7db1c9974f..0387421fb4 100644 --- a/core/connectors/runtime/src/api/source.rs +++ b/core/connectors/runtime/src/api/source.rs @@ -30,7 +30,7 @@ use axum::{ extract::{Path, Query, State}, http::{HeaderMap, StatusCode, header}, response::IntoResponse, - routing::get, + routing::{get, post}, }; use serde::Deserialize; use std::sync::Arc; @@ -55,6 +55,7 @@ pub fn router(state: Arc) -> Router { "/sources/{key}/configs/active", get(get_source_active_config).put(update_source_active_config), ) + .route("/sources/{key}/restart", post(restart_source)) .with_state(state) } @@ -249,3 +250,21 @@ async fn delete_source_config( .await?; Ok(StatusCode::NO_CONTENT) } + +async fn restart_source( + State(context): State>, + Path(key): Path, +) -> Result { + context + .sources + .restart_connector( + &key, + context.config_provider.as_ref(), + &context.iggy_clients.producer, + &context.metrics, + &context.state_path, + &context, + ) + .await?; + Ok(StatusCode::NO_CONTENT) +} diff --git a/core/connectors/runtime/src/context.rs b/core/connectors/runtime/src/context.rs index 292ea226c0..a87095463b 100644 --- a/core/connectors/runtime/src/context.rs +++ b/core/connectors/runtime/src/context.rs @@ -19,6 +19,7 @@ use crate::configs::connectors::{ConnectorsConfigProvider, SinkConfig, SourceConfig}; use crate::configs::runtime::ConnectorsRuntimeConfig; use crate::metrics::Metrics; +use crate::stream::IggyClients; use crate::{ SinkConnectorWrapper, SourceConnectorWrapper, manager::{ @@ -31,6 +32,7 @@ use iggy_connector_sdk::api::ConnectorError; use iggy_connector_sdk::api::ConnectorStatus; use std::collections::HashMap; use std::sync::Arc; +use tokio::sync::Mutex; use tracing::error; pub struct RuntimeContext { @@ -40,8 +42,11 @@ pub struct RuntimeContext { pub config_provider: Arc, pub metrics: Arc, pub start_time: IggyTimestamp, + pub iggy_clients: Arc, + pub state_path: String, } +#[allow(clippy::too_many_arguments)] pub fn init( config: &ConnectorsRuntimeConfig, sinks_config: &HashMap, @@ -49,6 +54,8 @@ pub fn init( sink_wrappers: &[SinkConnectorWrapper], source_wrappers: &[SourceConnectorWrapper], config_provider: Box, + iggy_clients: Arc, + state_path: String, ) -> RuntimeContext { let metrics = Arc::new(Metrics::init()); let sinks = SinkManager::new(map_sinks(sinks_config, sink_wrappers)); @@ -64,6 +71,8 @@ pub fn init( config_provider: Arc::from(config_provider), metrics, start_time: IggyTimestamp::now(), + iggy_clients, + state_path, } } @@ -103,6 +112,10 @@ fn map_sinks( plugin_config_format: sink_plugin.config_format, }, config: sink_config.clone(), + shutdown_tx: None, + task_handles: vec![], + container: None, + restart_guard: Arc::new(Mutex::new(())), }); } } @@ -145,6 +158,9 @@ fn map_sources( plugin_config_format: source_plugin.config_format, }, config: source_config.clone(), + handler_tasks: vec![], + container: None, + restart_guard: Arc::new(Mutex::new(())), }); } } diff --git a/core/connectors/runtime/src/main.rs b/core/connectors/runtime/src/main.rs index 1d53e71cfd..cca15d47a4 100644 --- a/core/connectors/runtime/src/main.rs +++ b/core/connectors/runtime/src/main.rs @@ -29,6 +29,7 @@ use figlet_rs::FIGfont; use iggy::prelude::{Client, IggyConsumer, IggyProducer}; use iggy_connector_sdk::{ StreamDecoder, StreamEncoder, + api::ConnectorStatus, sink::ConsumeCallback, source::{HandleCallback, SendCallback}, transforms::Transform, @@ -40,7 +41,7 @@ use std::{ env, sync::{Arc, atomic::AtomicU32}, }; -use tracing::info; +use tracing::{error, info}; mod api; pub(crate) mod configs; @@ -69,8 +70,8 @@ static PLUGIN_ID: AtomicU32 = AtomicU32::new(1); const ALLOWED_PLUGIN_EXTENSIONS: [&str; 3] = ["so", "dylib", "dll"]; const DEFAULT_CONFIG_PATH: &str = "core/connectors/runtime/config.toml"; -#[derive(WrapperApi)] -struct SourceApi { +#[derive(WrapperApi, Debug)] +pub(crate) struct SourceApi { iggy_source_open: extern "C" fn( id: u32, config_ptr: *const u8, @@ -84,8 +85,8 @@ struct SourceApi { iggy_source_version: extern "C" fn() -> *const std::ffi::c_char, } -#[derive(WrapperApi)] -struct SinkApi { +#[derive(WrapperApi, Debug)] +pub(crate) struct SinkApi { iggy_sink_open: extern "C" fn( id: u32, config_ptr: *const u8, @@ -143,7 +144,7 @@ async fn main() -> Result<(), RuntimeError> { info!("State will be stored in: {}", config.state.path); - let iggy_clients = stream::init(config.iggy.clone()).await?; + let iggy_clients = Arc::new(stream::init(config.iggy.clone()).await?); let connectors_config_provider: Box = create_connectors_config_provider(&config.connectors).await?; @@ -166,47 +167,31 @@ async fn main() -> Result<(), RuntimeError> { let sinks = sink::init(sinks_config.clone(), &iggy_clients.consumer).await?; let mut sink_wrappers = vec![]; - let mut sink_with_plugins = HashMap::new(); - for (key, sink) in sinks { - let plugin_ids = sink - .plugins - .iter() - .filter(|plugin| plugin.error.is_none()) - .map(|plugin| plugin.id) - .collect(); + let mut sink_containers_by_key: HashMap>> = HashMap::new(); + for (_path, sink) in sinks { + let container = Arc::new(sink.container); + let callback = container.iggy_sink_consume; + for plugin in &sink.plugins { + sink_containers_by_key.insert(plugin.key.clone(), container.clone()); + } sink_wrappers.push(SinkConnectorWrapper { - callback: sink.container.iggy_sink_consume, + callback, plugins: sink.plugins, }); - sink_with_plugins.insert( - key, - SinkWithPlugins { - container: sink.container, - plugin_ids, - }, - ); } let mut source_wrappers = vec![]; - let mut source_with_plugins = HashMap::new(); - for (key, source) in sources { - let plugin_ids = source - .plugins - .iter() - .filter(|plugin| plugin.error.is_none()) - .map(|plugin| plugin.id) - .collect(); + let mut source_containers_by_key: HashMap>> = HashMap::new(); + for (_path, source) in sources { + let container = Arc::new(source.container); + let callback = container.iggy_source_handle; + for plugin in &source.plugins { + source_containers_by_key.insert(plugin.key.clone(), container.clone()); + } source_wrappers.push(SourceConnectorWrapper { - callback: source.container.iggy_source_handle, + callback, plugins: source.plugins, }); - source_with_plugins.insert( - key, - SourceWithPlugins { - container: source.container, - plugin_ids, - }, - ); } let context = context::init( @@ -216,13 +201,47 @@ async fn main() -> Result<(), RuntimeError> { &sink_wrappers, &source_wrappers, connectors_config_provider, + iggy_clients.clone(), + config.state.path.clone(), ); + for (key, container) in sink_containers_by_key { + if let Some(details) = context.sinks.get(&key).await { + let mut details = details.lock().await; + details.container = Some(container); + } + } + for (key, container) in source_containers_by_key { + if let Some(details) = context.sources.get(&key).await { + let mut details = details.lock().await; + details.container = Some(container); + } + } + let context = Arc::new(context); - api::init(&config.http, context.clone()).await; - let source_handler_tasks = source::handle(source_wrappers, context.clone()); - sink::consume(sink_wrappers, context.clone()); + let source_handles = source::handle(source_wrappers, context.clone()); + for (key, handler_tasks) in source_handles { + if let Some(details) = context.sources.get(&key).await { + let mut details = details.lock().await; + details.handler_tasks = handler_tasks; + } + } + + let sink_handles = sink::consume(sink_wrappers, context.clone()); + for (key, shutdown_tx, task_handles) in sink_handles { + if let Some(details) = context.sinks.get(&key).await { + let mut details = details.lock().await; + details.shutdown_tx = Some(shutdown_tx); + details.task_handles = task_handles; + } + context + .sinks + .update_status(&key, ConnectorStatus::Running, Some(&context.metrics)) + .await; + } + info!("All sources and sinks spawned."); + api::init(&config.http, context.clone()).await; #[cfg(unix)] let (mut ctrl_c, mut sigterm) = { @@ -243,26 +262,37 @@ async fn main() -> Result<(), RuntimeError> { } } - for (key, source) in source_with_plugins { - for id in source.plugin_ids { - info!("Closing source connector with ID: {id} for plugin: {key}"); - source.container.iggy_source_close(id); - source::cleanup_sender(id); - info!("Closed source connector with ID: {id} for plugin: {key}"); + let source_keys: Vec = context + .sources + .get_all() + .await + .into_iter() + .map(|s| s.key) + .collect(); + for key in &source_keys { + if let Err(err) = context + .sources + .stop_connector_with_guard(key, &context.metrics) + .await + { + error!("Failed to stop source connector: {key}. {err}"); } } - // Wait for source handler tasks to drain remaining messages and persist state - // before shutting down the Iggy clients they depend on. - for handle in source_handler_tasks { - let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await; - } - - for (key, sink) in sink_with_plugins { - for id in sink.plugin_ids { - info!("Closing sink connector with ID: {id} for plugin: {key}"); - sink.container.iggy_sink_close(id); - info!("Closed sink connector with ID: {id} for plugin: {key}"); + let sink_keys: Vec = context + .sinks + .get_all() + .await + .into_iter() + .map(|s| s.key) + .collect(); + for key in &sink_keys { + if let Err(err) = context + .sinks + .stop_connector_with_guard(key, &context.metrics) + .await + { + error!("Failed to stop sink connector: {key}. {err}"); } } @@ -400,11 +430,6 @@ struct SinkConnectorWrapper { plugins: Vec, } -struct SinkWithPlugins { - container: Container, - plugin_ids: Vec, -} - struct SourceConnector { container: Container, plugins: Vec, @@ -429,11 +454,6 @@ struct SourceConnectorProducer { producer: IggyProducer, } -struct SourceWithPlugins { - container: Container, - plugin_ids: Vec, -} - struct SourceConnectorWrapper { callback: HandleCallback, plugins: Vec, diff --git a/core/connectors/runtime/src/manager/sink.rs b/core/connectors/runtime/src/manager/sink.rs index 0bc058201d..9c6cafa1ef 100644 --- a/core/connectors/runtime/src/manager/sink.rs +++ b/core/connectors/runtime/src/manager/sink.rs @@ -16,13 +16,25 @@ * specific language governing permissions and limitations * under the License. */ -use crate::configs::connectors::{ConfigFormat, SinkConfig}; +use crate::PLUGIN_ID; +use crate::SinkApi; +use crate::configs::connectors::{ConfigFormat, ConnectorsConfigProvider, SinkConfig}; +use crate::context::RuntimeContext; +use crate::error::RuntimeError; use crate::metrics::Metrics; +use crate::sink; use dashmap::DashMap; +use dlopen2::wrapper::Container; +use iggy::prelude::IggyClient; use iggy_connector_sdk::api::{ConnectorError, ConnectorStatus}; use std::collections::HashMap; +use std::fmt; use std::sync::Arc; -use tokio::sync::Mutex; +use std::sync::atomic::Ordering; +use std::time::Duration; +use tokio::sync::{Mutex, watch}; +use tokio::task::JoinHandle; +use tracing::info; #[derive(Debug)] pub struct SinkManager { @@ -96,6 +108,166 @@ impl SinkManager { sink.info.last_error = Some(ConnectorError::new(error_message)); } } + + pub async fn stop_connector_with_guard( + &self, + key: &str, + metrics: &Arc, + ) -> Result<(), RuntimeError> { + let guard = { + let details = self + .sinks + .get(key) + .map(|e| e.value().clone()) + .ok_or_else(|| RuntimeError::SinkNotFound(key.to_string()))?; + let details = details.lock().await; + details.restart_guard.clone() + }; + let _lock = guard.lock().await; + self.stop_connector(key, metrics).await + } + + pub async fn stop_connector( + &self, + key: &str, + metrics: &Arc, + ) -> Result<(), RuntimeError> { + let details = self + .sinks + .get(key) + .map(|e| e.value().clone()) + .ok_or_else(|| RuntimeError::SinkNotFound(key.to_string()))?; + + let (shutdown_tx, task_handles, plugin_id, container) = { + let mut details = details.lock().await; + ( + details.shutdown_tx.take(), + std::mem::take(&mut details.task_handles), + details.info.id, + details.container.clone(), + ) + }; + + if let Some(tx) = shutdown_tx { + let _ = tx.send(()); + } + + for handle in task_handles { + let _ = tokio::time::timeout(Duration::from_secs(5), handle).await; + } + + if let Some(container) = &container { + info!("Closing sink connector with ID: {plugin_id} for plugin: {key}"); + (container.iggy_sink_close)(plugin_id); + info!("Closed sink connector with ID: {plugin_id} for plugin: {key}"); + } + + { + let mut details = details.lock().await; + let old_status = details.info.status; + details.info.status = ConnectorStatus::Stopped; + details.info.last_error = None; + if old_status == ConnectorStatus::Running { + metrics.decrement_sinks_running(); + } + } + + Ok(()) + } + + pub async fn start_connector( + &self, + key: &str, + config: &SinkConfig, + iggy_client: &IggyClient, + metrics: &Arc, + context: &Arc, + ) -> Result<(), RuntimeError> { + let details = self + .sinks + .get(key) + .map(|e| e.value().clone()) + .ok_or_else(|| RuntimeError::SinkNotFound(key.to_string()))?; + + let container = { + let details = details.lock().await; + details.container.clone().ok_or_else(|| { + RuntimeError::InvalidConfiguration(format!("No container loaded for sink: {key}")) + })? + }; + + let plugin_id = PLUGIN_ID.fetch_add(1, Ordering::SeqCst); + + sink::init_sink( + &container, + &config.plugin_config.clone().unwrap_or_default(), + plugin_id, + )?; + info!("Sink connector with ID: {plugin_id} for plugin: {key} initialized successfully."); + + let consumers = sink::setup_sink_consumers(key, config, iggy_client).await?; + + let callback = container.iggy_sink_consume; + let (shutdown_tx, task_handles) = sink::spawn_consume_tasks( + plugin_id, + key, + consumers, + callback, + config.verbose, + metrics, + context.clone(), + ); + + { + let mut details = details.lock().await; + details.info.id = plugin_id; + details.info.status = ConnectorStatus::Running; + details.info.last_error = None; + details.config = config.clone(); + details.shutdown_tx = Some(shutdown_tx); + details.task_handles = task_handles; + metrics.increment_sinks_running(); + } + + Ok(()) + } + + pub async fn restart_connector( + &self, + key: &str, + config_provider: &dyn ConnectorsConfigProvider, + iggy_client: &IggyClient, + metrics: &Arc, + context: &Arc, + ) -> Result<(), RuntimeError> { + let guard = { + let details = self + .sinks + .get(key) + .map(|e| e.value().clone()) + .ok_or_else(|| RuntimeError::SinkNotFound(key.to_string()))?; + let details = details.lock().await; + details.restart_guard.clone() + }; + let Ok(_lock) = guard.try_lock() else { + info!("Restart already in progress for sink connector: {key}, skipping."); + return Ok(()); + }; + + info!("Restarting sink connector: {key}"); + self.stop_connector(key, metrics).await?; + + let config = config_provider + .get_sink_config(key, None) + .await + .map_err(|e| RuntimeError::InvalidConfiguration(e.to_string()))? + .ok_or_else(|| RuntimeError::SinkNotFound(key.to_string()))?; + + self.start_connector(key, &config, iggy_client, metrics, context) + .await?; + info!("Sink connector: {key} restarted successfully."); + Ok(()) + } } #[derive(Debug, Clone)] @@ -111,8 +283,302 @@ pub struct SinkInfo { pub plugin_config_format: Option, } -#[derive(Debug)] pub struct SinkDetails { pub info: SinkInfo, pub config: SinkConfig, + pub shutdown_tx: Option>, + pub task_handles: Vec>, + pub container: Option>>, + pub restart_guard: Arc>, +} + +impl fmt::Debug for SinkDetails { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SinkDetails") + .field("info", &self.info) + .field("config", &self.config) + .field("container", &self.container.as_ref().map(|_| "...")) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::configs::connectors::SinkConfig; + + fn create_test_sink_info(key: &str, id: u32) -> SinkInfo { + SinkInfo { + id, + key: key.to_string(), + name: format!("{key} sink"), + path: format!("/path/to/{key}"), + version: "1.0.0".to_string(), + enabled: true, + status: ConnectorStatus::Running, + last_error: None, + plugin_config_format: None, + } + } + + fn create_test_sink_details(key: &str, id: u32) -> SinkDetails { + SinkDetails { + info: create_test_sink_info(key, id), + config: SinkConfig { + key: key.to_string(), + enabled: true, + version: 1, + name: format!("{key} sink"), + path: format!("/path/to/{key}"), + ..Default::default() + }, + shutdown_tx: None, + task_handles: vec![], + container: None, + restart_guard: Arc::new(Mutex::new(())), + } + } + + #[tokio::test] + async fn should_create_manager_with_sinks() { + let manager = SinkManager::new(vec![ + create_test_sink_details("es", 1), + create_test_sink_details("pg", 2), + ]); + + let all = manager.get_all().await; + assert_eq!(all.len(), 2); + } + + #[tokio::test] + async fn should_get_existing_sink() { + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + + let sink = manager.get("es").await; + assert!(sink.is_some()); + let binding = sink.unwrap(); + let details = binding.lock().await; + assert_eq!(details.info.key, "es"); + assert_eq!(details.info.id, 1); + } + + #[tokio::test] + async fn should_return_none_for_unknown_key() { + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + + assert!(manager.get("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn should_get_config() { + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + + let config = manager.get_config("es").await; + assert!(config.is_some()); + assert_eq!(config.unwrap().key, "es"); + } + + #[tokio::test] + async fn should_return_none_config_for_unknown_key() { + let manager = SinkManager::new(vec![]); + + assert!(manager.get_config("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn should_get_all_sinks() { + let manager = SinkManager::new(vec![ + create_test_sink_details("es", 1), + create_test_sink_details("pg", 2), + create_test_sink_details("stdout", 3), + ]); + + let all = manager.get_all().await; + assert_eq!(all.len(), 3); + let keys: Vec = all.iter().map(|s| s.key.clone()).collect(); + assert!(keys.contains(&"es".to_string())); + assert!(keys.contains(&"pg".to_string())); + assert!(keys.contains(&"stdout".to_string())); + } + + #[tokio::test] + async fn should_update_status() { + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + + manager + .update_status("es", ConnectorStatus::Stopped, None) + .await; + + let sink = manager.get("es").await.unwrap(); + let details = sink.lock().await; + assert_eq!(details.info.status, ConnectorStatus::Stopped); + } + + #[tokio::test] + async fn should_increment_metrics_when_transitioning_to_running() { + let metrics = Arc::new(Metrics::init()); + let mut details = create_test_sink_details("es", 1); + details.info.status = ConnectorStatus::Stopped; + let manager = SinkManager::new(vec![details]); + + manager + .update_status("es", ConnectorStatus::Running, Some(&metrics)) + .await; + + assert_eq!(metrics.get_sinks_running(), 1); + } + + #[tokio::test] + async fn should_decrement_metrics_when_leaving_running() { + let metrics = Arc::new(Metrics::init()); + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + metrics.increment_sinks_running(); + + manager + .update_status("es", ConnectorStatus::Stopped, Some(&metrics)) + .await; + + assert_eq!(metrics.get_sinks_running(), 0); + } + + #[tokio::test] + async fn should_clear_error_when_status_becomes_running() { + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + manager.set_error("es", "some error").await; + + manager + .update_status("es", ConnectorStatus::Running, None) + .await; + + let sink = manager.get("es").await.unwrap(); + let details = sink.lock().await; + assert!(details.info.last_error.is_none()); + } + + #[tokio::test] + async fn should_set_error_status_and_message() { + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + + manager.set_error("es", "connection failed").await; + + let sink = manager.get("es").await.unwrap(); + let details = sink.lock().await; + assert_eq!(details.info.status, ConnectorStatus::Error); + assert!(details.info.last_error.is_some()); + } + + #[tokio::test] + async fn stop_should_return_not_found_for_unknown_key() { + let metrics = Arc::new(Metrics::init()); + let manager = SinkManager::new(vec![]); + + let result = manager.stop_connector("nonexistent", &metrics).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, RuntimeError::SinkNotFound(_))); + } + + #[tokio::test] + async fn stop_should_send_shutdown_signal_and_update_status() { + let metrics = Arc::new(Metrics::init()); + metrics.increment_sinks_running(); + let (shutdown_tx, mut shutdown_rx) = watch::channel(()); + let handle = tokio::spawn(async move { + let _ = shutdown_rx.changed().await; + }); + let mut details = create_test_sink_details("es", 1); + details.shutdown_tx = Some(shutdown_tx); + details.task_handles = vec![handle]; + let manager = SinkManager::new(vec![details]); + + let result = manager.stop_connector("es", &metrics).await; + assert!(result.is_ok()); + + let sink = manager.get("es").await.unwrap(); + let details = sink.lock().await; + assert_eq!(details.info.status, ConnectorStatus::Stopped); + assert!(details.shutdown_tx.is_none()); + assert!(details.task_handles.is_empty()); + } + + #[tokio::test] + async fn stop_should_work_without_container() { + let metrics = Arc::new(Metrics::init()); + let mut details = create_test_sink_details("es", 1); + details.container = None; + details.info.status = ConnectorStatus::Stopped; + let manager = SinkManager::new(vec![details]); + + let result = manager.stop_connector("es", &metrics).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn stop_should_decrement_metrics_from_running() { + let metrics = Arc::new(Metrics::init()); + metrics.increment_sinks_running(); + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + + manager.stop_connector("es", &metrics).await.unwrap(); + + assert_eq!(metrics.get_sinks_running(), 0); + } + + #[tokio::test] + async fn should_clear_error_when_status_becomes_stopped() { + let manager = SinkManager::new(vec![create_test_sink_details("es", 1)]); + manager.set_error("es", "some error").await; + + manager + .update_status("es", ConnectorStatus::Stopped, None) + .await; + + let sink = manager.get("es").await.unwrap(); + let details = sink.lock().await; + assert_eq!(details.info.status, ConnectorStatus::Stopped); + assert!(details.info.last_error.is_none()); + } + + #[tokio::test] + async fn stop_should_clear_last_error() { + let metrics = Arc::new(Metrics::init()); + let mut details = create_test_sink_details("es", 1); + details.info.status = ConnectorStatus::Error; + details.info.last_error = Some(ConnectorError::new("previous error")); + let manager = SinkManager::new(vec![details]); + + manager.stop_connector("es", &metrics).await.unwrap(); + + let sink = manager.get("es").await.unwrap(); + let details = sink.lock().await; + assert!(details.info.last_error.is_none()); + } + + #[tokio::test] + async fn stop_should_not_decrement_metrics_from_non_running() { + let metrics = Arc::new(Metrics::init()); + let mut details = create_test_sink_details("es", 1); + details.info.status = ConnectorStatus::Stopped; + let manager = SinkManager::new(vec![details]); + + manager.stop_connector("es", &metrics).await.unwrap(); + + assert_eq!(metrics.get_sinks_running(), 0); + } + + #[tokio::test] + async fn update_status_should_be_noop_for_unknown_key() { + let manager = SinkManager::new(vec![]); + + manager + .update_status("nonexistent", ConnectorStatus::Running, None) + .await; + } + + #[tokio::test] + async fn set_error_should_be_noop_for_unknown_key() { + let manager = SinkManager::new(vec![]); + + manager.set_error("nonexistent", "some error").await; + } } diff --git a/core/connectors/runtime/src/manager/source.rs b/core/connectors/runtime/src/manager/source.rs index b259fd8cf4..14fbb2c0d8 100644 --- a/core/connectors/runtime/src/manager/source.rs +++ b/core/connectors/runtime/src/manager/source.rs @@ -16,13 +16,26 @@ * specific language governing permissions and limitations * under the License. */ -use crate::configs::connectors::{ConfigFormat, SourceConfig}; +use crate::PLUGIN_ID; +use crate::SourceApi; +use crate::configs::connectors::{ConfigFormat, ConnectorsConfigProvider, SourceConfig}; +use crate::context::RuntimeContext; +use crate::error::RuntimeError; use crate::metrics::Metrics; +use crate::source; +use crate::state::{StateProvider, StateStorage}; use dashmap::DashMap; +use dlopen2::wrapper::Container; +use iggy::prelude::IggyClient; use iggy_connector_sdk::api::{ConnectorError, ConnectorStatus}; use std::collections::HashMap; +use std::fmt; use std::sync::Arc; +use std::sync::atomic::Ordering; +use std::time::Duration; use tokio::sync::Mutex; +use tokio::task::JoinHandle; +use tracing::info; #[derive(Debug)] pub struct SourceManager { @@ -96,6 +109,173 @@ impl SourceManager { source.info.last_error = Some(ConnectorError::new(error_message)); } } + + pub async fn stop_connector_with_guard( + &self, + key: &str, + metrics: &Arc, + ) -> Result<(), RuntimeError> { + let guard = { + let details = self + .sources + .get(key) + .map(|e| e.value().clone()) + .ok_or_else(|| RuntimeError::SourceNotFound(key.to_string()))?; + let details = details.lock().await; + details.restart_guard.clone() + }; + let _lock = guard.lock().await; + self.stop_connector(key, metrics).await + } + + pub async fn stop_connector( + &self, + key: &str, + metrics: &Arc, + ) -> Result<(), RuntimeError> { + let details = self + .sources + .get(key) + .map(|e| e.value().clone()) + .ok_or_else(|| RuntimeError::SourceNotFound(key.to_string()))?; + + let (task_handles, plugin_id, container) = { + let mut details = details.lock().await; + ( + std::mem::take(&mut details.handler_tasks), + details.info.id, + details.container.clone(), + ) + }; + + source::cleanup_sender(plugin_id); + + for handle in task_handles { + let _ = tokio::time::timeout(Duration::from_secs(5), handle).await; + } + + if let Some(container) = &container { + info!("Closing source connector with ID: {plugin_id} for plugin: {key}"); + (container.iggy_source_close)(plugin_id); + info!("Closed source connector with ID: {plugin_id} for plugin: {key}"); + } + + { + let mut details = details.lock().await; + let old_status = details.info.status; + details.info.status = ConnectorStatus::Stopped; + details.info.last_error = None; + if old_status == ConnectorStatus::Running { + metrics.decrement_sources_running(); + } + } + + Ok(()) + } + + pub async fn start_connector( + &self, + key: &str, + config: &SourceConfig, + iggy_client: &IggyClient, + metrics: &Arc, + state_path: &str, + context: &Arc, + ) -> Result<(), RuntimeError> { + let details = self + .sources + .get(key) + .map(|e| e.value().clone()) + .ok_or_else(|| RuntimeError::SourceNotFound(key.to_string()))?; + + let container = { + let details = details.lock().await; + details.container.clone().ok_or_else(|| { + RuntimeError::InvalidConfiguration(format!("No container loaded for source: {key}")) + })? + }; + + let plugin_id = PLUGIN_ID.fetch_add(1, Ordering::SeqCst); + + let state_storage = source::get_state_storage(state_path, key); + let state = match &state_storage { + StateStorage::File(file) => file.load().await?, + }; + + source::init_source( + &container, + &config.plugin_config.clone().unwrap_or_default(), + plugin_id, + state, + )?; + info!("Source connector with ID: {plugin_id} for plugin: {key} initialized successfully."); + + let (producer, encoder, transforms) = + source::setup_source_producer(key, config, iggy_client).await?; + + let callback = container.iggy_source_handle; + let handler_tasks = source::spawn_source_handler( + plugin_id, + key, + config.verbose, + producer, + encoder, + transforms, + state_storage, + callback, + context.clone(), + ); + + { + let mut details = details.lock().await; + details.info.id = plugin_id; + details.info.status = ConnectorStatus::Running; + details.info.last_error = None; + details.config = config.clone(); + details.handler_tasks = handler_tasks; + metrics.increment_sources_running(); + } + + Ok(()) + } + + pub async fn restart_connector( + &self, + key: &str, + config_provider: &dyn ConnectorsConfigProvider, + iggy_client: &IggyClient, + metrics: &Arc, + state_path: &str, + context: &Arc, + ) -> Result<(), RuntimeError> { + let guard = { + let details = self + .sources + .get(key) + .map(|e| e.value().clone()) + .ok_or_else(|| RuntimeError::SourceNotFound(key.to_string()))?; + let details = details.lock().await; + details.restart_guard.clone() + }; + let Ok(_lock) = guard.try_lock() else { + info!("Restart already in progress for source connector: {key}, skipping."); + return Ok(()); + }; + + info!("Restarting source connector: {key}"); + self.stop_connector(key, metrics).await?; + + let config = config_provider + .get_source_config(key, None) + .await + .map_err(|e| RuntimeError::InvalidConfiguration(e.to_string()))? + .ok_or_else(|| RuntimeError::SourceNotFound(key.to_string()))?; + + self.start_connector(key, &config, iggy_client, metrics, state_path, context) + .await?; + info!("Source connector: {key} restarted successfully."); + Ok(()) + } } #[derive(Debug, Clone)] @@ -111,8 +291,295 @@ pub struct SourceInfo { pub plugin_config_format: Option, } -#[derive(Debug)] pub struct SourceDetails { pub info: SourceInfo, pub config: SourceConfig, + pub handler_tasks: Vec>, + pub container: Option>>, + pub restart_guard: Arc>, +} + +impl fmt::Debug for SourceDetails { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("SourceDetails") + .field("info", &self.info) + .field("config", &self.config) + .field("container", &self.container.as_ref().map(|_| "...")) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::configs::connectors::SourceConfig; + + fn create_test_source_info(key: &str, id: u32) -> SourceInfo { + SourceInfo { + id, + key: key.to_string(), + name: format!("{key} source"), + path: format!("/path/to/{key}"), + version: "1.0.0".to_string(), + enabled: true, + status: ConnectorStatus::Running, + last_error: None, + plugin_config_format: None, + } + } + + fn create_test_source_details(key: &str, id: u32) -> SourceDetails { + SourceDetails { + info: create_test_source_info(key, id), + config: SourceConfig { + key: key.to_string(), + enabled: true, + version: 1, + name: format!("{key} source"), + path: format!("/path/to/{key}"), + ..Default::default() + }, + handler_tasks: vec![], + container: None, + restart_guard: Arc::new(Mutex::new(())), + } + } + + #[tokio::test] + async fn should_create_manager_with_sources() { + let manager = SourceManager::new(vec![ + create_test_source_details("pg", 1), + create_test_source_details("random", 2), + ]); + + let all = manager.get_all().await; + assert_eq!(all.len(), 2); + } + + #[tokio::test] + async fn should_get_existing_source() { + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + + let source = manager.get("pg").await; + assert!(source.is_some()); + let binding = source.unwrap(); + let details = binding.lock().await; + assert_eq!(details.info.key, "pg"); + assert_eq!(details.info.id, 1); + } + + #[tokio::test] + async fn should_return_none_for_unknown_key() { + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + + assert!(manager.get("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn should_get_config() { + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + + let config = manager.get_config("pg").await; + assert!(config.is_some()); + assert_eq!(config.unwrap().key, "pg"); + } + + #[tokio::test] + async fn should_return_none_config_for_unknown_key() { + let manager = SourceManager::new(vec![]); + + assert!(manager.get_config("nonexistent").await.is_none()); + } + + #[tokio::test] + async fn should_get_all_sources() { + let manager = SourceManager::new(vec![ + create_test_source_details("pg", 1), + create_test_source_details("random", 2), + create_test_source_details("es", 3), + ]); + + let all = manager.get_all().await; + assert_eq!(all.len(), 3); + let keys: Vec = all.iter().map(|s| s.key.clone()).collect(); + assert!(keys.contains(&"pg".to_string())); + assert!(keys.contains(&"random".to_string())); + assert!(keys.contains(&"es".to_string())); + } + + #[tokio::test] + async fn should_update_status() { + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + + manager + .update_status("pg", ConnectorStatus::Stopped, None) + .await; + + let source = manager.get("pg").await.unwrap(); + let details = source.lock().await; + assert_eq!(details.info.status, ConnectorStatus::Stopped); + } + + #[tokio::test] + async fn should_increment_metrics_when_transitioning_to_running() { + let metrics = Arc::new(Metrics::init()); + let mut details = create_test_source_details("pg", 1); + details.info.status = ConnectorStatus::Stopped; + let manager = SourceManager::new(vec![details]); + + manager + .update_status("pg", ConnectorStatus::Running, Some(&metrics)) + .await; + + assert_eq!(metrics.get_sources_running(), 1); + } + + #[tokio::test] + async fn should_decrement_metrics_when_leaving_running() { + let metrics = Arc::new(Metrics::init()); + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + metrics.increment_sources_running(); + + manager + .update_status("pg", ConnectorStatus::Stopped, Some(&metrics)) + .await; + + assert_eq!(metrics.get_sources_running(), 0); + } + + #[tokio::test] + async fn should_clear_error_when_status_becomes_running() { + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + manager.set_error("pg", "some error").await; + + manager + .update_status("pg", ConnectorStatus::Running, None) + .await; + + let source = manager.get("pg").await.unwrap(); + let details = source.lock().await; + assert!(details.info.last_error.is_none()); + } + + #[tokio::test] + async fn should_set_error_status_and_message() { + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + + manager.set_error("pg", "connection failed").await; + + let source = manager.get("pg").await.unwrap(); + let details = source.lock().await; + assert_eq!(details.info.status, ConnectorStatus::Error); + assert!(details.info.last_error.is_some()); + } + + #[tokio::test] + async fn stop_should_return_not_found_for_unknown_key() { + let metrics = Arc::new(Metrics::init()); + let manager = SourceManager::new(vec![]); + + let result = manager.stop_connector("nonexistent", &metrics).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!(matches!(err, RuntimeError::SourceNotFound(_))); + } + + #[tokio::test] + async fn stop_should_drain_tasks_and_update_status() { + let metrics = Arc::new(Metrics::init()); + metrics.increment_sources_running(); + let handle = tokio::spawn(async {}); + let mut details = create_test_source_details("pg", 1); + details.handler_tasks = vec![handle]; + let manager = SourceManager::new(vec![details]); + + let result = manager.stop_connector("pg", &metrics).await; + assert!(result.is_ok()); + + let source = manager.get("pg").await.unwrap(); + let details = source.lock().await; + assert_eq!(details.info.status, ConnectorStatus::Stopped); + assert!(details.handler_tasks.is_empty()); + } + + #[tokio::test] + async fn stop_should_work_without_container() { + let metrics = Arc::new(Metrics::init()); + let mut details = create_test_source_details("pg", 1); + details.container = None; + details.info.status = ConnectorStatus::Stopped; + let manager = SourceManager::new(vec![details]); + + let result = manager.stop_connector("pg", &metrics).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn stop_should_decrement_metrics_from_running() { + let metrics = Arc::new(Metrics::init()); + metrics.increment_sources_running(); + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + + manager.stop_connector("pg", &metrics).await.unwrap(); + + assert_eq!(metrics.get_sources_running(), 0); + } + + #[tokio::test] + async fn should_clear_error_when_status_becomes_stopped() { + let manager = SourceManager::new(vec![create_test_source_details("pg", 1)]); + manager.set_error("pg", "some error").await; + + manager + .update_status("pg", ConnectorStatus::Stopped, None) + .await; + + let source = manager.get("pg").await.unwrap(); + let details = source.lock().await; + assert_eq!(details.info.status, ConnectorStatus::Stopped); + assert!(details.info.last_error.is_none()); + } + + #[tokio::test] + async fn stop_should_clear_last_error() { + let metrics = Arc::new(Metrics::init()); + let mut details = create_test_source_details("pg", 1); + details.info.status = ConnectorStatus::Error; + details.info.last_error = Some(ConnectorError::new("previous error")); + let manager = SourceManager::new(vec![details]); + + manager.stop_connector("pg", &metrics).await.unwrap(); + + let source = manager.get("pg").await.unwrap(); + let details = source.lock().await; + assert!(details.info.last_error.is_none()); + } + + #[tokio::test] + async fn stop_should_not_decrement_metrics_from_non_running() { + let metrics = Arc::new(Metrics::init()); + let mut details = create_test_source_details("pg", 1); + details.info.status = ConnectorStatus::Stopped; + let manager = SourceManager::new(vec![details]); + + manager.stop_connector("pg", &metrics).await.unwrap(); + + assert_eq!(metrics.get_sources_running(), 0); + } + + #[tokio::test] + async fn update_status_should_be_noop_for_unknown_key() { + let manager = SourceManager::new(vec![]); + + manager + .update_status("nonexistent", ConnectorStatus::Running, None) + .await; + } + + #[tokio::test] + async fn set_error_should_be_noop_for_unknown_key() { + let manager = SourceManager::new(vec![]); + + manager.set_error("nonexistent", "some error").await; + } } diff --git a/core/connectors/runtime/src/sink.rs b/core/connectors/runtime/src/sink.rs index 1e82703fa6..2074f845ca 100644 --- a/core/connectors/runtime/src/sink.rs +++ b/core/connectors/runtime/src/sink.rs @@ -31,7 +31,6 @@ use iggy::prelude::{ AutoCommit, AutoCommitWhen, IggyClient, IggyConsumer, IggyDuration, IggyMessage, PollingStrategy, }; -use iggy_connector_sdk::api::ConnectorStatus; use iggy_connector_sdk::{ DecodedMessage, MessagesMetadata, RawMessage, RawMessages, ReceivedMessage, StreamDecoder, TopicMetadata, sink::ConsumeCallback, transforms::Transform, @@ -42,6 +41,8 @@ use std::{ sync::{Arc, atomic::Ordering}, time::Instant, }; +use tokio::sync::watch; +use tokio::task::JoinHandle; use tracing::{debug, error, info, warn}; pub async fn init( @@ -50,7 +51,7 @@ pub async fn init( ) -> Result, RuntimeError> { let mut sink_connectors: HashMap = HashMap::new(); for (key, config) in sink_configs { - let name = config.name; + let name = config.name.clone(); if !config.enabled { warn!("Sink: {name} is disabled ({key})"); continue; @@ -68,7 +69,7 @@ pub async fn init( let version = get_plugin_version(&container.container); init_error = init_sink( &container.container, - &config.plugin_config.unwrap_or_default(), + &config.plugin_config.clone().unwrap_or_default(), plugin_id, ) .err() @@ -96,7 +97,7 @@ pub async fn init( let version = get_plugin_version(&container); init_error = init_sink( &container, - &config.plugin_config.unwrap_or_default(), + &config.plugin_config.clone().unwrap_or_default(), plugin_id, ) .err() @@ -129,21 +130,7 @@ pub async fn init( ); } - let transforms = if let Some(transforms_config) = config.transforms { - let transforms = transform::load(&transforms_config).map_err(|error| { - RuntimeError::InvalidConfiguration(format!("Failed to load transforms: {error}")) - })?; - let types = transforms - .iter() - .map(|t| t.r#type().into()) - .collect::>() - .join(", "); - info!("Enabled transforms for sink: {name} ({key}): {types}",); - transforms - } else { - vec![] - }; - + let consumers = setup_sink_consumers(&key, &config, iggy_client).await?; let connector = sink_connectors.get_mut(&path).ok_or_else(|| { RuntimeError::InvalidConfiguration(format!("Sink connector not found for path: {path}")) })?; @@ -156,46 +143,24 @@ pub async fn init( "Sink plugin not found for ID: {plugin_id}" )) })?; - - for stream in config.streams.iter() { - let poll_interval = IggyDuration::from_str( - stream.poll_interval.as_deref().unwrap_or("5ms"), - ) - .map_err(|error| { - RuntimeError::InvalidConfiguration(format!("Invalid poll interval: {error}")) - })?; - let default_consumer_group = format!("iggy-connect-sink-{key}"); - let consumer_group = stream - .consumer_group - .as_deref() - .unwrap_or(&default_consumer_group); - let batch_length = stream.batch_length.unwrap_or(1000); - for topic in stream.topics.iter() { - let mut consumer = iggy_client - .consumer_group(consumer_group, &stream.stream, topic)? - .auto_commit(AutoCommit::When(AutoCommitWhen::PollingMessages)) - .create_consumer_group_if_not_exists() - .auto_join_consumer_group() - .polling_strategy(PollingStrategy::next()) - .poll_interval(poll_interval) - .batch_length(batch_length) - .build(); - - consumer.init().await?; - plugin.consumers.push(SinkConnectorConsumer { - consumer, - decoder: stream.schema.decoder(), - batch_size: batch_length, - transforms: transforms.clone(), - }); - } + for (consumer, decoder, batch_size, transforms) in consumers { + plugin.consumers.push(SinkConnectorConsumer { + consumer, + decoder, + batch_size, + transforms, + }); } } Ok(sink_connectors) } -pub fn consume(sinks: Vec, context: Arc) { +pub fn consume( + sinks: Vec, + context: Arc, +) -> Vec<(String, watch::Sender<()>, Vec>)> { + let mut handles = Vec::new(); for sink in sinks { for plugin in sink.plugins { if let Some(error) = &plugin.error { @@ -206,56 +171,79 @@ pub fn consume(sinks: Vec, context: Arc) { continue; } info!("Starting consume for sink with ID: {}...", plugin.id); - for consumer in plugin.consumers { - let plugin_key = plugin.key.clone(); - let context = context.clone(); - - tokio::spawn(async move { - context - .sinks - .update_status( - &plugin_key, - ConnectorStatus::Running, - Some(&context.metrics), - ) - .await; - - if let Err(error) = consume_messages( - plugin.id, - consumer.decoder, - consumer.batch_size, - sink.callback, - consumer.transforms, - consumer.consumer, - plugin.verbose, - &plugin_key, - &context.metrics, - ) - .await - { - let error_msg = format!( - "Failed to consume messages for sink connector with ID: {}. {error}", - plugin.id - ); - error!("{error_msg}"); - context - .metrics - .increment_errors(&plugin_key, ConnectorType::Sink); - context.sinks.set_error(&plugin_key, &error_msg).await; - return; - } - info!( - "Consume messages for sink connector with ID: {} started successfully.", - plugin.id - ); - }); - } + let consumers = plugin + .consumers + .into_iter() + .map(|c| (c.consumer, c.decoder, c.batch_size, c.transforms)) + .collect(); + let (shutdown_tx, task_handles) = spawn_consume_tasks( + plugin.id, + &plugin.key, + consumers, + sink.callback, + plugin.verbose, + &context.metrics, + context.clone(), + ); + handles.push((plugin.key, shutdown_tx, task_handles)); } } + handles +} + +#[allow(clippy::type_complexity)] +pub(crate) fn spawn_consume_tasks( + plugin_id: u32, + plugin_key: &str, + consumers: Vec<( + IggyConsumer, + Arc, + u32, + Vec>, + )>, + callback: ConsumeCallback, + verbose: bool, + metrics: &Arc, + context: Arc, +) -> (watch::Sender<()>, Vec>) { + let (shutdown_tx, shutdown_rx) = watch::channel(()); + let mut task_handles = Vec::new(); + for (consumer, decoder, batch_size, transforms) in consumers { + let plugin_key = plugin_key.to_string(); + let metrics = metrics.clone(); + let shutdown_rx = shutdown_rx.clone(); + let context = context.clone(); + let handle = tokio::spawn(async move { + if let Err(error) = consume_messages( + plugin_id, + decoder, + batch_size, + callback, + transforms, + consumer, + verbose, + &plugin_key, + &metrics, + shutdown_rx, + ) + .await + { + error!( + "Failed to consume messages for sink connector with ID: {plugin_id}: {error}" + ); + context + .sinks + .set_error(&plugin_key, &error.to_string()) + .await; + } + }); + task_handles.push(handle); + } + (shutdown_tx, task_handles) } #[allow(clippy::too_many_arguments)] -async fn consume_messages( +pub(crate) async fn consume_messages( plugin_id: u32, decoder: Arc, batch_size: u32, @@ -265,6 +253,7 @@ async fn consume_messages( verbose: bool, plugin_key: &str, metrics: &Arc, + mut shutdown_rx: watch::Receiver<()>, ) -> Result<(), RuntimeError> { info!("Started consuming messages for sink connector with ID: {plugin_id}"); let batch_size = batch_size as usize; @@ -274,7 +263,18 @@ async fn consume_messages( topic: consumer.topic().to_string(), }; - while let Some(message) = consumer.next().await { + loop { + let message = tokio::select! { + _ = shutdown_rx.changed() => { + info!("Sink connector with ID: {plugin_id} received shutdown signal"); + break; + } + msg = consumer.next() => msg, + }; + + let Some(message) = message else { + break; + }; let Ok(message) = message else { error!("Failed to receive message."); continue; @@ -356,7 +356,7 @@ fn get_plugin_version(container: &Container) -> String { } } -fn init_sink( +pub(crate) fn init_sink( container: &Container, plugin_config: &serde_json::Value, id: u32, @@ -377,6 +377,67 @@ fn init_sink( } } +pub(crate) async fn setup_sink_consumers( + key: &str, + config: &SinkConfig, + iggy_client: &IggyClient, +) -> Result< + Vec<( + IggyConsumer, + Arc, + u32, + Vec>, + )>, + RuntimeError, +> { + let transforms = if let Some(transforms_config) = &config.transforms { + let loaded = transform::load(transforms_config).map_err(|error| { + RuntimeError::InvalidConfiguration(format!("Failed to load transforms: {error}")) + })?; + for t in &loaded { + info!("Loaded transform: {:?} for sink: {key}", t.r#type()); + } + loaded + } else { + vec![] + }; + + let mut consumers = Vec::new(); + for stream in config.streams.iter() { + let poll_interval = IggyDuration::from_str( + stream.poll_interval.as_deref().unwrap_or("5ms"), + ) + .map_err(|error| { + RuntimeError::InvalidConfiguration(format!("Invalid poll interval: {error}")) + })?; + let default_consumer_group = format!("iggy-connect-sink-{key}"); + let consumer_group = stream + .consumer_group + .as_deref() + .unwrap_or(&default_consumer_group); + let batch_length = stream.batch_length.unwrap_or(1000); + for topic in stream.topics.iter() { + let mut consumer = iggy_client + .consumer_group(consumer_group, &stream.stream, topic)? + .auto_commit(AutoCommit::When(AutoCommitWhen::PollingMessages)) + .create_consumer_group_if_not_exists() + .auto_join_consumer_group() + .polling_strategy(PollingStrategy::next()) + .poll_interval(poll_interval) + .batch_length(batch_length) + .build(); + consumer.init().await?; + consumers.push(( + consumer, + stream.schema.decoder(), + batch_length, + transforms.clone(), + )); + } + } + Ok(consumers) +} + async fn process_messages( plugin_id: u32, messages_metadata: MessagesMetadata, diff --git a/core/connectors/runtime/src/source.rs b/core/connectors/runtime/src/source.rs index 8fe4225100..51ec060269 100644 --- a/core/connectors/runtime/src/source.rs +++ b/core/connectors/runtime/src/source.rs @@ -22,10 +22,11 @@ use dlopen2::wrapper::Container; use flume::{Receiver, Sender}; use iggy::prelude::{ DirectConfig, HeaderKey, HeaderValue, IggyClient, IggyDuration, IggyError, IggyMessage, + IggyProducer, }; use iggy_connector_sdk::{ ConnectorState, DecodedMessage, Error, ProducedMessages, StreamEncoder, TopicMetadata, - transforms::Transform, + source::HandleCallback, transforms::Transform, }; use once_cell::sync::Lazy; use std::{ @@ -46,6 +47,7 @@ use crate::{ transform, }; use iggy_connector_sdk::api::ConnectorStatus; +use tokio::task::JoinHandle; pub static SOURCE_SENDERS: Lazy>> = Lazy::new(DashMap::new); @@ -60,7 +62,7 @@ pub async fn init( ) -> Result, RuntimeError> { let mut source_connectors: HashMap = HashMap::new(); for (key, config) in source_configs { - let name = config.name; + let name = config.name.clone(); if !config.enabled { warn!("Source: {name} is disabled ({key})"); continue; @@ -82,7 +84,7 @@ pub async fn init( let version = get_plugin_version(&container.container); init_error = init_source( &container.container, - &config.plugin_config.unwrap_or_default(), + &config.plugin_config.clone().unwrap_or_default(), plugin_id, state, ) @@ -113,7 +115,7 @@ pub async fn init( let version = get_plugin_version(&container); init_error = init_source( &container, - &config.plugin_config.unwrap_or_default(), + &config.plugin_config.clone().unwrap_or_default(), plugin_id, state, ) @@ -149,20 +151,8 @@ pub async fn init( ); } - let transforms = if let Some(transforms_config) = config.transforms { - let transforms = transform::load(&transforms_config).map_err(|error| { - RuntimeError::InvalidConfiguration(format!("Failed to load transforms: {error}")) - })?; - let types = transforms - .iter() - .map(|t| t.r#type().into()) - .collect::>() - .join(", "); - info!("Enabled transforms for source: {name} ({key}): {types}",); - transforms - } else { - vec![] - }; + let (producer, encoder, transforms) = + setup_source_producer(&key, &config, iggy_client).await?; let connector = source_connectors.get_mut(&path).ok_or_else(|| { RuntimeError::InvalidConfiguration(format!( @@ -178,32 +168,8 @@ pub async fn init( "Source plugin not found for ID: {plugin_id}" )) })?; - - for stream in config.streams.iter() { - let linger_time = IggyDuration::from_str( - stream.linger_time.as_deref().unwrap_or("5ms"), - ) - .map_err(|error| { - RuntimeError::InvalidConfiguration(format!("Invalid linger time: {error}")) - })?; - let batch_length = stream.batch_length.unwrap_or(1000); - let producer = iggy_client - .producer(&stream.stream, &stream.topic)? - .direct( - DirectConfig::builder() - .batch_length(batch_length) - .linger_time(linger_time) - .build(), - ) - .build(); - - producer.init().await?; - plugin.producer = Some(SourceConnectorProducer { - producer, - encoder: stream.schema.encoder(), - }); - plugin.transforms = transforms.clone(); - } + plugin.producer = Some(SourceConnectorProducer { producer, encoder }); + plugin.transforms = transforms; } Ok(source_connectors) @@ -218,7 +184,7 @@ fn get_plugin_version(container: &Container) -> String { } } -fn init_source( +pub(crate) fn init_source( container: &Container, plugin_config: &serde_json::Value, id: u32, @@ -246,21 +212,256 @@ fn init_source( } } -fn get_state_storage(state_path: &str, key: &str) -> StateStorage { +pub(crate) fn get_state_storage(state_path: &str, key: &str) -> StateStorage { let path = format!("{state_path}/source_{key}.state"); StateStorage::File(FileStateProvider::new(path)) } +pub(crate) async fn setup_source_producer( + key: &str, + config: &SourceConfig, + iggy_client: &IggyClient, +) -> Result< + ( + IggyProducer, + Arc, + Vec>, + ), + RuntimeError, +> { + let transforms = if let Some(transforms_config) = &config.transforms { + let loaded = transform::load(transforms_config).map_err(|error| { + RuntimeError::InvalidConfiguration(format!("Failed to load transforms: {error}")) + })?; + for t in &loaded { + info!("Loaded transform: {:?} for source: {key}", t.r#type()); + } + loaded + } else { + vec![] + }; + + let mut last_producer = None; + let mut last_encoder = None; + for stream in config.streams.iter() { + let linger_time = IggyDuration::from_str(stream.linger_time.as_deref().unwrap_or("5ms")) + .map_err(|error| { + RuntimeError::InvalidConfiguration(format!("Invalid linger time: {error}")) + })?; + let batch_length = stream.batch_length.unwrap_or(1000); + let producer = iggy_client + .producer(&stream.stream, &stream.topic)? + .direct( + DirectConfig::builder() + .batch_length(batch_length) + .linger_time(linger_time) + .build(), + ) + .build(); + producer.init().await?; + last_encoder = Some(stream.schema.encoder()); + last_producer = Some(producer); + } + + let producer = last_producer.ok_or_else(|| { + RuntimeError::InvalidConfiguration("No streams configured for source".to_string()) + })?; + let encoder = last_encoder.ok_or_else(|| { + RuntimeError::InvalidConfiguration("No encoder configured for source".to_string()) + })?; + + Ok((producer, encoder, transforms)) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn source_forwarding_loop( + plugin_id: u32, + plugin_key: String, + verbose: bool, + producer: IggyProducer, + encoder: Arc, + transforms: Vec>, + state_storage: StateStorage, + receiver: Receiver, + context: Arc, +) { + info!("Source connector with ID: {plugin_id} started."); + context + .sources + .update_status( + &plugin_key, + ConnectorStatus::Running, + Some(&context.metrics), + ) + .await; + + let mut number = 1u64; + let topic_metadata = TopicMetadata { + stream: producer.stream().to_string(), + topic: producer.topic().to_string(), + }; + + while let Ok(produced_messages) = receiver.recv_async().await { + let count = produced_messages.messages.len(); + context + .metrics + .increment_messages_produced(&plugin_key, count as u64); + if verbose { + info!("Source connector with ID: {plugin_id} received {count} messages"); + } else { + debug!("Source connector with ID: {plugin_id} received {count} messages"); + } + let schema = produced_messages.schema; + let mut messages: Vec = Vec::with_capacity(count); + for message in produced_messages.messages { + let Ok(payload) = schema.try_into_payload(message.payload) else { + error!( + "Failed to decode message payload with schema: {schema} for source connector with ID: {plugin_id}", + ); + continue; + }; + + debug!( + "Source connector with ID: {plugin_id}] received message: {number} | schema: {schema} | payload: {payload}" + ); + messages.push(DecodedMessage { + id: message.id, + offset: None, + headers: message.headers, + checksum: message.checksum, + timestamp: message.timestamp, + origin_timestamp: message.origin_timestamp, + payload, + }); + number += 1; + } + + let Ok(iggy_messages) = + process_messages(plugin_id, &encoder, &topic_metadata, messages, &transforms) + else { + let error_msg = format!( + "Failed to process {count} messages by source connector with ID: {plugin_id} before sending them to stream: {}, topic: {}.", + producer.stream(), + producer.topic() + ); + error!("{error_msg}"); + context + .metrics + .increment_errors(&plugin_key, ConnectorType::Source); + context.sources.set_error(&plugin_key, &error_msg).await; + continue; + }; + + if let Err(error) = producer.send(iggy_messages).await { + let error_msg = format!( + "Failed to send {count} messages to stream: {}, topic: {} by source connector with ID: {plugin_id}. {error}", + producer.stream(), + producer.topic(), + ); + error!("{error_msg}"); + context + .metrics + .increment_errors(&plugin_key, ConnectorType::Source); + context.sources.set_error(&plugin_key, &error_msg).await; + continue; + } + + context + .metrics + .increment_messages_sent(&plugin_key, count as u64); + + if verbose { + info!( + "Sent {count} messages to stream: {}, topic: {} by source connector with ID: {plugin_id}", + producer.stream(), + producer.topic() + ); + } else { + debug!( + "Sent {count} messages to stream: {}, topic: {} by source connector with ID: {plugin_id}", + producer.stream(), + producer.topic() + ); + } + + let Some(state) = produced_messages.state else { + debug!("No state provided for source connector with ID: {plugin_id}"); + continue; + }; + + match &state_storage { + StateStorage::File(file) => { + if let Err(error) = file.save(state).await { + let error_msg = format!( + "Failed to save state for source connector with ID: {plugin_id}. {error}" + ); + error!("{error_msg}"); + context.sources.set_error(&plugin_key, &error_msg).await; + continue; + } + debug!("State saved for source connector with ID: {plugin_id}"); + } + } + } + + info!("Source connector with ID: {plugin_id} stopped."); + context + .sources + .update_status( + &plugin_key, + ConnectorStatus::Stopped, + Some(&context.metrics), + ) + .await; +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn spawn_source_handler( + plugin_id: u32, + plugin_key: &str, + verbose: bool, + producer: IggyProducer, + encoder: Arc, + transforms: Vec>, + state_storage: StateStorage, + callback: HandleCallback, + context: Arc, +) -> Vec> { + let (sender, receiver) = flume::unbounded(); + SOURCE_SENDERS.insert(plugin_id, sender); + + let blocking_handle = tokio::task::spawn_blocking(move || { + callback(plugin_id, handle_produced_messages); + }); + + let plugin_key = plugin_key.to_string(); + let handler_task = tokio::spawn(async move { + source_forwarding_loop( + plugin_id, + plugin_key, + verbose, + producer, + encoder, + transforms, + state_storage, + receiver, + context, + ) + .await; + }); + + vec![blocking_handle, handler_task] +} + pub fn handle( sources: Vec, context: Arc, -) -> Vec> { - let mut handler_tasks = Vec::new(); +) -> Vec<(String, Vec>)> { + let mut handles = Vec::new(); for source in sources { for plugin in source.plugins { let plugin_id = plugin.id; let plugin_key = plugin.key.clone(); - let context = context.clone(); if let Some(error) = &plugin.error { error!( @@ -270,165 +471,27 @@ pub fn handle( } info!("Starting handler for source connector with ID: {plugin_id}..."); - let handle = source.callback; - tokio::task::spawn_blocking(move || { - handle(plugin_id, handle_produced_messages); - }); - info!("Handler for source connector with ID: {plugin_id} started successfully."); - - let (sender, receiver): (Sender, Receiver) = - flume::unbounded(); - SOURCE_SENDERS.insert(plugin_id, sender); - let handler_task = tokio::spawn(async move { - info!("Source connector with ID: {plugin_id} started."); - let Some(producer) = &plugin.producer else { - error!("Producer not initialized for source connector with ID: {plugin_id}"); - context - .sources - .set_error(&plugin_key, "Producer not initialized") - .await; - return; - }; - - context - .sources - .update_status( - &plugin_key, - ConnectorStatus::Running, - Some(&context.metrics), - ) - .await; - let encoder = producer.encoder.clone(); - let producer = &producer.producer; - let mut number = 1u64; - - let topic_metadata = TopicMetadata { - stream: producer.stream().to_string(), - topic: producer.topic().to_string(), - }; - - while let Ok(produced_messages) = receiver.recv_async().await { - let count = produced_messages.messages.len(); - context - .metrics - .increment_messages_produced(&plugin_key, count as u64); - if plugin.verbose { - info!("Source connector with ID: {plugin_id} received {count} messages"); - } else { - debug!("Source connector with ID: {plugin_id} received {count} messages"); - } - let schema = produced_messages.schema; - let mut messages: Vec = Vec::with_capacity(count); - for message in produced_messages.messages { - let Ok(payload) = schema.try_into_payload(message.payload) else { - error!( - "Failed to decode message payload with schema: {} for source connector with ID: {plugin_id}", - produced_messages.schema - ); - continue; - }; - - debug!( - "Source connector with ID: {plugin_id}] received message: {number} | schema: {schema} | payload: {payload}" - ); - messages.push(DecodedMessage { - id: message.id, - offset: None, - headers: message.headers, - checksum: message.checksum, - timestamp: message.timestamp, - origin_timestamp: message.origin_timestamp, - payload, - }); - number += 1; - } - - let Ok(iggy_messages) = process_messages( - plugin_id, - &encoder, - &topic_metadata, - messages, - &plugin.transforms, - ) else { - let error_msg = format!( - "Failed to process {count} messages by source connector with ID: {plugin_id} before sending them to stream: {}, topic: {}.", - producer.stream(), - producer.topic() - ); - error!("{error_msg}"); - context - .metrics - .increment_errors(&plugin_key, ConnectorType::Source); - context.sources.set_error(&plugin_key, &error_msg).await; - continue; - }; - - if let Err(error) = producer.send(iggy_messages).await { - let error_msg = format!( - "Failed to send {count} messages to stream: {}, topic: {} by source connector with ID: {plugin_id}. {error}", - producer.stream(), - producer.topic(), - ); - error!("{error_msg}"); - context - .metrics - .increment_errors(&plugin_key, ConnectorType::Source); - context.sources.set_error(&plugin_key, &error_msg).await; - continue; - } - - context - .metrics - .increment_messages_sent(&plugin_key, count as u64); - - if plugin.verbose { - info!( - "Sent {count} messages to stream: {}, topic: {} by source connector with ID: {plugin_id}", - producer.stream(), - producer.topic() - ); - } else { - debug!( - "Sent {count} messages to stream: {}, topic: {} by source connector with ID: {plugin_id}", - producer.stream(), - producer.topic() - ); - } + let Some(producer_wrapper) = plugin.producer else { + error!("Producer not initialized for source connector with ID: {plugin_id}"); + continue; + }; - let Some(state) = produced_messages.state else { - debug!("No state provided for source connector with ID: {plugin_id}"); - continue; - }; - - match &plugin.state_storage { - StateStorage::File(file) => { - if let Err(error) = file.save(state).await { - let error_msg = format!( - "Failed to save state for source connector with ID: {plugin_id}. {error}" - ); - error!("{error_msg}"); - context.sources.set_error(&plugin_key, &error_msg).await; - continue; - } - debug!("State saved for source connector with ID: {plugin_id}"); - } - } - } + let handler_tasks = spawn_source_handler( + plugin_id, + &plugin_key, + plugin.verbose, + producer_wrapper.producer, + producer_wrapper.encoder, + plugin.transforms, + plugin.state_storage, + source.callback, + context.clone(), + ); - info!("Source connector with ID: {plugin_id} stopped."); - context - .sources - .update_status( - &plugin_key, - ConnectorStatus::Stopped, - Some(&context.metrics), - ) - .await; - }); - handler_tasks.push(handler_task); + handles.push((plugin_key, handler_tasks)); } } - handler_tasks + handles } fn process_messages( @@ -475,7 +538,7 @@ fn process_messages( Ok(iggy_messages) } -extern "C" fn handle_produced_messages( +pub(crate) extern "C" fn handle_produced_messages( plugin_id: u32, messages_ptr: *const u8, messages_len: usize, diff --git a/core/integration/tests/connectors/postgres/mod.rs b/core/integration/tests/connectors/postgres/mod.rs index cb16dfbfba..c91abf9aba 100644 --- a/core/integration/tests/connectors/postgres/mod.rs +++ b/core/integration/tests/connectors/postgres/mod.rs @@ -19,6 +19,7 @@ mod postgres_sink; mod postgres_source; +mod restart; use crate::connectors::TestMessage; use serde::Deserialize; diff --git a/core/integration/tests/connectors/postgres/restart.rs b/core/integration/tests/connectors/postgres/restart.rs new file mode 100644 index 0000000000..97d9f2422c --- /dev/null +++ b/core/integration/tests/connectors/postgres/restart.rs @@ -0,0 +1,238 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +use super::{POLL_ATTEMPTS, POLL_INTERVAL_MS, TEST_MESSAGE_COUNT}; +use crate::connectors::fixtures::{PostgresOps, PostgresSinkFixture}; +use crate::connectors::{TestMessage, create_test_messages}; +use bytes::Bytes; +use iggy::prelude::{IggyMessage, Partitioning}; +use iggy_binary_protocol::MessageClient; +use iggy_common::Identifier; +use iggy_connector_sdk::api::{ConnectorStatus, SinkInfoResponse}; +use integration::harness::seeds; +use integration::iggy_harness; +use reqwest::Client; +use std::time::Duration; +use tokio::time::sleep; + +const API_KEY: &str = "test-api-key"; +const SINK_TABLE: &str = "iggy_messages"; +const SINK_KEY: &str = "postgres"; + +type SinkRow = (i64, String, String, Vec); + +#[iggy_harness( + server(connectors_runtime(config_path = "tests/connectors/postgres/sink.toml")), + seed = seeds::connector_stream +)] +async fn restart_sink_connector_continues_processing( + harness: &TestHarness, + fixture: PostgresSinkFixture, +) { + let client = harness.root_client().await.unwrap(); + let api_url = harness + .connectors_runtime() + .expect("connector runtime should be available") + .http_url(); + let http = Client::new(); + let pool = fixture.create_pool().await.expect("Failed to create pool"); + + fixture.wait_for_table(&pool, SINK_TABLE).await; + + let stream_id: Identifier = seeds::names::STREAM.try_into().unwrap(); + let topic_id: Identifier = seeds::names::TOPIC.try_into().unwrap(); + + wait_for_sink_status(&http, &api_url, ConnectorStatus::Running).await; + + let first_batch = create_test_messages(TEST_MESSAGE_COUNT); + let mut messages = build_messages(&first_batch, 0); + client + .send_messages( + &stream_id, + &topic_id, + &Partitioning::partition_id(0), + &mut messages, + ) + .await + .expect("Failed to send first batch"); + + let query = format!( + "SELECT iggy_offset, iggy_stream, iggy_topic, payload FROM {SINK_TABLE} ORDER BY iggy_offset" + ); + let rows: Vec = fixture + .fetch_rows_as(&pool, &query, TEST_MESSAGE_COUNT) + .await + .expect("Failed to fetch first batch rows"); + + assert_eq!( + rows.len(), + TEST_MESSAGE_COUNT, + "Expected {TEST_MESSAGE_COUNT} rows before restart" + ); + + let resp = http + .post(format!("{api_url}/sinks/{SINK_KEY}/restart")) + .header("api-key", API_KEY) + .send() + .await + .expect("Failed to call restart endpoint"); + + assert_eq!( + resp.status().as_u16(), + 204, + "Restart endpoint should return 204 No Content" + ); + + wait_for_sink_status(&http, &api_url, ConnectorStatus::Running).await; + + let second_batch = create_test_messages(TEST_MESSAGE_COUNT); + let mut messages = build_messages(&second_batch, TEST_MESSAGE_COUNT); + client + .send_messages( + &stream_id, + &topic_id, + &Partitioning::partition_id(0), + &mut messages, + ) + .await + .expect("Failed to send second batch"); + + let total_expected = TEST_MESSAGE_COUNT * 2; + let rows: Vec = fixture + .fetch_rows_as(&pool, &query, total_expected) + .await + .expect("Failed to fetch rows after restart"); + + assert!( + rows.len() >= total_expected, + "Expected at least {total_expected} rows after restart, got {}", + rows.len() + ); +} + +#[iggy_harness( + server(connectors_runtime(config_path = "tests/connectors/postgres/sink.toml")), + seed = seeds::connector_stream +)] +async fn parallel_restart_requests_should_not_break_connector( + harness: &TestHarness, + fixture: PostgresSinkFixture, +) { + let client = harness.root_client().await.unwrap(); + let api_url = harness + .connectors_runtime() + .expect("connector runtime should be available") + .http_url(); + let http = Client::new(); + let pool = fixture.create_pool().await.expect("Failed to create pool"); + + fixture.wait_for_table(&pool, SINK_TABLE).await; + + let stream_id: Identifier = seeds::names::STREAM.try_into().unwrap(); + let topic_id: Identifier = seeds::names::TOPIC.try_into().unwrap(); + + wait_for_sink_status(&http, &api_url, ConnectorStatus::Running).await; + + let mut tasks = Vec::new(); + for _ in 0..5 { + let http = http.clone(); + let url = format!("{api_url}/sinks/{SINK_KEY}/restart"); + tasks.push(tokio::spawn(async move { + http.post(&url) + .header("api-key", API_KEY) + .send() + .await + .expect("Failed to call restart endpoint") + })); + } + + let responses = futures::future::join_all(tasks).await; + for resp in responses { + let resp = resp.expect("Task panicked"); + assert_eq!( + resp.status().as_u16(), + 204, + "All restart requests should return 204" + ); + } + + wait_for_sink_status(&http, &api_url, ConnectorStatus::Running).await; + + let batch = create_test_messages(TEST_MESSAGE_COUNT); + let mut messages = build_messages(&batch, 0); + client + .send_messages( + &stream_id, + &topic_id, + &Partitioning::partition_id(0), + &mut messages, + ) + .await + .expect("Failed to send messages after parallel restarts"); + + let query = format!( + "SELECT iggy_offset, iggy_stream, iggy_topic, payload FROM {SINK_TABLE} ORDER BY iggy_offset" + ); + let rows: Vec = fixture + .fetch_rows_as(&pool, &query, TEST_MESSAGE_COUNT) + .await + .expect("Failed to fetch rows after parallel restarts"); + + assert!( + rows.len() >= TEST_MESSAGE_COUNT, + "Expected at least {TEST_MESSAGE_COUNT} rows after parallel restarts, got {}", + rows.len() + ); +} + +async fn wait_for_sink_status( + http: &Client, + api_url: &str, + expected: ConnectorStatus, +) -> SinkInfoResponse { + for _ in 0..POLL_ATTEMPTS { + if let Ok(resp) = http + .get(format!("{api_url}/sinks/{SINK_KEY}")) + .header("api-key", API_KEY) + .send() + .await + && let Ok(info) = resp.json::().await + && info.status == expected + { + return info; + } + sleep(Duration::from_millis(POLL_INTERVAL_MS)).await; + } + panic!("Sink connector did not reach {expected:?} status in time"); +} + +fn build_messages(messages_data: &[TestMessage], id_offset: usize) -> Vec { + messages_data + .iter() + .enumerate() + .map(|(i, msg)| { + let payload = serde_json::to_vec(msg).expect("Failed to serialize message"); + IggyMessage::builder() + .id((id_offset + i + 1) as u128) + .payload(Bytes::from(payload)) + .build() + .expect("Failed to build message") + }) + .collect() +}