From ebd653aa1a2ebeb37a4cd085dfb1e319a9e4dd92 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 2 Jun 2026 07:12:59 -0700 Subject: [PATCH 1/3] Add Validation to Several Config Fields --- Cargo.lock | 53 +++++++++++++++++++++ Cargo.toml | 1 + src/config.rs | 117 +++++++++++++++++++++++++++++++++++++++++++++-- src/fetch/mod.rs | 24 ++-------- src/main.rs | 4 +- 5 files changed, 172 insertions(+), 27 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5b3dbf74..f634694c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2056,6 +2056,28 @@ dependencies = [ "toml_edit", ] +[[package]] +name = "proc-macro-error-attr2" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96de42df36bb9bba5542fe9f1a054b8cc87e172759a1868aa05c1f3acc89dfc5" +dependencies = [ + "proc-macro2", + "quote", +] + +[[package]] +name = "proc-macro-error2" +version = "2.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11ec05c52be0a07b08061f7dd003e7d7092e0472bc731b4af7bb1ef876109802" +dependencies = [ + "proc-macro-error-attr2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "proc-macro2" version = "1.0.101" @@ -3171,6 +3193,7 @@ dependencies = [ "tracing", "tracing-subscriber", "uuid", + "validator", ] [[package]] @@ -3702,6 +3725,36 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "validator" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43fb22e1a008ece370ce08a3e9e4447a910e92621bb49b85d6e48a45397e7cfa" +dependencies = [ + "idna", + "once_cell", + "regex", + "serde", + "serde_derive", + "serde_json", + "url", + "validator_derive", +] + +[[package]] +name = "validator_derive" +version = "0.20.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7df16e474ef958526d1205f6dda359fdfab79d9aa6d54bafcb92dcd07673dca" +dependencies = [ + "darling", + "once_cell", + "proc-macro-error2", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "valuable" version = "0.1.1" diff --git a/Cargo.toml b/Cargo.toml index 5d736cf0..8603190d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ libsqlite3-sys = "0.30.1" metrics = "0.24.0" metrics-exporter-statsd = "0.9.0" prost = "0.14" +validator = { version = "0.20.0", features = ["derive"] } prost-types = "0.14" rand = "0.8.5" rdkafka = { version = "0.37.0", features = ["cmake-build", "ssl"] } diff --git a/src/config.rs b/src/config.rs index 17705933..e26d706b 100644 --- a/src/config.rs +++ b/src/config.rs @@ -2,12 +2,15 @@ use std::borrow::Cow; use std::collections::BTreeMap; +use anyhow::Result; use figment::providers::{Env, Format, Yaml}; use figment::{Figment, Metadata, Profile, Provider}; use rdkafka::ClientConfig; use serde::{Deserialize, Serialize}; +use validator::{Validate, ValidationError}; use crate::Args; +use crate::fetch::MAX_FETCH_THREADS; use crate::logging::LogFormat; #[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Deserialize, Serialize)] @@ -31,7 +34,7 @@ pub enum DeliveryMode { Push, } -#[derive(PartialEq, Debug, Deserialize, Serialize)] +#[derive(PartialEq, Debug, Deserialize, Serialize, Validate)] pub struct Config { /// The sentry DSN to use for error reporting. pub sentry_dsn: Option, @@ -300,34 +303,41 @@ pub struct Config { pub delivery_mode: DeliveryMode, /// The number of concurrent fetch loops in push mode, which should be ≤ `MAX_FETCH_THREADS` and a power of two. - /// If it's not a power of two or it's too large, it will be rounded to a valid nearby value. + #[validate(range(min = 1, max = MAX_FETCH_THREADS), custom(function = "validate_power_of_two"))] pub fetch_threads: usize, /// Time in milliseconds to wait between fetch attempts when no pending activation is found. pub fetch_wait_ms: u64, /// The number of activations to claim with a single fetch query. + #[validate(range(min = 1))] pub fetch_batch_size: i32, - /// The number of concurrent pushers each dispatcher should run. + /// The number of concurrent push threads to run. + #[validate(range(min = 1))] pub push_threads: usize, /// The size of the push queue. + #[validate(range(min = 1))] pub push_queue_size: usize, /// Maximum time in milliseconds to wait when submitting an activation to the push pool. + #[validate(range(min = 1))] pub push_queue_timeout_ms: u64, /// Maximum time in milliseconds for a single push RPC to the worker service. This should be greater than the worker's internal timeout. + #[validate(range(min = 1))] pub push_timeout_ms: u64, /// Update statuses from the gRPC server in batches? pub batch_status_updates: bool, /// The size of a batch of status updates. + #[validate(range(min = 1))] pub status_update_batch_size: usize, /// Maximum milliseconds to wait before flushing a batch of status updates. + #[validate(range(min = 1))] pub status_update_interval_ms: u64, /// Maps every application to its worker endpoint, both represented as strings. @@ -454,13 +464,18 @@ impl Default for Config { impl Config { /// Build a config instance from defaults, env vars, file + CLI options - pub fn from_args(args: &Args) -> Result> { + pub fn from_args(args: &Args) -> Result { let mut builder = Figment::from(Config::default()); + if let Some(path) = &args.config { builder = builder.merge(Yaml::file(path)); } + builder = builder.merge(Env::prefixed("TASKBROKER_")); - let config = builder.extract()?; + + let config: Config = builder.extract()?; + config.validate()?; + Ok(config) } @@ -558,12 +573,29 @@ impl Provider for Config { } } +/// Ensures that `n` is a power of two, used to validate `fetch_threads`. +fn validate_power_of_two(n: usize) -> Result<(), ValidationError> { + let mut m = n; + + while m > 1 { + if m % 2 == 1 { + // Not divisible by two + return Err(ValidationError::new("not_power_of_two")); + } + + m /= 2; + } + + Ok(()) +} + #[cfg(test)] mod tests { use std::borrow::Cow; use std::collections::BTreeMap; use figment::Jail; + use validator::Validate; use crate::Args; use crate::logging::LogFormat; @@ -591,6 +623,81 @@ mod tests { ); } + #[test] + fn test_validate_rejects_invalid_fields() { + let mut config = Config::default(); + + // Fetch threads cannot be zero + config.fetch_threads = 0; + assert!(config.validate().is_err()); + + config.fetch_threads = 1; + assert!(config.validate().is_ok()); + + // Fetch threads must be a power of two + config.fetch_threads = 3; + assert!(config.validate().is_err()); + + config.fetch_threads = 4; + assert!(config.validate().is_ok()); + + // Fetch threads must be ≤ 256 + config.fetch_threads = 512; + assert!(config.validate().is_err()); + + config.fetch_threads = 1; + assert!(config.validate().is_ok()); + + // Fetch batch size cannot be zero + config.fetch_batch_size = 0; + assert!(config.validate().is_err()); + + config.fetch_batch_size = 1; + assert!(config.validate().is_ok()); + + // Push threads cannot be zero + config.push_threads = 0; + assert!(config.validate().is_err()); + + config.push_threads = 1; + assert!(config.validate().is_ok()); + + // Push queue size cannot be zero + config.push_queue_size = 0; + assert!(config.validate().is_err()); + + config.push_queue_size = 1; + assert!(config.validate().is_ok()); + + // Push queue timeout cannot be zero + config.push_queue_timeout_ms = 0; + assert!(config.validate().is_err()); + + config.push_queue_timeout_ms = 1; + assert!(config.validate().is_ok()); + + // Push timeout cannot be zero + config.push_timeout_ms = 0; + assert!(config.validate().is_err()); + + config.push_timeout_ms = 1; + assert!(config.validate().is_ok()); + + // Status update batch size cannot be zero + config.status_update_batch_size = 0; + assert!(config.validate().is_err()); + + config.status_update_batch_size = 1; + assert!(config.validate().is_ok()); + + // Status update interval cannot be zero + config.status_update_interval_ms = 0; + assert!(config.validate().is_err()); + + config.status_update_interval_ms = 1; + assert!(config.validate().is_ok()); + } + #[test] fn test_from_args_config_file() { Jail::expect_with(|jail| { diff --git a/src/fetch/mod.rs b/src/fetch/mod.rs index 615a57a3..0e07bba5 100644 --- a/src/fetch/mod.rs +++ b/src/fetch/mod.rs @@ -19,28 +19,12 @@ use crate::timed; /// This value should be a power of two. If it decreases, some ranges will no longer be queried. /// That means the pending activation query will skip tasks within these ranges. -pub const MAX_FETCH_THREADS: i16 = 256; - -/// Returns the largest positive divisor of [`MAX_FETCH_THREADS`] that is also a power of two. -pub fn normalize_fetch_threads(n: usize) -> usize { - let n = n.max(1); - let mut v = MAX_FETCH_THREADS; - - while v > 1 { - if (v as usize) <= n { - return v as usize; - } - - v /= 2; - } - - 1 -} +pub const MAX_FETCH_THREADS: usize = 256; /// Inclusive bucket range for fetch thread `thread_index` when using `fetch_threads` concurrent fetch loops. /// Requires `fetch_threads` to divide [`MAX_FETCH_THREADS`] (enforced via [`normalize_fetch_threads`]). pub fn bucket_range_for_fetch_thread(thread_index: usize, fetch_threads: usize) -> BucketRange { - let maximum = MAX_FETCH_THREADS as usize; + let maximum = MAX_FETCH_THREADS; let buckets_per_range = maximum / fetch_threads; let low = (thread_index * buckets_per_range) as i16; @@ -79,14 +63,14 @@ impl FetchPool { #[framed] pub async fn start(&self) -> Result<()> { let fetch_wait_ms = self.config.fetch_wait_ms; - let fetch_threads = normalize_fetch_threads(self.config.fetch_threads); + let fetch_threads = self.config.fetch_threads; let mut fetch_pool = crate::tokio::spawn_pool(fetch_threads, |thread_index| { let store = self.store.clone(); let sender = self.sender.clone(); let config = self.config.clone(); - let limit = Some(config.fetch_batch_size.max(1)); + let limit = Some(config.fetch_batch_size); let bucket = Some(bucket_range_for_fetch_thread(thread_index, fetch_threads)); let guard = get_shutdown_guard().shutdown_on_drop(); diff --git a/src/main.rs b/src/main.rs index 8dd4a99d..6e5b0bf7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -204,7 +204,7 @@ async fn main() -> Result<(), Error> { // Status update flush task let (status_update_tx, status_update_task) = if config.batch_status_updates { - let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size.max(1)); + let (tx, rx) = tokio::sync::mpsc::channel(config.status_update_batch_size); let flusher_store = store.clone(); let flusher_config = config.clone(); @@ -293,7 +293,7 @@ async fn main() -> Result<(), Error> { let mut workers: Vec = vec![]; // For every push thread, create a map from applications to worker connections - for _ in 0..config.push_threads.max(1) { + for _ in 0..config.push_threads { let mut map = HashMap::new(); for (application, endpoint) in config.worker_map.clone() { From e5ad6da977642e6d80ff02469fefc6edddff3896 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 2 Jun 2026 09:59:23 -0700 Subject: [PATCH 2/3] Move Validate Call --- src/config.rs | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/config.rs b/src/config.rs index 99c4c194..53e0b421 100644 --- a/src/config.rs +++ b/src/config.rs @@ -589,10 +589,14 @@ impl Config { // Use "__" for nested configurations via environment variables, like `TASKBROKER_KAFKA_TOPICS__PROFILES__CLUSTER` builder = builder.merge(Env::prefixed("TASKBROKER_").split("__")); - let mut config: Config = builder.extract()?; + + // Normalize and validate Kafka values config.normalize_and_validate()?; + // Validate all other values + config.validate()?; + Ok(config) } @@ -858,9 +862,6 @@ impl Config { )); } - // Validate all other values - self.validate()?; - Ok(()) } From 202ddeec125871b37b6ab1b66a207c4d7c146d95 Mon Sep 17 00:00:00 2001 From: james-mcnulty Date: Tue, 2 Jun 2026 10:02:04 -0700 Subject: [PATCH 3/3] Use `usize::is_power_of_two` --- src/config.rs | 21 ++++++++------------- 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/src/config.rs b/src/config.rs index 53e0b421..5231b2d4 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1027,18 +1027,11 @@ impl Provider for Config { /// Ensures that `n` is a power of two, used to validate `fetch_threads`. fn validate_power_of_two(n: usize) -> Result<(), ValidationError> { - let mut m = n; - - while m > 1 { - if m % 2 == 1 { - // Not divisible by two - return Err(ValidationError::new("not_power_of_two")); - } - - m /= 2; + if n.is_power_of_two() { + Ok(()) + } else { + Err(ValidationError::new("not_power_of_two")) } - - Ok(()) } #[cfg(test)] @@ -1082,10 +1075,12 @@ mod tests { #[test] fn test_validate_rejects_invalid_fields() { - let mut config = Config::default(); + let mut config = Config { + fetch_threads: 0, + ..Config::default() + }; // Fetch threads cannot be zero - config.fetch_threads = 0; assert!(config.validate().is_err()); config.fetch_threads = 1;