From 918572591f383661297f03cc3ddb46bbb202b6fd Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Mon, 8 Jun 2026 12:41:47 +0200 Subject: [PATCH] Basic implementation of batching writes across topics --- src/kafka/activation_batcher.rs | 443 ------------ src/kafka/activation_writer.rs | 1129 +++++++++++++++++++++---------- src/kafka/mod.rs | 1 - src/main.rs | 28 +- 4 files changed, 773 insertions(+), 828 deletions(-) delete mode 100644 src/kafka/activation_batcher.rs diff --git a/src/kafka/activation_batcher.rs b/src/kafka/activation_batcher.rs deleted file mode 100644 index 86c905cd..00000000 --- a/src/kafka/activation_batcher.rs +++ /dev/null @@ -1,443 +0,0 @@ -use std::mem::replace; -use std::sync::Arc; -use std::time::Duration; - -use chrono::Utc; -use futures::future::join_all; -use rdkafka::config::ClientConfig; -use rdkafka::producer::{FutureProducer, FutureRecord}; -use rdkafka::util::Timeout; - -use crate::config::Config; -use crate::runtime_config::RuntimeConfigManager; -use crate::store::activation::Activation; - -use super::consumer::{ - ReduceConfig, ReduceShutdownBehaviour, ReduceShutdownCondition, Reducer, - ReducerWhenFullBehaviour, -}; - -pub struct ActivationBatcherConfig { - pub producer_config: ClientConfig, - pub kafka_topic: String, - pub kafka_long_topic: String, - pub send_timeout_ms: u64, - pub max_batch_time_ms: u64, - pub max_batch_len: usize, - pub max_batch_size: usize, -} - -impl ActivationBatcherConfig { - /// Convert from application configuration into ActivationBatcher config for a - /// single consumed topic. Each consumer has its own batcher, so the topic is - /// passed explicitly rather than derived from "the" consumable topic. - pub fn from_topic(config: &Config, topic_name: &str) -> Self { - Self { - producer_config: config.kafka_producer_config(), - kafka_topic: topic_name.to_owned(), - kafka_long_topic: config.kafka_long_topic.clone(), - send_timeout_ms: config.kafka_send_timeout_ms, - max_batch_time_ms: config.db_insert_batch_max_time_ms, - max_batch_len: config.db_insert_batch_max_len, - max_batch_size: config.db_insert_batch_max_size, - } - } -} - -pub struct ActivationBatcher { - batch: Vec, - batch_size: usize, - forward_batch: Vec>, // payload - config: ActivationBatcherConfig, - runtime_config_manager: Arc, - producer: Arc, - producer_cluster: String, -} - -impl ActivationBatcher { - pub fn new( - config: ActivationBatcherConfig, - runtime_config_manager: Arc, - ) -> Self { - let producer: Arc = Arc::new( - config - .producer_config - .create() - .expect("Could not create kafka producer in activation batcher"), - ); - let producer_cluster = config - .producer_config - .get("bootstrap.servers") - .unwrap() - .to_owned(); - Self { - batch: Vec::with_capacity(config.max_batch_len), - batch_size: 0, - forward_batch: Vec::with_capacity(config.max_batch_len), - config, - runtime_config_manager, - producer, - producer_cluster, - } - } -} - -impl Reducer for ActivationBatcher { - type Input = Activation; - - type Output = Vec; - - async fn reduce(&mut self, t: Self::Input) -> Result<(), anyhow::Error> { - let runtime_config = self.runtime_config_manager.read().await; - let forward_topic = runtime_config - .demoted_topic - .clone() - .unwrap_or(self.config.kafka_long_topic.clone()); - let task_name = &t.taskname; - let namespace = &t.namespace; - - if runtime_config.drop_task_killswitch.contains(task_name) { - metrics::counter!( - "filter.drop_task_killswitch", - "topic" => self.config.kafka_topic.clone(), - "taskname" => task_name.clone(), - ) - .increment(1); - return Ok(()); - } - - if let Some(expires_at) = t.expires_at - && Utc::now() > expires_at - { - metrics::counter!( - "filter.expired_at_consumer", - "topic" => self.config.kafka_topic.clone(), - ) - .increment(1); - return Ok(()); - } - - if runtime_config.demoted_namespaces.contains(namespace) { - if forward_topic == self.config.kafka_topic { - metrics::counter!( - "filter.forward_task_demoted_namespace.skipped", - "topic" => self.config.kafka_topic.clone(), - "namespace" => namespace.clone(), - "taskname" => task_name.clone(), - ) - .increment(1); - } else { - metrics::counter!( - "filter.forward_task_demoted_namespace", - "topic" => self.config.kafka_topic.clone(), - "namespace" => namespace.clone(), - "taskname" => task_name.clone(), - ) - .increment(1); - self.forward_batch.push(t.activation); - return Ok(()); - } - } - - self.batch_size += t.activation.len(); - self.batch.push(t); - - Ok(()) - } - - async fn flush(&mut self) -> Result, anyhow::Error> { - if self.batch.is_empty() && self.forward_batch.is_empty() { - return Ok(None); - } - - metrics::histogram!("consumer.batch_rows", "topic" => self.config.kafka_topic.clone()) - .record(self.batch.len() as f64); - metrics::histogram!("consumer.batch_bytes", "topic" => self.config.kafka_topic.clone()) - .record(self.batch_size as f64); - - // Send all forward batch in parallel - if !self.forward_batch.is_empty() { - let runtime_config = self.runtime_config_manager.read().await; - // The forwarding producer authenticates against the deadletter - // cluster, so default demoted forwarding there too (and consistently - // with upkeep) when no demoted_topic_cluster is configured. - let forward_cluster = - runtime_config - .demoted_topic_cluster - .clone() - .unwrap_or_else(|| { - self.config - .producer_config - .get("bootstrap.servers") - .expect("producer config always sets bootstrap.servers") - .to_string() - }); - if self.producer_cluster != forward_cluster { - let mut new_config = self.config.producer_config.clone(); - new_config.set("bootstrap.servers", &forward_cluster); - self.producer = Arc::new( - new_config - .create() - .expect("Could not create kafka producer in activation batcher"), - ); - self.producer_cluster = forward_cluster; - } - let forward_topic = runtime_config - .demoted_topic - .clone() - .unwrap_or(self.config.kafka_long_topic.clone()); - let sends = self.forward_batch.iter().map(|payload| { - self.producer.send( - FutureRecord::<(), Vec>::to(&forward_topic).payload(payload), - Timeout::After(Duration::from_millis(self.config.send_timeout_ms)), - ) - }); - - let results = join_all(sends).await; - let success_count = results.iter().filter(|r| r.is_ok()).count(); - - let topic = self.config.kafka_topic.clone(); - metrics::histogram!("consumer.forward_attempts", "topic" => topic.clone()) - .record(results.len() as f64); - metrics::histogram!("consumer.forward_successes", "topic" => topic.clone()) - .record(success_count as f64); - metrics::histogram!("consumer.forward_failures", "topic" => topic) - .record((results.len() - success_count) as f64); - - self.forward_batch.clear(); - } - - self.batch_size = 0; - - Ok(Some(replace( - &mut self.batch, - Vec::with_capacity(self.config.max_batch_len), - ))) - } - - fn reset(&mut self) { - self.batch_size = 0; - self.forward_batch.clear(); - self.batch.clear(); - } - - async fn is_full(&self) -> bool { - self.batch.len() >= self.config.max_batch_len - || self.batch_size >= self.config.max_batch_size - } - - fn get_reduce_config(&self) -> ReduceConfig { - ReduceConfig { - shutdown_condition: ReduceShutdownCondition::Signal, - shutdown_behaviour: ReduceShutdownBehaviour::Drop, - when_full_behaviour: ReducerWhenFullBehaviour::Flush, - flush_interval: Some(Duration::from_millis(self.config.max_batch_time_ms)), - } - } -} - -#[cfg(test)] -mod tests { - use std::io::Write; - use std::sync::Arc; - - use chrono::Utc; - use tempfile::NamedTempFile; - - use crate::store::activation::ActivationBuilder; - use crate::test_utils::{TaskActivationBuilder, generate_unique_namespace}; - - use super::{ - ActivationBatcher, ActivationBatcherConfig, Config, Reducer, RuntimeConfigManager, - }; - - #[tokio::test] - async fn test_drop_task_due_to_killswitch() { - let test_yaml = r#" -drop_task_killswitch: - - task_to_be_filtered -demoted_namespaces: - -"#; - - let mut config_file = NamedTempFile::new().unwrap(); - writeln!(config_file, "{}", test_yaml).unwrap(); - config_file.flush().unwrap(); - - let runtime_config = Arc::new( - RuntimeConfigManager::new(Some(config_file.path().to_str().unwrap().to_string())).await, - ); - let mut config = Config::default(); - config.normalize_and_validate().unwrap(); - let config = Arc::new(config); - let mut batcher = ActivationBatcher::new( - ActivationBatcherConfig::from_topic(&config, config.consumable_topics().unwrap()[0].0), - runtime_config, - ); - - let namespace = generate_unique_namespace(); - - let activation_0 = ActivationBuilder::new() - .id("0") - .taskname("task_to_be_filtered") - .namespace(&namespace) - .build(TaskActivationBuilder::new()); - - batcher.reduce(activation_0).await.unwrap(); - assert_eq!(batcher.batch.len(), 0); - } - - #[tokio::test] - async fn test_drop_task_due_to_expiry() { - let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let mut config = Config::default(); - config.normalize_and_validate().unwrap(); - let config = Arc::new(config); - let mut batcher = ActivationBatcher::new( - ActivationBatcherConfig::from_topic(&config, config.consumable_topics().unwrap()[0].0), - runtime_config, - ); - - let namespace = generate_unique_namespace(); - - let activation_0 = ActivationBuilder::new() - .id("0") - .taskname("task_to_be_filtered") - .namespace(&namespace) - .expires_at(Utc::now()) - .build(TaskActivationBuilder::new()); - - batcher.reduce(activation_0).await.unwrap(); - assert_eq!(batcher.batch.len(), 0); - } - - #[tokio::test] - async fn test_close_by_bytes_limit() { - let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let mut config = Config { - db_insert_batch_max_size: 1, - db_insert_batch_max_len: 2, - ..Default::default() - }; - config.normalize_and_validate().unwrap(); - let config = Arc::new(config); - - let mut batcher = ActivationBatcher::new( - ActivationBatcherConfig::from_topic(&config, config.consumable_topics().unwrap()[0].0), - runtime_config, - ); - - let namespace = generate_unique_namespace(); - - let activation_0 = ActivationBuilder::new() - .id("0") - .taskname("taskname") - .namespace(&namespace) - .build(TaskActivationBuilder::new()); - - batcher.reduce(activation_0).await.unwrap(); - assert!(batcher.is_full().await); - batcher.flush().await.unwrap(); - assert!(!batcher.is_full().await) - } - - #[tokio::test] - async fn test_close_by_rows_limit() { - let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); - let mut config = Config { - db_insert_batch_max_size: 100000, - db_insert_batch_max_len: 2, - ..Default::default() - }; - config.normalize_and_validate().unwrap(); - let config = Arc::new(config); - - let mut batcher = ActivationBatcher::new( - ActivationBatcherConfig::from_topic(&config, config.consumable_topics().unwrap()[0].0), - runtime_config, - ); - - let namespace = generate_unique_namespace(); - - let activation_0 = ActivationBuilder::new() - .id("0") - .taskname("taskname") - .namespace(&namespace) - .build(TaskActivationBuilder::new()); - - let activation_1 = ActivationBuilder::new() - .id("1") - .taskname("taskname") - .namespace(&namespace) - .build(TaskActivationBuilder::new()); - - batcher.reduce(activation_0).await.unwrap(); - batcher.reduce(activation_1).await.unwrap(); - assert!(batcher.is_full().await); - batcher.flush().await.unwrap(); - assert!(!batcher.is_full().await) - } - - #[tokio::test] - async fn test_forward_task_due_to_demoted_namespace() { - let test_yaml = r#" -drop_task_killswitch: - - -demoted_namespaces: - - bad_namespace -demoted_topic_cluster: 0.0.0.0:9092 -demoted_topic: taskworker-demoted"#; - - let mut config_file = NamedTempFile::new().unwrap(); - writeln!(config_file, "{}", test_yaml).unwrap(); - config_file.flush().unwrap(); - - let runtime_config = Arc::new( - RuntimeConfigManager::new(Some(config_file.path().to_str().unwrap().to_string())).await, - ); - let mut config = Config::default(); - config.normalize_and_validate().unwrap(); - let config = Arc::new(config); - let mut batcher = ActivationBatcher::new( - ActivationBatcherConfig::from_topic(&config, config.consumable_topics().unwrap()[0].0), - runtime_config, - ); - - let (_, topic_config) = config.consumable_topics().unwrap()[0]; - let cluster_address = config - .cluster(&topic_config.cluster) - .unwrap() - .address - .clone(); - assert_eq!(batcher.producer_cluster, cluster_address); - - let activation_0 = ActivationBuilder::new() - .id("0") - .taskname("task_to_be_filtered") - .namespace("bad_namespace") - .build(TaskActivationBuilder::new()); - - let activation_1 = ActivationBuilder::new() - .id("1") - .taskname("good_task") - .namespace("good_namespace") - .build(TaskActivationBuilder::new()); - - batcher.reduce(activation_0).await.unwrap(); - batcher.reduce(activation_1).await.unwrap(); - - assert_eq!(batcher.batch.len(), 1); - assert_eq!(batcher.forward_batch.len(), 1); - - let flush_result = batcher.flush().await.unwrap(); - assert!(flush_result.is_some()); - assert_eq!(flush_result.as_ref().unwrap().len(), 1); - assert_eq!( - flush_result.as_ref().unwrap()[0].namespace, - "good_namespace" - ); - assert_eq!(flush_result.as_ref().unwrap()[0].taskname, "good_task"); - assert_eq!(batcher.batch.len(), 0); - assert_eq!(batcher.forward_batch.len(), 0); - assert_eq!(batcher.producer_cluster, "0.0.0.0:9092"); - } -} diff --git a/src/kafka/activation_writer.rs b/src/kafka/activation_writer.rs index c9b64fe5..9527407a 100644 --- a/src/kafka/activation_writer.rs +++ b/src/kafka/activation_writer.rs @@ -1,11 +1,45 @@ +//! A single, process-wide activation writer shared by every consumer. +//! +//! All consumers feed it raw deserialized activations via a thin +//! [`ActivationWriterClient`] reduce stage. The shared [`ActivationWriter`] task +//! applies the global filters, forwards demoted-namespace tasks through a single +//! shared producer, and coalesces activations across consumers into larger DB +//! writes. +//! +//! ## Durability / commit coupling +//! +//! `ActivationWriterClient::flush()` ships its locally accumulated batch to the +//! writer and blocks until the writer confirms the batch was **durably stored**. +//! Only then does it return `Ok(Some(()))`, which lets the generic `reduce()` +//! wrapper hand the batch's Kafka messages to the per-consumer `commit` actor. +//! Because each client ships sequentially (it blocks until the previous batch is +//! durable), per-partition offsets are confirmed in order and the existing +//! max-offset commit logic stays correct. +//! +//! There is deliberately no "retry, re-ship" signal back to the client: once the +//! activations are handed to the writer the client no longer holds them, so it +//! must not advance to newer offsets until they are durable. Backpressure and +//! write failures are therefore retried **inside** the writer; the client simply +//! waits. If the writer shuts down it drops the ack channel, the client observes +//! a closed channel, returns `Ok(None)` (no commit), and Kafka redelivers — the +//! at-least-once safety net. + use std::sync::Arc; use std::time::{Duration, Instant}; +use anyhow::Error; use chrono::Utc; +use futures::future::join_all; +use rdkafka::config::ClientConfig; +use rdkafka::producer::{FutureProducer, FutureRecord}; +use rdkafka::util::Timeout; +use tokio::select; +use tokio::sync::{mpsc, oneshot}; use tokio::time::sleep; use tracing::{debug, error, instrument}; use crate::config::Config; +use crate::runtime_config::RuntimeConfigManager; use crate::store::activation::{Activation, ActivationStatus}; use crate::store::traits::ActivationStore; use crate::store::types::DepthCounts; @@ -15,74 +49,365 @@ use super::consumer::{ ReducerWhenFullBehaviour, }; +/// How long to wait between re-checks while the writer is backpressured by the +/// database depth limits. Mirrors the consumer's repoll cadence. +const BACKPRESSURE_POLL_MS: u64 = 250; + +/// Upper bound on how many consumer batches the writer fuses into one DB write, +/// expressed as a multiple of `db_insert_batch_max_len`. Bounds the size of a +/// single INSERT while still allowing several consumers' batches to coalesce. +const COALESCE_FACTOR: usize = 4; + +/// A batch shipped from one [`ActivationWriterClient`] to the writer, with a one-shot +/// channel the writer signals once the batch is durably stored. +pub struct ActivationWriteRequest { + /// The consumed topic the activations came from. Used for the + /// demoted-namespace self-forward guard and for metric tags. + pub source_topic: Arc, + pub activations: Vec, + /// Signalled (with `()`) once the batch has been durably persisted. + pub ack: oneshot::Sender<()>, +} + +/// Activations from one or more requests, classified by disposition and ready +/// to be forwarded/written. The `acks` are fired together once the DB write of +/// `db_batch` succeeds. +struct PendingWrite { + db_batch: Vec, + forward_payloads: Vec>, + acks: Vec>, +} + pub struct ActivationWriterConfig { - /// The consumed topic this writer belongs to, used as a metric tag. Each - /// consumer has its own writer, so writer metrics are per-topic. - pub topic: String, - pub max_buf_len: usize, + pub producer_config: ClientConfig, + pub kafka_long_topic: String, + pub send_timeout_ms: u64, + /// Maximum number of activations fused into one DB write. + pub max_combined_len: usize, pub max_pending_activations: usize, pub max_processing_activations: usize, pub max_delay_activations: usize, pub db_max_size: Option, pub write_failure_backoff_ms: u64, + pub channel_capacity: usize, } impl ActivationWriterConfig { - /// Convert from application configuration into ActivationWriter config for a - /// single consumed topic. - pub fn from_topic(config: &Config, topic: &str) -> Self { + pub fn from_config(config: &Config) -> Self { Self { - topic: topic.to_owned(), - db_max_size: config.db_max_size, - max_buf_len: config.db_insert_batch_max_len, + producer_config: config.kafka_producer_config(), + kafka_long_topic: config.kafka_long_topic.clone(), + send_timeout_ms: config.kafka_send_timeout_ms, + max_combined_len: config.db_insert_batch_max_len * COALESCE_FACTOR, max_pending_activations: config.max_pending_count, max_processing_activations: config.max_processing_count, max_delay_activations: config.max_delay_count, + db_max_size: config.db_max_size, write_failure_backoff_ms: config.db_write_failure_backoff_ms, + channel_capacity: config.db_insert_batch_max_len, } } } -pub struct ActivationWriter { +/// Spawn the shared activation writer. Returns the sender used to construct +/// [`ActivationWriterClient`]s and the task handle (so the caller can await its final +/// drain at shutdown). The writer exits, after handling shutdown, once every +/// sender has been dropped. +pub fn spawn( + store: Arc, + runtime_config_manager: Arc, config: ActivationWriterConfig, +) -> ( + mpsc::Sender, + tokio::task::JoinHandle>, +) { + let (tx, rx) = mpsc::channel(config.channel_capacity); + + let producer: FutureProducer = config + .producer_config + .create() + .expect("Could not create kafka producer in shared writer"); + let producer_cluster = config + .producer_config + .get("bootstrap.servers") + .expect("producer config always sets bootstrap.servers") + .to_owned(); + + let writer = ActivationWriter { + store, + runtime_config_manager, + producer: Arc::new(producer), + producer_cluster, + config, + rx, + }; + + let handle = crate::tokio::spawn(writer.run()); + + (tx, handle) +} + +struct ActivationWriter { store: Arc, - batch: Option>, + runtime_config_manager: Arc, + producer: Arc, + producer_cluster: String, + config: ActivationWriterConfig, + rx: mpsc::Receiver, } impl ActivationWriter { - pub fn new(store: Arc, config: ActivationWriterConfig) -> Self { - Self { - config, - store, - batch: None, + #[instrument(skip_all)] + async fn run(mut self) -> Result<(), Error> { + let guard = elegant_departure::get_shutdown_guard().shutdown_on_drop(); + + loop { + // Gather phase: block for the first request, then greedily drain + // whatever else is already queued so concurrently-shipped batches + // from different consumers fuse into one write. The store latency of + // the previous cycle acts as the natural gather window (group commit). + let first = select! { + biased; + _ = guard.wait() => break, + req = self.rx.recv() => req, + }; + let Some(first) = first else { + // All clients dropped their senders: orderly drain complete. + break; + }; + + let mut raw = vec![first]; + let mut estimate = raw[0].activations.len(); + while estimate < self.config.max_combined_len { + match self.rx.try_recv() { + Ok(req) => { + estimate += req.activations.len(); + raw.push(req); + } + Err(_) => break, + } + } + + let coalesced = raw.len(); + metrics::histogram!("consumer.activation_writer.coalesced_requests") + .record(coalesced as f64); + + let mut pending = self.classify(raw).await; + + // Forwards are best-effort and sent exactly once, before the write + // loop, so a store retry never re-forwards (matches the old batcher, + // which forwarded independently of the writer). + self.forward(&pending.forward_payloads).await; + + if self.write(&mut pending, &guard).await.is_err() { + break; + } + + for ack in pending.acks { + // Receiver gone (client revoked/aborted) just means those offsets + // won't be committed and the messages get redelivered. + let _ = ack.send(()); + } } + + debug!("Shared writer shutdown complete"); + Ok(()) } -} -impl Reducer for ActivationWriter { - type Input = Vec; + /// Apply the global filters and split requests into a DB batch and a set of + /// forward payloads. Killswitched/expired activations are dropped (their + /// offsets still commit once the request is acked). + async fn classify(&self, raw: Vec) -> PendingWrite { + let runtime_config = self.runtime_config_manager.read().await; + let forward_topic = runtime_config + .demoted_topic + .clone() + .unwrap_or_else(|| self.config.kafka_long_topic.clone()); + + let mut db_batch = Vec::new(); + let mut forward_payloads = Vec::new(); + let mut acks = Vec::with_capacity(raw.len()); + + for ActivationWriteRequest { + source_topic, + activations, + ack, + } in raw + { + acks.push(ack); + + for activation in activations { + if runtime_config + .drop_task_killswitch + .contains(&activation.taskname) + { + metrics::counter!( + "filter.drop_task_killswitch", + "topic" => source_topic.to_string(), + "taskname" => activation.taskname.clone(), + ) + .increment(1); + continue; + } + + if let Some(expires_at) = activation.expires_at + && Utc::now() > expires_at + { + metrics::counter!( + "filter.expired_at_consumer", + "topic" => source_topic.to_string(), + ) + .increment(1); + continue; + } + + if runtime_config + .demoted_namespaces + .contains(&activation.namespace) + { + if forward_topic.as_str() == source_topic.as_ref() { + // Already on the demoted topic; don't forward to self, + // write it normally instead. + metrics::counter!( + "filter.forward_task_demoted_namespace.skipped", + "topic" => source_topic.to_string(), + "namespace" => activation.namespace.clone(), + "taskname" => activation.taskname.clone(), + ) + .increment(1); + } else { + metrics::counter!( + "filter.forward_task_demoted_namespace", + "topic" => source_topic.to_string(), + "namespace" => activation.namespace.clone(), + "taskname" => activation.taskname.clone(), + ) + .increment(1); + forward_payloads.push(activation.activation); + continue; + } + } + + db_batch.push(activation); + } + } + + PendingWrite { + db_batch, + forward_payloads, + acks, + } + } - type Output = (); + /// Forward demoted-namespace payloads to the long/demoted topic. Best-effort: + /// failures are counted but not retried. + async fn forward(&mut self, payloads: &[Vec]) { + if payloads.is_empty() { + return; + } - async fn reduce(&mut self, batch: Self::Input) -> Result<(), anyhow::Error> { - assert!(self.batch.is_none()); - self.batch = Some(batch); - Ok(()) + let runtime_config = self.runtime_config_manager.read().await; + // The forwarding producer authenticates against the deadletter cluster, + // so default demoted forwarding there too when no cluster is configured. + let forward_cluster = runtime_config + .demoted_topic_cluster + .clone() + .unwrap_or_else(|| { + self.config + .producer_config + .get("bootstrap.servers") + .expect("producer config always sets bootstrap.servers") + .to_string() + }); + if self.producer_cluster != forward_cluster { + let mut new_config = self.config.producer_config.clone(); + new_config.set("bootstrap.servers", &forward_cluster); + self.producer = Arc::new( + new_config + .create() + .expect("Could not create kafka producer in shared writer"), + ); + self.producer_cluster = forward_cluster; + } + let forward_topic = runtime_config + .demoted_topic + .clone() + .unwrap_or_else(|| self.config.kafka_long_topic.clone()); + + let sends = payloads.iter().map(|payload| { + self.producer.send( + FutureRecord::<(), Vec>::to(&forward_topic).payload(payload), + Timeout::After(Duration::from_millis(self.config.send_timeout_ms)), + ) + }); + let results = join_all(sends).await; + let success_count = results.iter().filter(|r| r.is_ok()).count(); + + metrics::histogram!("consumer.forward_attempts").record(results.len() as f64); + metrics::histogram!("consumer.forward_successes").record(success_count as f64); + metrics::histogram!("consumer.forward_failures") + .record((results.len() - success_count) as f64); } - #[instrument(skip_all)] - async fn flush(&mut self) -> Result, anyhow::Error> { - let Some(ref batch) = self.batch else { - return Ok(None); - }; + /// Write `pending.db_batch`, waiting out depth-limit backpressure and + /// retrying write failures. Returns `Err(())` if a shutdown was observed + /// while waiting (the caller then exits without acking, so the batch is + /// redelivered). + async fn write( + &self, + pending: &mut PendingWrite, + guard: &elegant_departure::ShutdownGuard, + ) -> Result<(), ()> { + loop { + if self.is_backpressured(&pending.db_batch).await { + select! { + _ = guard.wait() => return Err(()), + _ = sleep(Duration::from_millis(BACKPRESSURE_POLL_MS)) => continue, + } + } - // If batch is empty (all tasks were forwarded), just mark as complete - if batch.is_empty() { - self.batch.take(); - return Ok(Some(())); + if pending.db_batch.is_empty() { + return Ok(()); + } + + let write_start = Instant::now(); + match self.store.store(&pending.db_batch).await { + Ok(entries) => { + let lag = Utc::now() + - pending + .db_batch + .iter() + .map(|item| item.received_at) + .min_by_key(|item| item.timestamp()) + .unwrap(); + metrics::histogram!("consumer.inflight_activation_writer.write_to_store") + .record(write_start.elapsed()); + metrics::histogram!("consumer.inflight_activation_writer.insert_lag") + .record(lag.num_seconds() as f64); + metrics::counter!("consumer.inflight_activation_writer.stored") + .increment(entries); + return Ok(()); + } + Err(err) => { + error!("Unable to write activations to store: {}", err); + metrics::counter!("consumer.inflight_activation_writer.write_failed") + .increment(1); + select! { + _ = guard.wait() => return Err(()), + _ = sleep(Duration::from_millis(self.config.write_failure_backoff_ms)) => continue, + } + } + } + } + } + + /// Depth-limit backpressure, evaluated over the combined batch (more correct + /// than each consumer deciding independently). + async fn is_backpressured(&self, db_batch: &[Activation]) -> bool { + if db_batch.is_empty() { + return false; } - // Check if writing the batch would exceed the limits let DepthCounts { pending, delay, @@ -94,8 +419,8 @@ impl Reducer for ActivationWriter { .await .expect("Error communicating with activation store"); - let exceeded_pending_limit = pending + batch.len() > self.config.max_pending_activations; - let exceeded_delay_limit = delay + batch.len() > self.config.max_delay_activations; + let exceeded_pending_limit = pending + db_batch.len() > self.config.max_pending_activations; + let exceeded_delay_limit = delay + db_batch.len() > self.config.max_delay_activations; let exceeded_processing_limit = processing + claimed >= self.config.max_processing_activations; let exceeded_db_size = if let Some(db_max_size) = self.config.db_max_size { @@ -108,20 +433,13 @@ impl Reducer for ActivationWriter { false }; - // Check if the entire batch is either pending or delay - let has_delay = batch + let has_delay = db_batch .iter() .any(|activation| activation.status == ActivationStatus::Delay); - let has_pending = batch + let has_pending = db_batch .iter() .any(|activation| activation.status == ActivationStatus::Pending); - // Backpressure if any of these conditions are met: - // 1. The processing limit is exceeded - // 2. The delay limit is exceeded AND either: - // a. There are delay activations in the batch, OR - // b. The pending limit is also exceeded - // 3. The pending limit is exceeded AND there are pending activations if exceeded_processing_limit || exceeded_db_size || exceeded_delay_limit && (has_delay || exceeded_pending_limit) @@ -138,440 +456,505 @@ impl Reducer for ActivationWriter { }; metrics::counter!( "consumer.inflight_activation_writer.backpressure", - "topic" => self.config.topic.clone(), "reason" => reason, ) .increment(1); + true + } else { + false + } + } +} + +/// Per-consumer reduce stage. Accumulates deserialized activations and ships +/// them to the shared writer, blocking until they are durably stored. +pub struct ActivationWriterClient { + tx: mpsc::Sender, + source_topic: Arc, + batch: Vec, + batch_size: usize, + max_batch_len: usize, + max_batch_size: usize, + flush_interval: Duration, +} + +impl ActivationWriterClient { + pub fn new( + tx: mpsc::Sender, + source_topic: &str, + config: &Config, + ) -> Self { + Self { + tx, + source_topic: Arc::from(source_topic), + batch: Vec::with_capacity(config.db_insert_batch_max_len), + batch_size: 0, + max_batch_len: config.db_insert_batch_max_len, + max_batch_size: config.db_insert_batch_max_size, + flush_interval: Duration::from_millis(config.db_insert_batch_max_time_ms), + } + } +} + +impl Reducer for ActivationWriterClient { + type Input = Activation; + type Output = (); + + async fn reduce(&mut self, activation: Self::Input) -> Result<(), Error> { + self.batch_size += activation.activation.len(); + self.batch.push(activation); + Ok(()) + } + + #[instrument(skip_all)] + async fn flush(&mut self) -> Result, Error> { + if self.batch.is_empty() { + // Nothing buffered (e.g. a timer tick on an idle consumer): there are + // still in-flight messages to commit, so signal completion. + return Ok(Some(())); + } + + let activations = std::mem::take(&mut self.batch); + self.batch_size = 0; + + let (ack_tx, ack_rx) = oneshot::channel(); + let request = ActivationWriteRequest { + source_topic: self.source_topic.clone(), + activations, + ack: ack_tx, + }; + + if self.tx.send(request).await.is_err() { + // Writer is gone (shutting down). Don't commit; messages redeliver. return Ok(None); } - // I suspect that 'store' occasionally hangs and want to confirm - let insert_id = Utc::now().timestamp_millis(); - debug!("Preparing insert {:?}", insert_id); - - let write_to_store_start = Instant::now(); - let res = self.store.store(batch).await; - - // If every "preparing" has a matching "completed" we are good - debug!("Completed insert {:?}", insert_id); - - match res { - Ok(entries) => { - let batch = self.batch.take().unwrap(); - let lag = Utc::now() - - batch - .iter() - .map(|item| item.received_at) - .min_by_key(|item| item.timestamp()) - .unwrap(); - - metrics::histogram!( - "consumer.inflight_activation_writer.write_to_store", - "topic" => self.config.topic.clone(), - ) - .record(write_to_store_start.elapsed()); - metrics::histogram!( - "consumer.inflight_activation_writer.insert_lag", - "topic" => self.config.topic.clone(), - ) - .record(lag.num_seconds() as f64); - metrics::counter!( - "consumer.inflight_activation_writer.stored", - "topic" => self.config.topic.clone(), - ) - .increment(entries); - debug!( - "Inserted {:?} entries with max lag: {:?}s", - entries, - lag.num_seconds() - ); - Ok(Some(())) - } - Err(err) => { - error!("Unable to write to sqlite: {}", err); - metrics::counter!( - "consumer.inflight_activation_writer.write_failed", - "topic" => self.config.topic.clone(), - ) - .increment(1); - sleep(Duration::from_millis(self.config.write_failure_backoff_ms)).await; - Ok(None) - } + match ack_rx.await { + // Durably stored: commit the offsets for this batch. + Ok(()) => Ok(Some(())), + // Writer dropped the ack (shutdown). Don't commit; messages redeliver. + Err(_) => Ok(None), } } - fn reset(&mut self) {} + fn reset(&mut self) { + self.batch.clear(); + self.batch_size = 0; + } async fn is_full(&self) -> bool { - self.batch.is_some() + self.batch.len() >= self.max_batch_len || self.batch_size >= self.max_batch_size } fn get_reduce_config(&self) -> ReduceConfig { ReduceConfig { when_full_behaviour: ReducerWhenFullBehaviour::Flush, - shutdown_behaviour: ReduceShutdownBehaviour::Flush, + // Drop the un-shipped buffer on revoke; those offsets aren't + // committed and the new owner reprocesses them. + shutdown_behaviour: ReduceShutdownBehaviour::Drop, shutdown_condition: ReduceShutdownCondition::Signal, - flush_interval: None, + flush_interval: Some(self.flush_interval), } } } #[cfg(test)] mod tests { - use chrono::DateTime; + use std::io::Write; + use std::sync::Arc; + + use chrono::Utc; use rstest::rstest; + use tempfile::NamedTempFile; + use tokio::sync::{mpsc, oneshot}; - use crate::store::activation::{ActivationBuilder, ActivationStatus}; + use crate::config::Config; + use crate::runtime_config::RuntimeConfigManager; + use crate::store::activation::ActivationBuilder; + use crate::store::traits::ActivationStore; use crate::test_utils::{ TaskActivationBuilder, create_test_store, generate_unique_namespace, make_activations, }; - use super::{ActivationWriter, ActivationWriterConfig, Reducer}; + use super::{ + ActivationWriteRequest, ActivationWriter, ActivationWriterClient, ActivationWriterConfig, + Reducer, + }; + + /// A validated default config (populates `kafka_topics`, needed to build the + /// forward producer). + fn validated_config() -> Config { + let mut config = Config::default(); + config.normalize_and_validate().unwrap(); + config + } + + /// Build an writer wired to `store`/`runtime_config` for exercising the + /// writer's internal methods directly. The receiver is a throwaway. + fn make_writer( + store: Arc, + runtime_config_manager: Arc, + config: ActivationWriterConfig, + ) -> ActivationWriter { + let producer = config + .producer_config + .create() + .expect("could not create test producer"); + let producer_cluster = config + .producer_config + .get("bootstrap.servers") + .unwrap() + .to_owned(); + let (_tx, rx) = mpsc::channel(1); + ActivationWriter { + store, + runtime_config_manager, + producer: Arc::new(producer), + producer_cluster, + config, + rx, + } + } + + fn request( + source_topic: &str, + activations: Vec, + ) -> ActivationWriteRequest { + let (ack, _ack_rx) = oneshot::channel(); + ActivationWriteRequest { + source_topic: Arc::from(source_topic), + activations, + ack, + } + } + + async fn runtime_config_from_yaml(yaml: &str) -> Arc { + let mut file = NamedTempFile::new().unwrap(); + writeln!(file, "{}", yaml).unwrap(); + file.flush().unwrap(); + Arc::new(RuntimeConfigManager::new(Some(file.path().to_str().unwrap().to_string())).await) + } #[tokio::test] #[rstest] #[case::sqlite("sqlite")] #[case::postgres("postgres")] - async fn test_writer_flush_batch(#[case] adapter: &str) { + async fn test_classify_drops_killswitch(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let writer_config = ActivationWriterConfig { - topic: "test-topic".to_string(), - db_max_size: None, - max_buf_len: 100, - max_pending_activations: 10, - max_processing_activations: 10, - max_delay_activations: 10, - write_failure_backoff_ms: 4000, - }; - let mut writer = ActivationWriter::new(store, writer_config); + let runtime_config = runtime_config_from_yaml( + r#" +drop_task_killswitch: + - task_to_be_filtered +demoted_namespaces: + -"#, + ) + .await; + let writer = make_writer( + store.clone(), + runtime_config, + ActivationWriterConfig::from_config(&validated_config()), + ); - let received_at = DateTime::from_timestamp_nanos(0); let namespace = generate_unique_namespace(); + let dropped = ActivationBuilder::new() + .id("0") + .taskname("task_to_be_filtered") + .namespace(&namespace) + .build(TaskActivationBuilder::new()); + let kept = ActivationBuilder::new() + .id("1") + .taskname("good_task") + .namespace(&namespace) + .build(TaskActivationBuilder::new()); - let batch = vec![ - ActivationBuilder::new() - .id("0") - .taskname("pending_task") - .namespace(&namespace) - .received_at(received_at) - .build(TaskActivationBuilder::new()), - ActivationBuilder::new() - .id("1") - .taskname("delay_task") - .namespace(&namespace) - .received_at(received_at) - .build(TaskActivationBuilder::new()), - ]; - - writer.reduce(batch).await.unwrap(); - writer.flush().await.unwrap(); - let count_pending = writer.store.count_pending_activations().await.unwrap(); - let count_delay = writer - .store - .count_by_status(ActivationStatus::Delay) - .await - .unwrap(); - assert_eq!(count_pending + count_delay, 2); - writer.store.remove_db().await.unwrap(); + let pending = writer + .classify(vec![request("topic", vec![dropped, kept])]) + .await; + assert_eq!(pending.db_batch.len(), 1); + assert_eq!(pending.db_batch[0].taskname, "good_task"); + assert!(pending.forward_payloads.is_empty()); + store.remove_db().await.unwrap(); } #[tokio::test] #[rstest] #[case::sqlite("sqlite")] #[case::postgres("postgres")] - async fn test_writer_flush_only_pending(#[case] adapter: &str) { + async fn test_classify_drops_expired(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let writer_config = ActivationWriterConfig { - topic: "test-topic".to_string(), - db_max_size: None, - max_buf_len: 100, - max_pending_activations: 10, - max_processing_activations: 10, - max_delay_activations: 10, - write_failure_backoff_ms: 4000, - }; - let mut writer = ActivationWriter::new(store, writer_config); + let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); + let writer = make_writer( + store.clone(), + runtime_config, + ActivationWriterConfig::from_config(&validated_config()), + ); - let received_at = DateTime::from_timestamp_nanos(0); let namespace = generate_unique_namespace(); + let expired = ActivationBuilder::new() + .id("0") + .taskname("task") + .namespace(&namespace) + .expires_at(Utc::now()) + .build(TaskActivationBuilder::new()); - let batch = vec![ - ActivationBuilder::new() - .id("0") - .taskname("pending_task") - .namespace(&namespace) - .received_at(received_at) - .build(TaskActivationBuilder::new()), - ]; - - writer.reduce(batch).await.unwrap(); - writer.flush().await.unwrap(); - let count_pending = writer.store.count_pending_activations().await.unwrap(); - assert_eq!(count_pending, 1); - writer.store.remove_db().await.unwrap(); + let pending = writer.classify(vec![request("topic", vec![expired])]).await; + assert!(pending.db_batch.is_empty()); + assert!(pending.forward_payloads.is_empty()); + store.remove_db().await.unwrap(); } #[tokio::test] #[rstest] #[case::sqlite("sqlite")] #[case::postgres("postgres")] - async fn test_writer_flush_only_delay(#[case] adapter: &str) { + async fn test_classify_forwards_demoted_namespace(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let writer_config = ActivationWriterConfig { - topic: "test-topic".to_string(), - db_max_size: None, - max_buf_len: 100, - max_pending_activations: 0, - max_processing_activations: 10, - max_delay_activations: 10, - write_failure_backoff_ms: 4000, - }; - let mut writer = ActivationWriter::new(store, writer_config); + let runtime_config = runtime_config_from_yaml( + r#" +drop_task_killswitch: + - +demoted_namespaces: + - bad_namespace +demoted_topic_cluster: 0.0.0.0:9092 +demoted_topic: taskworker-demoted"#, + ) + .await; + let writer = make_writer( + store.clone(), + runtime_config, + ActivationWriterConfig::from_config(&validated_config()), + ); + + let demoted = ActivationBuilder::new() + .id("0") + .taskname("task") + .namespace("bad_namespace") + .build(TaskActivationBuilder::new()); + let normal = ActivationBuilder::new() + .id("1") + .taskname("task") + .namespace("good_namespace") + .build(TaskActivationBuilder::new()); - let received_at = DateTime::from_timestamp_nanos(0); - let namespace = generate_unique_namespace(); + let pending = writer + .classify(vec![request("taskworker", vec![demoted, normal])]) + .await; + assert_eq!(pending.forward_payloads.len(), 1); + assert_eq!(pending.db_batch.len(), 1); + assert_eq!(pending.db_batch[0].namespace, "good_namespace"); + store.remove_db().await.unwrap(); + } - let batch = vec![ - ActivationBuilder::new() - .id("0") - .taskname("pending_task") - .namespace(&namespace) - .received_at(received_at) - .status(ActivationStatus::Delay) - .build(TaskActivationBuilder::new()), - ]; + #[tokio::test] + #[rstest] + #[case::sqlite("sqlite")] + #[case::postgres("postgres")] + async fn test_classify_skips_self_forward(#[case] adapter: &str) { + let store = create_test_store(adapter).await; + // Forward topic equals the source topic: demoted tasks are written + // normally instead of being forwarded to themselves. + let runtime_config = runtime_config_from_yaml( + r#" +drop_task_killswitch: + - +demoted_namespaces: + - bad_namespace +demoted_topic: taskworker-demoted"#, + ) + .await; + let writer = make_writer( + store.clone(), + runtime_config, + ActivationWriterConfig::from_config(&validated_config()), + ); + + let demoted = ActivationBuilder::new() + .id("0") + .taskname("task") + .namespace("bad_namespace") + .build(TaskActivationBuilder::new()); - writer.reduce(batch).await.unwrap(); - writer.flush().await.unwrap(); - let count_delay = writer - .store - .count_by_status(ActivationStatus::Delay) - .await - .unwrap(); - assert_eq!(count_delay, 1); - writer.store.remove_db().await.unwrap(); + let pending = writer + .classify(vec![request("taskworker-demoted", vec![demoted])]) + .await; + assert!(pending.forward_payloads.is_empty()); + assert_eq!(pending.db_batch.len(), 1); + store.remove_db().await.unwrap(); } #[tokio::test] #[rstest] #[case::sqlite("sqlite")] #[case::postgres("postgres")] - async fn test_writer_backpressure_pending_limit_reached(#[case] adapter: &str) { + async fn test_classify_coalesces_across_topics(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let writer_config = ActivationWriterConfig { - topic: "test-topic".to_string(), - db_max_size: None, - max_buf_len: 100, - max_pending_activations: 0, - max_processing_activations: 10, - max_delay_activations: 0, - write_failure_backoff_ms: 4000, - }; - let mut writer = ActivationWriter::new(store, writer_config); + let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); + let writer = make_writer( + store.clone(), + runtime_config, + ActivationWriterConfig::from_config(&validated_config()), + ); - let received_at = DateTime::from_timestamp_nanos(0); let namespace = generate_unique_namespace(); - - let batch = vec![ + let mk = |id: &str| { ActivationBuilder::new() - .id("0") - .taskname("pending_task") + .id(id) + .taskname("task") .namespace(&namespace) - .received_at(received_at) - .build(TaskActivationBuilder::new()), - ActivationBuilder::new() - .id("1") - .taskname("delay_task") - .namespace(&namespace) - .received_at(received_at) - .build(TaskActivationBuilder::new()), - ]; + .build(TaskActivationBuilder::new()) + }; - writer.reduce(batch).await.unwrap(); - writer.flush().await.unwrap(); - let count_pending = writer.store.count_pending_activations().await.unwrap(); - assert_eq!(count_pending, 0); - let count_delay = writer - .store - .count_by_status(ActivationStatus::Delay) - .await - .unwrap(); - assert_eq!(count_delay, 0); - writer.store.remove_db().await.unwrap(); + // Two requests from different consumers fuse into one DB batch, with one + // ack per request. + let pending = writer + .classify(vec![ + request("topic-a", vec![mk("0"), mk("1")]), + request("topic-b", vec![mk("2")]), + ]) + .await; + assert_eq!(pending.db_batch.len(), 3); + assert_eq!(pending.acks.len(), 2); + store.remove_db().await.unwrap(); } #[tokio::test] #[rstest] #[case::sqlite("sqlite")] #[case::postgres("postgres")] - async fn test_writer_backpressure_only_delay_limit_reached_and_entire_batch_is_pending( - #[case] adapter: &str, - ) { + async fn test_is_backpressured_pending_limit(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let writer_config = ActivationWriterConfig { - topic: "test-topic".to_string(), - db_max_size: None, - max_buf_len: 100, - max_pending_activations: 10, - max_processing_activations: 10, - max_delay_activations: 0, - write_failure_backoff_ms: 4000, - }; - let mut writer = ActivationWriter::new(store, writer_config); + let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); + let mut config = ActivationWriterConfig::from_config(&validated_config()); + config.max_pending_activations = 0; + config.max_delay_activations = 0; + let writer = make_writer(store.clone(), runtime_config, config); - let received_at = DateTime::from_timestamp_nanos(0); let namespace = generate_unique_namespace(); - let batch = vec![ ActivationBuilder::new() .id("0") - .taskname("pending_task") + .taskname("task") .namespace(&namespace) - .received_at(received_at) - .build(TaskActivationBuilder::new()), - ActivationBuilder::new() - .id("1") - .taskname("pending_task") - .namespace(&namespace) - .received_at(received_at) .build(TaskActivationBuilder::new()), ]; - - writer.reduce(batch).await.unwrap(); - writer.flush().await.unwrap(); - let count_pending = writer.store.count_pending_activations().await.unwrap(); - assert_eq!(count_pending, 2); - let count_delay = writer - .store - .count_by_status(ActivationStatus::Delay) - .await - .unwrap(); - assert_eq!(count_delay, 0); - writer.store.remove_db().await.unwrap(); + assert!(writer.is_backpressured(&batch).await); + // An empty batch is never backpressured. + assert!(!writer.is_backpressured(&[]).await); + store.remove_db().await.unwrap(); } #[tokio::test] #[rstest] #[case::sqlite("sqlite")] #[case::postgres("postgres")] - async fn test_writer_backpressure_processing_limit_reached(#[case] adapter: &str) { + async fn test_is_backpressured_db_size_limit(#[case] adapter: &str) { let store = create_test_store(adapter).await; - let writer_config = ActivationWriterConfig { - topic: "test-topic".to_string(), - db_max_size: None, - max_buf_len: 100, - max_pending_activations: 10, - max_processing_activations: 1, - max_delay_activations: 0, - write_failure_backoff_ms: 4000, - }; - - let received_at = DateTime::from_timestamp_nanos(0); - let namespace = generate_unique_namespace(); + let runtime_config = Arc::new(RuntimeConfigManager::new(None).await); + let mut config = ActivationWriterConfig::from_config(&validated_config()); + // 200 rows is ~50KB. + config.db_max_size = Some(50_000); + config.max_pending_activations = 5000; + config.max_processing_activations = 5000; + let first_round = make_activations(200); + store.store(&first_round).await.unwrap(); + assert!(store.db_size().await.unwrap() > 50_000); - let existing_activation = ActivationBuilder::new() - .id("existing") - .taskname("existing_task") - .namespace(&namespace) - .received_at(received_at) - .status(ActivationStatus::Processing) - .build(TaskActivationBuilder::new()); + let writer = make_writer(store.clone(), runtime_config, config); + let batch = make_activations(1); + assert!(writer.is_backpressured(&batch).await); + store.remove_db().await.unwrap(); + } - store.store(&[existing_activation]).await.unwrap(); + #[tokio::test] + async fn test_client_is_full() { + let config = Config { + db_insert_batch_max_len: 2, + db_insert_batch_max_size: 100_000, + ..Default::default() + }; + let (tx, _rx) = mpsc::channel(1); + let mut client = ActivationWriterClient::new(tx, "topic", &config); - let mut writer = ActivationWriter::new(store.clone(), writer_config); - let batch = vec![ - ActivationBuilder::new() - .id("0") - .taskname("pending_task") - .namespace(&namespace) - .received_at(received_at) - .build(TaskActivationBuilder::new()), + let namespace = generate_unique_namespace(); + let mk = |id: &str| { ActivationBuilder::new() - .id("1") - .taskname("delay_task") + .id(id) + .taskname("task") .namespace(&namespace) - .received_at(received_at) - .build(TaskActivationBuilder::new()), - ]; - - writer.reduce(batch).await.unwrap(); - let flush_result = writer.flush().await.unwrap(); + .build(TaskActivationBuilder::new()) + }; - assert!(flush_result.is_none()); + assert!(!client.is_full().await); + client.reduce(mk("0")).await.unwrap(); + assert!(!client.is_full().await); + client.reduce(mk("1")).await.unwrap(); + assert!(client.is_full().await); + } - let count_pending = writer.store.count_pending_activations().await.unwrap(); - assert_eq!(count_pending, 0); - let count_delay = writer - .store - .count_by_status(ActivationStatus::Delay) - .await - .unwrap(); - assert_eq!(count_delay, 0); - let count_processing = writer - .store - .count_by_status(ActivationStatus::Processing) - .await - .unwrap(); - // Only the existing processing activation should remain, new ones should be blocked - assert_eq!(count_processing, 1); - // TODO: Because the store and the writer both access the DB, both need to be cleaned up. - // Uncomment this when we figure out how to do that cleanly. - // writer.store.remove_db().await.unwrap(); + #[tokio::test] + async fn test_client_flush_empty_commits() { + let config = Config::default(); + let (tx, _rx) = mpsc::channel(1); + let mut client = ActivationWriterClient::new(tx, "topic", &config); + // An empty buffer still reports success so idle-timer flushes commit any + // in-flight (dropped/forwarded) offsets. + assert_eq!(client.flush().await.unwrap(), Some(())); } #[tokio::test] - #[rstest] - #[case::sqlite("sqlite")] - #[case::postgres("postgres")] - async fn test_writer_backpressure_db_size_limit_reached(#[case] adapter: &str) { - let store = create_test_store(adapter).await; - let writer_config = ActivationWriterConfig { - topic: "test-topic".to_string(), - // 200 rows is ~50KB - db_max_size: Some(50_000), - max_buf_len: 100, - max_pending_activations: 5000, - max_processing_activations: 5000, - max_delay_activations: 0, - write_failure_backoff_ms: 4000, - }; - let first_round = make_activations(200); - store.store(&first_round).await.unwrap(); - assert!(store.db_size().await.unwrap() > 50_000); + async fn test_client_flush_commits_after_durable_ack() { + let config = Config::default(); + let (tx, mut rx) = mpsc::channel::(1); + let mut client = ActivationWriterClient::new(tx, "topic", &config); - // Make more activations that won't be stored. - let second_round = make_activations(10); + let namespace = generate_unique_namespace(); + client + .reduce( + ActivationBuilder::new() + .id("0") + .taskname("task") + .namespace(&namespace) + .build(TaskActivationBuilder::new()), + ) + .await + .unwrap(); - let mut writer = ActivationWriter::new(store.clone(), writer_config); - writer.reduce(second_round).await.unwrap(); - let flush_result = writer.flush().await.unwrap(); - assert!(flush_result.is_none()); + // Stand in for the writer: receive the request, then ack durability. + let writer = tokio::spawn(async move { + let req = rx.recv().await.unwrap(); + assert_eq!(req.activations.len(), 1); + req.ack.send(()).unwrap(); + }); - let count_pending = writer.store.count_pending_activations().await.unwrap(); - assert_eq!(count_pending, 200); - writer.store.remove_db().await.unwrap(); + assert_eq!(client.flush().await.unwrap(), Some(())); + writer.await.unwrap(); } #[tokio::test] - #[rstest] - #[case::sqlite("sqlite")] - #[case::postgres("postgres")] - async fn test_writer_flush_empty_batch(#[case] adapter: &str) { - let store = create_test_store(adapter).await; - let writer_config = ActivationWriterConfig { - topic: "test-topic".to_string(), - db_max_size: None, - max_buf_len: 100, - max_pending_activations: 10, - max_processing_activations: 10, - max_delay_activations: 10, - write_failure_backoff_ms: 4000, - }; - let mut writer = ActivationWriter::new(store.clone(), writer_config); - writer.reduce(vec![]).await.unwrap(); - let flush_result = writer.flush().await.unwrap(); - assert!(flush_result.is_some()); - writer.store.remove_db().await.unwrap(); + async fn test_client_flush_no_commit_when_writer_drops_ack() { + let config = Config::default(); + let (tx, mut rx) = mpsc::channel::(1); + let mut client = ActivationWriterClient::new(tx, "topic", &config); + + let namespace = generate_unique_namespace(); + client + .reduce( + ActivationBuilder::new() + .id("0") + .taskname("task") + .namespace(&namespace) + .build(TaskActivationBuilder::new()), + ) + .await + .unwrap(); + + // Writer receives the request but drops the ack (e.g. shutdown). + let writer = tokio::spawn(async move { + let _req = rx.recv().await.unwrap(); + // drop _req -> drops the ack sender + }); + + assert_eq!(client.flush().await.unwrap(), None); + writer.await.unwrap(); } } diff --git a/src/kafka/mod.rs b/src/kafka/mod.rs index 7a7776fb..75094c77 100644 --- a/src/kafka/mod.rs +++ b/src/kafka/mod.rs @@ -1,4 +1,3 @@ -pub mod activation_batcher; pub mod activation_writer; pub mod admin; pub mod consumer; diff --git a/src/main.rs b/src/main.rs index cb34dc3e..93dc1b3b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -20,8 +20,7 @@ use taskbroker::fetch::FetchPool; use taskbroker::grpc::auth_middleware::AuthLayer; use taskbroker::grpc::metrics_middleware::MetricsLayer; use taskbroker::grpc::server::{TaskbrokerServer, flush_updates}; -use taskbroker::kafka::activation_batcher::{ActivationBatcher, ActivationBatcherConfig}; -use taskbroker::kafka::activation_writer::{ActivationWriter, ActivationWriterConfig}; +use taskbroker::kafka::activation_writer::{self, ActivationWriterClient, ActivationWriterConfig}; use taskbroker::kafka::admin::create_missing_topics; use taskbroker::kafka::consumer::start_consumer; use taskbroker::kafka::deserialize::{self, DeserializeConfig}; @@ -169,12 +168,21 @@ async fn main() -> Result<(), Error> { .map(|(name, _)| name.to_owned()) .collect(); + // One process-wide activation writer fed by every consumer. It filters, + // forwards and writes activations, coalescing batches across consumers into + // larger DB writes through a single shared producer/store. + let (writer_tx, writer_task) = activation_writer::spawn( + store.clone(), + runtime_config_manager.clone(), + ActivationWriterConfig::from_config(&config), + ); + let mut consumer_tasks: Vec<(String, JoinHandle>)> = Vec::new(); for topic in consumer_topics { let consumer_store = store.clone(); let consumer_config = config.clone(); - let runtime_config_manager = runtime_config_manager.clone(); let task_topic = topic.clone(); + let writer_tx = writer_tx.clone(); let handle = taskbroker::tokio::spawn(async move { // The consumer has an internal thread that listens for cancellations, so it doesn't need @@ -195,14 +203,7 @@ async fn main() -> Result<(), Error> { deserialize::new(DeserializeConfig::from_topic(&consumer_config, &task_topic)), reduce: - ActivationBatcher::new( - ActivationBatcherConfig::from_topic(&consumer_config, &task_topic), - runtime_config_manager.clone() - ), - ActivationWriter::new( - consumer_store.clone(), - ActivationWriterConfig::from_topic(&consumer_config, &task_topic) - ), + ActivationWriterClient::new(writer_tx.clone(), &task_topic, &consumer_config), }), ) @@ -210,6 +211,9 @@ async fn main() -> Result<(), Error> { }); consumer_tasks.push((topic, handle)); } + // The writer outlives individual consumers; drop our handle so it can finish + // draining once every consumer has stopped. + drop(writer_tx); // Status update flush task let (status_update_tx, status_update_task) = if config.batch_status_updates { @@ -362,6 +366,8 @@ async fn main() -> Result<(), Error> { departure.on_completion(log_task_completion(format!("consumer:{topic}"), handle)); } + departure = departure.on_completion(log_task_completion("activation_writer", writer_task)); + if let Some(task) = grpc_server_task { departure = departure.on_completion(log_task_completion("grpc_server", task)); }