diff --git a/Cargo.lock b/Cargo.lock index 5b3dbf74..f634694c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2056,6 +2056,28 @@ dependencies = [ "toml_edit", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -3171,6 +3193,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "validator", ] [[package]] @@ -3702,6 +3725,36 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "validator" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43fb22e1a008ece370ce08a3e9e4447a910e92621bb49b85d6e48a45397e7cfa" +dependencies = [ + "idna", + "once_cell", + "regex", + "serde", + "serde_derive", + "serde_json", + "url", + "validator_derive", +] + +[[package]] +name = "validator_derive" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7df16e474ef958526d1205f6dda359fdfab79d9aa6d54bafcb92dcd07673dca" +dependencies = [ + "darling", + "once_cell", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "valuable" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 5d736cf0..8603190d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ libsqlite3-sys = "0.30.1" metrics = "0.24.0" metrics-exporter-statsd = "0.9.0" prost = "0.14" +validator = { version = "0.20.0", features = ["derive"] } prost-types = "0.14" rand = "0.8.5" rdkafka = { version = "0.37.0", features = ["cmake-build", "ssl"] } diff --git a/src/config.rs b/src/config.rs index 23f62005..5231b2d4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,13 +2,16 @@ use std::borrow::Cow; use std::collections::BTreeMap; +use anyhow::{Result, anyhow}; use figment::providers::{Env, Format, Yaml}; use figment::{Figment, Metadata, Profile, Provider}; use rdkafka::ClientConfig; use serde::{Deserialize, Serialize}; use tracing::warn; +use validator::{Validate, ValidationError}; use crate::Args; +use crate::fetch::MAX_FETCH_THREADS; use crate::logging::LogFormat; /// Configuration for a single Kafka topic in multi-topic mode. @@ -109,7 +112,7 @@ pub enum DeliveryMode { Push, } -#[derive(PartialEq, Debug, Deserialize, Serialize)] +#[derive(PartialEq, Debug, Deserialize, Serialize, Validate)] pub struct Config { /// The sentry DSN to use for error reporting. pub sentry_dsn: Option, @@ -402,34 +405,41 @@ pub struct Config { pub delivery_mode: DeliveryMode, /// The number of concurrent fetch loops in push mode, which should be ≤ `MAX_FETCH_THREADS` and a power of two. - /// If it's not a power of two or it's too large, it will be rounded to a valid nearby value. + #[validate(range(min = 1, max = MAX_FETCH_THREADS), custom(function = "validate_power_of_two"))] pub fetch_threads: usize, /// Time in milliseconds to wait between fetch attempts when no pending activation is found. pub fetch_wait_ms: u64, /// The number of activations to claim with a single fetch query. + #[validate(range(min = 1))] pub fetch_batch_size: i32, - /// The number of concurrent pushers each dispatcher should run. + /// The number of concurrent push threads to run. + #[validate(range(min = 1))] pub push_threads: usize, /// The size of the push queue. + #[validate(range(min = 1))] pub push_queue_size: usize, /// Maximum time in milliseconds to wait when submitting an activation to the push pool. + #[validate(range(min = 1))] pub push_queue_timeout_ms: u64, /// Maximum time in milliseconds for a single push RPC to the worker service. This should be greater than the worker's internal timeout. + #[validate(range(min = 1))] pub push_timeout_ms: u64, /// Update statuses from the gRPC server in batches? pub batch_status_updates: bool, /// The size of a batch of status updates. + #[validate(range(min = 1))] pub status_update_batch_size: usize, /// Maximum milliseconds to wait before flushing a batch of status updates. + #[validate(range(min = 1))] pub status_update_interval_ms: u64, /// Maps every application to its worker endpoint, both represented as strings. @@ -570,16 +580,23 @@ impl Default for Config { impl Config { /// Build a config instance from defaults, env vars, file + CLI options - pub fn from_args(args: &Args) -> Result> { + pub fn from_args(args: &Args) -> Result { let mut builder = Figment::from(Config::default()); + if let Some(path) = &args.config { builder = builder.merge(Yaml::file(path)); } - // Use split("__") to support nested config via envvars like: - // TASKBROKER_KAFKA_TOPICS__PROFILES__CLUSTER=my-cluster + + // Use "__" for nested configurations via environment variables, like `TASKBROKER_KAFKA_TOPICS__PROFILES__CLUSTER` builder = builder.merge(Env::prefixed("TASKBROKER_").split("__")); let mut config: Config = builder.extract()?; + + // Normalize and validate Kafka values config.normalize_and_validate()?; + + // Validate all other values + config.validate()?; + Ok(config) } @@ -592,7 +609,7 @@ impl Config { /// zero-config case, where the historical `taskworker` defaults apply), they /// are normalized into `kafka_topics`/`kafka_clusters`. After this, /// `kafka_topics` and `kafka_clusters` are always populated. - pub(crate) fn normalize_and_validate(&mut self) -> Result<(), Box> { + pub(crate) fn normalize_and_validate(&mut self) -> Result<()> { const DEFAULT_CLUSTER: &str = "default"; const DEADLETTER_CLUSTER: &str = "deadletter"; const DEFAULT_TOPIC: &str = "taskworker"; @@ -623,26 +640,26 @@ impl Config { || self.kafka_deadletter_ssl_key_location.is_some(); if uses_new_format && uses_legacy { - return Err(Box::new(figment::Error::from( + return Err(anyhow!( "cannot mix the deprecated kafka_cluster/kafka_topic/kafka_consumer_group/\ kafka_deadletter_cluster (and related kafka_sasl_*/kafka_ssl_*/kafka_deadletter_* \ auth fields) with kafka_topics/kafka_clusters; use one config format or the other" .to_owned(), - ))); + )); } if uses_new_format { // New format: the maps are the source of truth. Require both halves // so a topic always has a cluster to resolve against. if self.kafka_topics.is_empty() { - return Err(Box::new(figment::Error::from( + return Err(anyhow!( "kafka_clusters is set but kafka_topics is empty".to_owned(), - ))); + )); } if self.kafka_clusters.is_empty() { - return Err(Box::new(figment::Error::from( + return Err(anyhow!( "kafka_topics is set but kafka_clusters is empty".to_owned(), - ))); + )); } } else { if self.kafka_cluster.is_some() { @@ -760,10 +777,11 @@ impl Config { }, ); if prev.is_some() { - return Err(Box::new(figment::Error::from(format!( + return Err(anyhow!( "kafka_deadletter_topic '{}' must differ from the consumed topic '{}'", - self.kafka_deadletter_topic, topic_name - )))); + self.kafka_deadletter_topic, + topic_name + )); } // Register the retry topic on the deadletter cluster: retries are @@ -772,10 +790,10 @@ impl Config { // the deadletter topic is rejected to avoid a name collision. if let Some(ref retry_topic) = self.kafka_retry_topic { if retry_topic == &self.kafka_deadletter_topic { - return Err(Box::new(figment::Error::from(format!( + return Err(anyhow!( "kafka_retry_topic '{}' must differ from kafka_deadletter_topic", retry_topic - )))); + )); } self.kafka_topics .entry(retry_topic.clone()) @@ -808,10 +826,10 @@ impl Config { // resolve its cluster. In legacy mode it was added above; in the new // format the user must declare it (produce-only) in kafka_topics. if !self.kafka_topics.contains_key(&self.kafka_deadletter_topic) { - return Err(Box::new(figment::Error::from(format!( + return Err(anyhow!( "kafka_deadletter_topic '{}' is not defined in kafka_topics", self.kafka_deadletter_topic - )))); + )); } // The upkeep producer connects to the deadletter topic's cluster but is @@ -834,11 +852,14 @@ impl Config { .cluster(&self.kafka_topics[&self.kafka_deadletter_topic].cluster)? .address; if retry_address != deadletter_address { - return Err(Box::new(figment::Error::from(format!( + return Err(anyhow!( "retry target topic '{}' is on cluster '{}', but deadletter topic '{}' is on \ '{}'; they share a single producer and must be on the same cluster", - retry_target, retry_address, self.kafka_deadletter_topic, deadletter_address - )))); + retry_target, + retry_address, + self.kafka_deadletter_topic, + deadletter_address + )); } Ok(()) @@ -1004,12 +1025,22 @@ impl Provider for Config { } } +/// Ensures that `n` is a power of two, used to validate `fetch_threads`. +fn validate_power_of_two(n: usize) -> Result<(), ValidationError> { + if n.is_power_of_two() { + Ok(()) + } else { + Err(ValidationError::new("not_power_of_two")) + } +} + #[cfg(test)] mod tests { use std::borrow::Cow; use std::collections::BTreeMap; use figment::Jail; + use validator::Validate; use crate::Args; use crate::logging::LogFormat; @@ -1042,6 +1073,83 @@ mod tests { ); } + #[test] + fn test_validate_rejects_invalid_fields() { + let mut config = Config { + fetch_threads: 0, + ..Config::default() + }; + + // Fetch threads cannot be zero + assert!(config.validate().is_err()); + + config.fetch_threads = 1; + assert!(config.validate().is_ok()); + + // Fetch threads must be a power of two + config.fetch_threads = 3; + assert!(config.validate().is_err()); + + config.fetch_threads = 4; + assert!(config.validate().is_ok()); + + // Fetch threads must be ≤ 256 + config.fetch_threads = 512; + assert!(config.validate().is_err()); + + config.fetch_threads = 1; + assert!(config.validate().is_ok()); + + // Fetch batch size cannot be zero + config.fetch_batch_size = 0; + assert!(config.validate().is_err()); + + config.fetch_batch_size = 1; + assert!(config.validate().is_ok()); + + // Push threads cannot be zero + config.push_threads = 0; + assert!(config.validate().is_err()); + + config.push_threads = 1; + assert!(config.validate().is_ok()); + + // Push queue size cannot be zero + config.push_queue_size = 0; + assert!(config.validate().is_err()); + + config.push_queue_size = 1; + assert!(config.validate().is_ok()); + + // Push queue timeout cannot be zero + config.push_queue_timeout_ms = 0; + assert!(config.validate().is_err()); + + config.push_queue_timeout_ms = 1; + assert!(config.validate().is_ok()); + + // Push timeout cannot be zero + config.push_timeout_ms = 0; + assert!(config.validate().is_err()); + + config.push_timeout_ms = 1; + assert!(config.validate().is_ok()); + + // Status update batch size cannot be zero + config.status_update_batch_size = 0; + assert!(config.validate().is_err()); + + config.status_update_batch_size = 1; + assert!(config.validate().is_ok()); + + // Status update interval cannot be zero + config.status_update_interval_ms = 0; + assert!(config.validate().is_err()); + + config.status_update_interval_ms = 1; + assert!(config.validate().is_ok()); + } + #[test] fn test_from_args_config_file() { Jail::expect_with(|jail| { diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 615a57a3..0e07bba5 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -19,28 +19,12 @@ use crate::timed; /// This value should be a power of two. If it decreases, some ranges will no longer be queried. /// That means the pending activation query will skip tasks within these ranges. -pub const MAX_FETCH_THREADS: i16 = 256; - -/// Returns the largest positive divisor of [`MAX_FETCH_THREADS`] that is also a power of two. -pub fn normalize_fetch_threads(n: usize) -> usize { - let n = n.max(1); - let mut v = MAX_FETCH_THREADS; - - while v > 1 { - if (v as usize) <= n { - return v as usize; - } - - v /= 2; - } - - 1 -} +pub const MAX_FETCH_THREADS: usize = 256; /// Inclusive bucket range for fetch thread `thread_index` when using `fetch_threads` concurrent fetch loops. /// Requires `fetch_threads` to divide [`MAX_FETCH_THREADS`] (enforced via [`normalize_fetch_threads`]). pub fn bucket_range_for_fetch_thread(thread_index: usize, fetch_threads: usize) -> BucketRange { - let maximum = MAX_FETCH_THREADS as usize; + let maximum = MAX_FETCH_THREADS; let buckets_per_range = maximum / fetch_threads; let low = (thread_index * buckets_per_range) as i16; @@ -79,14 +63,14 @@ impl FetchPool { #[framed] pub async fn start(&self) -> Result<()> { let fetch_wait_ms = self.config.fetch_wait_ms; - let fetch_threads = normalize_fetch_threads(self.config.fetch_threads); + let fetch_threads = self.config.fetch_threads; let mut fetch_pool = crate::tokio::spawn_pool(fetch_threads, |thread_index| { let store = self.store.clone(); let sender = self.sender.clone(); let config = self.config.clone(); - let limit = Some(config.fetch_batch_size.max(1)); + let limit = Some(config.fetch_batch_size); let bucket = Some(bucket_range_for_fetch_thread(thread_index, fetch_threads)); let guard = get_shutdown_guard().shutdown_on_drop(); diff --git a/src/main.rs b/src/main.rs index ab644caf..dd3f9939 100644 --- a/src/main.rs +++ b/src/main.rs @@ -205,7 +205,7 @@ async fn main() -> Result<(), Error> { // Status update flush task let (status_update_tx, status_update_task) = if config.batch_status_updates { - let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size.max(1)); + let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size); let flusher_store = store.clone(); let flusher_config = config.clone(); @@ -294,7 +294,7 @@ async fn main() -> Result<(), Error> { let mut workers: Vec = vec![]; // For every push thread, create a map from applications to worker connections - for _ in 0..config.push_threads.max(1) { + for _ in 0..config.push_threads { let mut map = HashMap::new(); for (application, endpoint) in config.worker_map.clone() {