From 15f32fea3d075b47002bf3dc4ea8cc6ef5be019c Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 17:09:56 +0600 Subject: [PATCH 01/44] Do not unlock live tasks on worker startup --- ...9c981673f0139ec30884dead832f19bae36d9.json | 12 ---- src/error.rs | 2 - src/worker.rs | 60 ++++++------------- 3 files changed, 18 insertions(+), 56 deletions(-) delete mode 100644 .sqlx/query-24d7ef3a1dc86b9e408c7c691cb9c981673f0139ec30884dead832f19bae36d9.json diff --git a/.sqlx/query-24d7ef3a1dc86b9e408c7c691cb9c981673f0139ec30884dead832f19bae36d9.json b/.sqlx/query-24d7ef3a1dc86b9e408c7c691cb9c981673f0139ec30884dead832f19bae36d9.json deleted file mode 100644 index ee1c393..0000000 --- a/.sqlx/query-24d7ef3a1dc86b9e408c7c691cb9c981673f0139ec30884dead832f19bae36d9.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "UPDATE pg_task SET is_running = false WHERE is_running = true", - "describe": { - "columns": [], - "parameters": { - "Left": [] - }, - "nullable": [] - }, - "hash": "24d7ef3a1dc86b9e408c7c691cb9c981673f0139ec30884dead832f19bae36d9" -} diff --git a/src/error.rs b/src/error.rs index ec3e3a4..322bb90 100644 --- a/src/error.rs +++ b/src/error.rs @@ -13,8 +13,6 @@ pub enum Error { scheduling and running of the step): {1} */ DeserializeStep(#[source] serde_json::Error, String), - /// can't unlock stale tasks - UnlockStaleTasks(#[source] sqlx::Error), /// listener can't connect to the db ListenerConnect(#[source] sqlx::Error), /// can't start listening for table changes diff --git a/src/worker.rs b/src/worker.rs index a31e3e8..9400f4b 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -7,7 +7,7 @@ use crate::{ use sqlx::postgres::PgPool; use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration}; use tokio::{sync::Semaphore, time::sleep}; -use tracing::{debug, error, info, trace, warn}; +use tracing::{error, info, trace, warn}; const LOCKED_TASK_RECHECK_DELAY: Duration = Duration::from_millis(100); @@ -40,7 +40,6 @@ impl + 'static> Worker { /// Runs all ready tasks to completion and waits for new ones pub async fn run(&self) -> Result<()> { - self.unlock_stale_tasks().await?; self.listener.listen(self.db.clone()).await?; let semaphore = Arc::new(Semaphore::new(self.concurrency.get())); @@ -74,24 +73,6 @@ impl + 'static> Worker { self.finish_run(result, semaphore).await } - /// Unlocks all tasks. This is intended to run at the start of the worker as - /// some tasks could remain locked as running indefinitely if the - /// previous run ended due to some kind of crash. - async fn unlock_stale_tasks(&self) -> Result<()> { - let unlocked = - sqlx::query!("UPDATE pg_task SET is_running = false WHERE is_running = true") - .execute(&self.db) - .await - .map_err(Error::UnlockStaleTasks)? - .rows_affected(); - if unlocked == 0 { - debug!("No stale tasks to unlock") - } else { - debug!("Unlocked {} stale tasks", unlocked) - } - Ok(()) - } - /// Waits until the next task is ready, marks it running and returns it. /// Returns `None` if the worker is stopped async fn recv_task(&self) -> Result> { @@ -537,21 +518,6 @@ mod tests { }) } - #[sqlx::test(migrations = "./migrations")] - async fn run_returns_unlock_stale_task_errors(pool: PgPool) { - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN is_running TO running_state") - .execute(&pool) - .await - .unwrap(); - - let err = Worker::::new(pool).run().await.unwrap_err(); - - assert!(matches!( - err, - Error::UnlockStaleTasks(sqlx::Error::Database(_)) - )); - } - #[sqlx::test(migrations = "./migrations")] async fn run_returns_listener_startup_errors(pool: PgPool) { let worker = Worker::::new(pool); @@ -979,27 +945,37 @@ mod tests { } #[sqlx::test(migrations = "./migrations")] - async fn run_unlocks_stale_tasks_before_processing(pool: PgPool) { + async fn starting_another_worker_does_not_unlock_live_tasks(pool: PgPool) { let state = StepStateGuard::new(); insert_task( &pool, - &TestTask::Complete(Complete { key: state.key() }), - true, + &TestTask::Blocking(Blocking { key: state.key() }), + false, ) .await; - let worker = spawn_worker(pool.clone()); - + let first_worker = spawn_worker(pool.clone()); state.state().wait_for_events(1).await; + + let second_worker = spawn_worker(pool.clone()); + sleep(Duration::from_millis(150)).await; + assert_eq!(state.state().events(), vec!["started"]); + stop_worker(&pool).await; + state.state().release(); - timeout(Duration::from_secs(1), worker) + timeout(Duration::from_secs(1), first_worker) + .await + .unwrap() + .unwrap() + .unwrap(); + timeout(Duration::from_secs(1), second_worker) .await .unwrap() .unwrap() .unwrap(); - assert_eq!(state.state().events(), vec!["complete"]); + assert_eq!(state.state().events(), vec!["started", "completed"]); assert_eq!(task_count(&pool).await, 0); } From 603cd5c8eae101232cb43f4df9122d924b82768a Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 18:21:26 +0600 Subject: [PATCH 02/44] Add task lease ownership checks --- ...b3a757de18d1ad1eac534a8b48f811ed6ddea.json | 23 + ...887598b499bb439da0e15a5b7aba6ee5ab2b.json} | 4 +- ...698fc6226e641b2216d4e1266e8f49bcaeed5.json | 14 - ...d80f7aa197d1486407bc43b7d1a8aad5e5fd0.json | 24 - ...b197d8447bfe601667b8fba5690d8aa4b05d9.json | 15 - ...7a5fde522a8600ab6bc190be70637b9c7fa19.json | 20 + ...fc0dd82bb91191a7b895d7ca61b53d76b3aa4.json | 17 + ...cc674f4ef3137aef2db4ea85373c64d6884d7.json | 20 - ...827457acdf48a3fd870e745d9eefa0f0b5ee7.json | 16 + ...65c4400e46cb175daebcec8bc95c6c1e74f73.json | 25 + ...edad674e8c1ebfbd134ae115abbcbf6db6cf7.json | 15 + ...3418241b5071c94314838c22c1d99416ac1b3.json | 16 + ...b906b7b52ff09a55a3064ed7943558234b103.json | 14 - ...dc382dd6fff214a80f93c4b9fd082bf24696c.json | 16 - Cargo.toml | 1 + examples/tutorial.rs | 7 +- src/lib.rs | 19 +- src/task.rs | 464 ++++++++++++++---- src/worker.rs | 131 +++-- 19 files changed, 616 insertions(+), 245 deletions(-) create mode 100644 .sqlx/query-052393d8aa7dd13457e6fbf7e98b3a757de18d1ad1eac534a8b48f811ed6ddea.json rename .sqlx/{query-3560f4907b25647d12debefd6de3ea5968877075d7f792ceddb6a31b0890d669.json => query-0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b.json} (64%) delete mode 100644 .sqlx/query-1a69b80bf0909e445dcbe33ef6f698fc6226e641b2216d4e1266e8f49bcaeed5.json delete mode 100644 .sqlx/query-2cee211c9b3b8cb5ba895facf21d80f7aa197d1486407bc43b7d1a8aad5e5fd0.json delete mode 100644 .sqlx/query-2d2d8318e918473d99f96fed61ab197d8447bfe601667b8fba5690d8aa4b05d9.json create mode 100644 .sqlx/query-33ef20cd3dea5b74190e9540ca47a5fde522a8600ab6bc190be70637b9c7fa19.json create mode 100644 .sqlx/query-7de8839f5ef28990b78ca789cdefc0dd82bb91191a7b895d7ca61b53d76b3aa4.json delete mode 100644 .sqlx/query-91d8de5a94cbb3ac437ce760100cc674f4ef3137aef2db4ea85373c64d6884d7.json create mode 100644 .sqlx/query-b404bf65c635097d909d262996f827457acdf48a3fd870e745d9eefa0f0b5ee7.json create mode 100644 .sqlx/query-bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73.json create mode 100644 .sqlx/query-d37b82e8ea2e95ecd9c77994696edad674e8c1ebfbd134ae115abbcbf6db6cf7.json create mode 100644 .sqlx/query-eb70293509d0d3d24f1a6d8e13b3418241b5071c94314838c22c1d99416ac1b3.json delete mode 100644 .sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json delete mode 100644 .sqlx/query-f7824b0e7bc69b17d2c3de68b35dc382dd6fff214a80f93c4b9fd082bf24696c.json diff --git a/.sqlx/query-052393d8aa7dd13457e6fbf7e98b3a757de18d1ad1eac534a8b48f811ed6ddea.json b/.sqlx/query-052393d8aa7dd13457e6fbf7e98b3a757de18d1ad1eac534a8b48f811ed6ddea.json new file mode 100644 index 0000000..49538aa --- /dev/null +++ b/.sqlx/query-052393d8aa7dd13457e6fbf7e98b3a757de18d1ad1eac534a8b48f811ed6ddea.json @@ -0,0 +1,23 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET error = $2,\n wakeup_at = now(),\n locked_by = NULL,\n lock_expires_at = NULL\n WHERE id = $1\n RETURNING step::TEXT as \"step!\"\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "step!", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + }, + "nullable": [ + false + ] + }, + "hash": "052393d8aa7dd13457e6fbf7e98b3a757de18d1ad1eac534a8b48f811ed6ddea" +} diff --git a/.sqlx/query-3560f4907b25647d12debefd6de3ea5968877075d7f792ceddb6a31b0890d669.json b/.sqlx/query-0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b.json similarity index 64% rename from .sqlx/query-3560f4907b25647d12debefd6de3ea5968877075d7f792ceddb6a31b0890d669.json rename to .sqlx/query-0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b.json index ea33e03..6ca2ff2 100644 --- a/.sqlx/query-3560f4907b25647d12debefd6de3ea5968877075d7f792ceddb6a31b0890d669.json +++ b/.sqlx/query-0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n id,\n step,\n tried\n FROM pg_task\n WHERE is_running = false\n AND error IS NULL\n AND wakeup_at <= now()\n ORDER BY wakeup_at\n LIMIT 1\n FOR UPDATE SKIP LOCKED\n ", + "query": "\n SELECT\n id,\n step,\n tried\n FROM pg_task\n WHERE error IS NULL\n AND wakeup_at <= now()\n AND (locked_by IS NULL OR lock_expires_at <= now())\n ORDER BY wakeup_at\n LIMIT 1\n FOR UPDATE SKIP LOCKED\n ", "describe": { "columns": [ { @@ -28,5 +28,5 @@ false ] }, - "hash": "3560f4907b25647d12debefd6de3ea5968877075d7f792ceddb6a31b0890d669" + "hash": "0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b" } diff --git a/.sqlx/query-1a69b80bf0909e445dcbe33ef6f698fc6226e641b2216d4e1266e8f49bcaeed5.json b/.sqlx/query-1a69b80bf0909e445dcbe33ef6f698fc6226e641b2216d4e1266e8f49bcaeed5.json deleted file mode 100644 index fe8b8cb..0000000 --- a/.sqlx/query-1a69b80bf0909e445dcbe33ef6f698fc6226e641b2216d4e1266e8f49bcaeed5.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "UPDATE pg_task SET is_running = true WHERE id = $1", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [] - }, - "hash": "1a69b80bf0909e445dcbe33ef6f698fc6226e641b2216d4e1266e8f49bcaeed5" -} diff --git a/.sqlx/query-2cee211c9b3b8cb5ba895facf21d80f7aa197d1486407bc43b7d1a8aad5e5fd0.json b/.sqlx/query-2cee211c9b3b8cb5ba895facf21d80f7aa197d1486407bc43b7d1a8aad5e5fd0.json deleted file mode 100644 index a8c258b..0000000 --- a/.sqlx/query-2cee211c9b3b8cb5ba895facf21d80f7aa197d1486407bc43b7d1a8aad5e5fd0.json +++ /dev/null @@ -1,24 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET is_running = false,\n tried = tried + $3,\n error = $2,\n wakeup_at = now()\n WHERE id = $1\n RETURNING step::TEXT as \"step!\"\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "step!", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Int4" - ] - }, - "nullable": [ - false - ] - }, - "hash": "2cee211c9b3b8cb5ba895facf21d80f7aa197d1486407bc43b7d1a8aad5e5fd0" -} diff --git a/.sqlx/query-2d2d8318e918473d99f96fed61ab197d8447bfe601667b8fba5690d8aa4b05d9.json b/.sqlx/query-2d2d8318e918473d99f96fed61ab197d8447bfe601667b8fba5690d8aa4b05d9.json deleted file mode 100644 index 6ef6f8c..0000000 --- a/.sqlx/query-2d2d8318e918473d99f96fed61ab197d8447bfe601667b8fba5690d8aa4b05d9.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET is_running = false,\n tried = tried + 1,\n wakeup_at = $2\n WHERE id = $1\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "2d2d8318e918473d99f96fed61ab197d8447bfe601667b8fba5690d8aa4b05d9" -} diff --git a/.sqlx/query-33ef20cd3dea5b74190e9540ca47a5fde522a8600ab6bc190be70637b9c7fa19.json b/.sqlx/query-33ef20cd3dea5b74190e9540ca47a5fde522a8600ab6bc190be70637b9c7fa19.json new file mode 100644 index 0000000..d80d201 --- /dev/null +++ b/.sqlx/query-33ef20cd3dea5b74190e9540ca47a5fde522a8600ab6bc190be70637b9c7fa19.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n CASE\n WHEN locked_by IS NOT NULL THEN\n GREATEST(wakeup_at, lock_expires_at)\n ELSE\n wakeup_at\n END AS \"next_at!\"\n FROM pg_task\n WHERE error IS NULL\n ORDER BY 1\n LIMIT 1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "next_at!", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "33ef20cd3dea5b74190e9540ca47a5fde522a8600ab6bc190be70637b9c7fa19" +} diff --git a/.sqlx/query-7de8839f5ef28990b78ca789cdefc0dd82bb91191a7b895d7ca61b53d76b3aa4.json b/.sqlx/query-7de8839f5ef28990b78ca789cdefc0dd82bb91191a7b895d7ca61b53d76b3aa4.json new file mode 100644 index 0000000..a86711f --- /dev/null +++ b/.sqlx/query-7de8839f5ef28990b78ca789cdefc0dd82bb91191a7b895d7ca61b53d76b3aa4.json @@ -0,0 +1,17 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET tried = 0,\n step = $2,\n wakeup_at = $3,\n locked_by = NULL,\n lock_expires_at = NULL\n WHERE id = $1\n AND locked_by = $4\n AND lock_expires_at > now()\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "7de8839f5ef28990b78ca789cdefc0dd82bb91191a7b895d7ca61b53d76b3aa4" +} diff --git a/.sqlx/query-91d8de5a94cbb3ac437ce760100cc674f4ef3137aef2db4ea85373c64d6884d7.json b/.sqlx/query-91d8de5a94cbb3ac437ce760100cc674f4ef3137aef2db4ea85373c64d6884d7.json deleted file mode 100644 index 8ff5240..0000000 --- a/.sqlx/query-91d8de5a94cbb3ac437ce760100cc674f4ef3137aef2db4ea85373c64d6884d7.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT wakeup_at\n FROM pg_task\n WHERE is_running = false\n AND error IS NULL\n ORDER BY wakeup_at\n LIMIT 1\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "wakeup_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false - ] - }, - "hash": "91d8de5a94cbb3ac437ce760100cc674f4ef3137aef2db4ea85373c64d6884d7" -} diff --git a/.sqlx/query-b404bf65c635097d909d262996f827457acdf48a3fd870e745d9eefa0f0b5ee7.json b/.sqlx/query-b404bf65c635097d909d262996f827457acdf48a3fd870e745d9eefa0f0b5ee7.json new file mode 100644 index 0000000..7ce6bf3 --- /dev/null +++ b/.sqlx/query-b404bf65c635097d909d262996f827457acdf48a3fd870e745d9eefa0f0b5ee7.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET locked_by = $2,\n lock_expires_at = $3\n WHERE id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "b404bf65c635097d909d262996f827457acdf48a3fd870e745d9eefa0f0b5ee7" +} diff --git a/.sqlx/query-bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73.json b/.sqlx/query-bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73.json new file mode 100644 index 0000000..2c1622f --- /dev/null +++ b/.sqlx/query-bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73.json @@ -0,0 +1,25 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET tried = tried + $3,\n error = $2,\n wakeup_at = now(),\n locked_by = NULL,\n lock_expires_at = NULL\n WHERE id = $1\n AND locked_by = $4\n AND lock_expires_at > now()\n RETURNING step::TEXT as \"step!\"\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "step!", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Int4", + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73" +} diff --git a/.sqlx/query-d37b82e8ea2e95ecd9c77994696edad674e8c1ebfbd134ae115abbcbf6db6cf7.json b/.sqlx/query-d37b82e8ea2e95ecd9c77994696edad674e8c1ebfbd134ae115abbcbf6db6cf7.json new file mode 100644 index 0000000..81316a0 --- /dev/null +++ b/.sqlx/query-d37b82e8ea2e95ecd9c77994696edad674e8c1ebfbd134ae115abbcbf6db6cf7.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n DELETE FROM pg_task\n WHERE id = $1\n AND locked_by = $2\n AND lock_expires_at > now()\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "d37b82e8ea2e95ecd9c77994696edad674e8c1ebfbd134ae115abbcbf6db6cf7" +} diff --git a/.sqlx/query-eb70293509d0d3d24f1a6d8e13b3418241b5071c94314838c22c1d99416ac1b3.json b/.sqlx/query-eb70293509d0d3d24f1a6d8e13b3418241b5071c94314838c22c1d99416ac1b3.json new file mode 100644 index 0000000..c995175 --- /dev/null +++ b/.sqlx/query-eb70293509d0d3d24f1a6d8e13b3418241b5071c94314838c22c1d99416ac1b3.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET tried = tried + 1,\n wakeup_at = $2,\n locked_by = NULL,\n lock_expires_at = NULL\n WHERE id = $1\n AND locked_by = $3\n AND lock_expires_at > now()\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "eb70293509d0d3d24f1a6d8e13b3418241b5071c94314838c22c1d99416ac1b3" +} diff --git a/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json b/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json deleted file mode 100644 index 8c0474e..0000000 --- a/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "DELETE FROM pg_task WHERE id = $1", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [] - }, - "hash": "ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103" -} diff --git a/.sqlx/query-f7824b0e7bc69b17d2c3de68b35dc382dd6fff214a80f93c4b9fd082bf24696c.json b/.sqlx/query-f7824b0e7bc69b17d2c3de68b35dc382dd6fff214a80f93c4b9fd082bf24696c.json deleted file mode 100644 index 7f934db..0000000 --- a/.sqlx/query-f7824b0e7bc69b17d2c3de68b35dc382dd6fff214a80f93c4b9fd082bf24696c.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET is_running = false,\n tried = 0,\n step = $2,\n wakeup_at = $3\n WHERE id = $1\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "f7824b0e7bc69b17d2c3de68b35dc382dd6fff214a80f93c4b9fd082bf24696c" -} diff --git a/Cargo.toml b/Cargo.toml index ea185e2..68e707f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -25,6 +25,7 @@ sqlx = { version = "0.8", features = [ thiserror = "2" tokio = "1" tracing = "0.1" +uuid = { version = "1", features = ["v4"] } [dev-dependencies] anyhow = "1" diff --git a/examples/tutorial.rs b/examples/tutorial.rs index cbf38d3..d335d3b 100644 --- a/examples/tutorial.rs +++ b/examples/tutorial.rs @@ -151,7 +151,7 @@ mod tests { let errored_row = timeout(Duration::from_secs(8), async { loop { let row = sqlx::query!( - "SELECT tried, is_running, error FROM pg_task WHERE id = $1", + "SELECT tried, locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", id, ) .fetch_optional(&pool) @@ -172,7 +172,8 @@ mod tests { errored_row.tried, >::RETRY_LIMIT + 1 ); - assert!(!errored_row.is_running); + assert!(errored_row.locked_by.is_none()); + assert!(errored_row.lock_expires_at.is_none()); assert!(errored_row.error.is_some()); std::fs::write(&path, "Fixed World").unwrap(); @@ -193,7 +194,7 @@ mod tests { timeout(Duration::from_secs(2), async { loop { if sqlx::query!( - "SELECT tried, is_running, error FROM pg_task WHERE id = $1", + "SELECT tried, locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", id, ) .fetch_optional(&pool) diff --git a/src/lib.rs b/src/lib.rs index 4c46657..1b6b693 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -116,14 +116,15 @@ final error message. Let's look into the DB to find out what happened: ```bash ~$ psql pg_task -c 'table pg_task' -[ RECORD 1 ]------------------------------------------------ -id | cddf7de1-1194-4bee-90c6-af73d9206ce2 -step | {"Greeter":{"ReadName":{"filename":"name.txt"}}} -wakeup_at | 2024-06-30 09:32:27.703599+06 -tried | 6 -is_running | f -error | No such file or directory (os error 2) -created_at | 2024-06-30 09:32:22.628563+06 -updated_at | 2024-06-30 09:32:27.703599+06 +id | cddf7de1-1194-4bee-90c6-af73d9206ce2 +step | {"Greeter":{"ReadName":{"filename":"name.txt"}}} +wakeup_at | 2024-06-30 09:32:27.703599+06 +tried | 6 +locked_by | +lock_expires_at | +error | No such file or directory (os error 2) +created_at | 2024-06-30 09:32:22.628563+06 +updated_at | 2024-06-30 09:32:27.703599+06 ``` - a non-null `error` field indicates that the task has errored and contains the @@ -270,7 +271,7 @@ The workers would wait until the current step of all the tasks is finished and then exit. You can wait for this by checking for the existence of running tasks: ```sql -SELECT EXISTS(SELECT 1 FROM pg_task WHERE is_running = true); +SELECT EXISTS(SELECT 1 FROM pg_task WHERE locked_by IS NOT NULL); ``` ### Delaying Steps diff --git a/src/task.rs b/src/task.rs index 4efa3eb..e997306 100644 --- a/src/task.rs +++ b/src/task.rs @@ -6,11 +6,11 @@ use chrono::{DateTime, Utc}; use serde::Serialize; use sqlx::{ postgres::{PgConnection, PgPool}, - types::Uuid, PgExecutor, }; use std::{fmt, time::Duration}; -use tracing::{debug, error, info, trace}; +use tracing::{debug, error, info, trace, warn}; +use uuid::Uuid; #[derive(Debug)] pub struct Task { @@ -19,6 +19,25 @@ pub struct Task { tried: i32, } +#[derive(Clone, Copy, Debug)] +pub(crate) struct TaskLease { + worker_id: Uuid, + timeout: chrono::Duration, +} + +impl TaskLease { + pub(crate) fn new(worker_id: Uuid, timeout: Duration) -> Self { + Self { + worker_id, + timeout: std_duration_to_chrono(timeout), + } + } + + fn expires_at(self) -> DateTime { + Utc::now() + self.timeout + } +} + impl Task { /// Returns a delay before a task should run at the given time. pub fn delay_until(wakeup_at: DateTime) -> Option { @@ -30,7 +49,7 @@ impl Task { } } - /// Fetches the closest ready unlocked task to run. + /// Fetches the closest ready task to run. pub async fn fetch_ready(con: &mut PgConnection) -> Result> { trace!("Fetching the closest ready task to run"); sqlx::query_as!( @@ -41,9 +60,9 @@ impl Task { step, tried FROM pg_task - WHERE is_running = false - AND error IS NULL + WHERE error IS NULL AND wakeup_at <= now() + AND (locked_by IS NULL OR lock_expires_at <= now()) ORDER BY wakeup_at LIMIT 1 FOR UPDATE SKIP LOCKED @@ -54,16 +73,21 @@ impl Task { .map_err(db_error!()) } - /// Fetches the closest scheduled task time. - pub async fn fetch_next_wakeup_at(con: &mut PgConnection) -> Result>> { - trace!("Fetching the closest scheduled task time"); + /// Fetches the closest time when a task may become claimable. + pub async fn fetch_next_available_at(con: &mut PgConnection) -> Result>> { + trace!("Fetching the closest task availability time"); sqlx::query_scalar!( r#" - SELECT wakeup_at + SELECT + CASE + WHEN locked_by IS NOT NULL THEN + GREATEST(wakeup_at, lock_expires_at) + ELSE + wakeup_at + END AS "next_at!" FROM pg_task - WHERE is_running = false - AND error IS NULL - ORDER BY wakeup_at + WHERE error IS NULL + ORDER BY 1 LIMIT 1 "#, ) @@ -73,11 +97,22 @@ impl Task { } /// Marks the task running - pub async fn mark_running(&self, con: &mut PgConnection) -> Result<()> { + pub(crate) async fn mark_running( + &self, + con: &mut PgConnection, + lease: TaskLease, + ) -> Result<()> { trace!("[{}] mark running", self.id); sqlx::query!( - "UPDATE pg_task SET is_running = true WHERE id = $1", - self.id + r#" + UPDATE pg_task + SET locked_by = $2, + lock_expires_at = $3 + WHERE id = $1 + "#, + self.id, + lease.worker_id, + lease.expires_at(), ) .execute(con) .await @@ -88,21 +123,30 @@ impl Task { /// Deserializes the current task step and marks it running. /// If deserialization fails, stores the error instead and leaves the task /// non-running. - pub(crate) async fn claim>(&self, con: &mut PgConnection) -> Result> { + pub(crate) async fn claim>( + &self, + con: &mut PgConnection, + lease: TaskLease, + ) -> Result> { let step = match self.parse_step() { Ok(step) => step, Err(e) => { - self.save_error(&mut *con, e.into(), false).await?; + self.save_claim_error(&mut *con, e.into()).await?; return Ok(None); } }; - self.mark_running(con).await?; + self.mark_running(con, lease).await?; Ok(Some(step)) } /// Runs the current step of the task to completion - pub async fn run_step>(&self, db: &PgPool, step: S) -> Result<()> { + pub(crate) async fn run_step>( + &self, + db: &PgPool, + step: S, + lease: TaskLease, + ) -> Result<()> { info!( "[{id}]{attempt} run step {step}", id = self.id, @@ -118,15 +162,19 @@ impl Task { match step.step(db).await { Err(e) => { if self.tried < retry_limit { - self.retry(db, self.tried, retry_limit, retry_delay, e) + self.retry(db, self.tried, retry_limit, retry_delay, e, lease) .await?; } else { - self.save_error(db, e, true).await?; + self.save_step_error(db, e, true, lease).await?; } } - Ok(NextStep::None) => self.complete(db).await?, - Ok(NextStep::Now(step)) => self.save_next_step(db, step, Duration::ZERO).await?, - Ok(NextStep::Delayed(step, delay)) => self.save_next_step(db, step, delay).await?, + Ok(NextStep::None) => self.complete(db, lease).await?, + Ok(NextStep::Now(step)) => { + self.save_next_step(db, step, Duration::ZERO, lease).await?; + } + Ok(NextStep::Delayed(step, delay)) => { + self.save_next_step(db, step, delay, lease).await?; + } }; Ok(()) } @@ -136,43 +184,88 @@ impl Task { .map_err(|e| Error::DeserializeStep(e, self.step.to_string())) } - /// Saves the task error - async fn save_error<'e, E>(&self, db: E, err: StepError, increment_tried: bool) -> Result<()> + /// Saves a deserialization error for a task before the worker owns it. + async fn save_claim_error<'e, E>(&self, db: E, err: StepError) -> Result<()> where E: PgExecutor<'e>, { + let err_str = source_chain::to_string(&*err); + let step = sqlx::query!( + r#" + UPDATE pg_task + SET error = $2, + wakeup_at = now(), + locked_by = NULL, + lock_expires_at = NULL + WHERE id = $1 + RETURNING step::TEXT as "step!" + "#, + self.id, + &err_str, + ) + .fetch_one(db) + .await + .map(|r| r.step) + .map_err(db_error!())?; + + error!( + "[{id}] couldn't deserialize step {step}: {err_str}", + id = self.id + ); + + Ok(()) + } + + /// Saves the task error if the worker still owns the task. + async fn save_step_error( + &self, + db: &PgPool, + err: StepError, + increment_tried: bool, + lease: TaskLease, + ) -> Result<()> { let err_str = source_chain::to_string(&*err); let tried_increment = if increment_tried { 1 } else { 0 }; let step = sqlx::query!( r#" UPDATE pg_task - SET is_running = false, - tried = tried + $3, + SET tried = tried + $3, error = $2, - wakeup_at = now() + wakeup_at = now(), + locked_by = NULL, + lock_expires_at = NULL WHERE id = $1 + AND locked_by = $4 + AND lock_expires_at > now() RETURNING step::TEXT as "step!" "#, self.id, &err_str, tried_increment, + lease.worker_id, ) - .fetch_one(db) + .fetch_optional(db) .await - .map(|r| r.step) .map_err(db_error!())?; + let Some(row) = step else { + self.log_lost_lease(lease.worker_id, "save the step error"); + return Ok(()); + }; + if increment_tried { let attempt = self.tried + 1; error!( "[{id}] resulted in an error at step {step} on {attempt} attempt: {err_str}", id = self.id, + step = row.step, attempt = ordinal(attempt) ); } else { error!( "[{id}] couldn't deserialize step {step}: {err_str}", - id = self.id + id = self.id, + step = row.step ); } @@ -185,41 +278,62 @@ impl Task { db: &PgPool, step: impl Serialize + fmt::Debug, delay: Duration, + lease: TaskLease, ) -> Result<()> { let step = match serde_json::to_string(&step) .map_err(|e| Error::SerializeStep(e, format!("{step:?}"))) { Ok(x) => x, - Err(e) => return self.save_error(db, e.into(), true).await, + Err(e) => return self.save_step_error(db, e.into(), true, lease).await, }; debug!("[{}] moved to the next step {step}", self.id); - sqlx::query!( + let result = sqlx::query!( " UPDATE pg_task - SET is_running = false, - tried = 0, + SET tried = 0, step = $2, - wakeup_at = $3 + wakeup_at = $3, + locked_by = NULL, + lock_expires_at = NULL WHERE id = $1 + AND locked_by = $4 + AND lock_expires_at > now() ", self.id, step, Utc::now() + std_duration_to_chrono(delay), + lease.worker_id, ) .execute(db) .await .map_err(db_error!())?; + if result.rows_affected() == 0 { + self.log_lost_lease(lease.worker_id, "save the next step"); + } Ok(()) } /// Removes the finished task - async fn complete(&self, db: &PgPool) -> Result<()> { - info!("[{}] is successfully completed", self.id); - sqlx::query!("DELETE FROM pg_task WHERE id = $1", self.id) - .execute(db) - .await - .map_err(db_error!())?; + async fn complete(&self, db: &PgPool, lease: TaskLease) -> Result<()> { + let result = sqlx::query!( + r#" + DELETE FROM pg_task + WHERE id = $1 + AND locked_by = $2 + AND lock_expires_at > now() + "#, + self.id, + lease.worker_id, + ) + .execute(db) + .await + .map_err(db_error!())?; + if result.rows_affected() == 0 { + self.log_lost_lease(lease.worker_id, "complete the task"); + } else { + info!("[{}] is successfully completed", self.id); + } Ok(()) } @@ -231,6 +345,7 @@ impl Task { retry_limit: i32, delay: Duration, err: StepError, + lease: TaskLease, ) -> Result<()> { let delay = std_duration_to_chrono(delay); debug!( @@ -240,32 +355,47 @@ impl Task { err = source_chain::to_string(&*err), ); - sqlx::query!( + let result = sqlx::query!( " UPDATE pg_task - SET is_running = false, - tried = tried + 1, - wakeup_at = $2 + SET tried = tried + 1, + wakeup_at = $2, + locked_by = NULL, + lock_expires_at = NULL WHERE id = $1 + AND locked_by = $3 + AND lock_expires_at > now() ", self.id, Utc::now() + delay, + lease.worker_id, ) .execute(db) .await .map_err(db_error!())?; + if result.rows_affected() == 0 { + self.log_lost_lease(lease.worker_id, "schedule a retry"); + } Ok(()) } + + fn log_lost_lease(&self, worker_id: Uuid, action: &str) { + warn!( + "[{}] couldn't {action} because worker {worker_id} no longer owns the task", + self.id + ); + } } #[cfg(test)] mod tests { - use super::Task; + use super::{Task, TaskLease}; use crate::{NextStep, Step}; use chrono::{DateTime, Duration as ChronoDuration, Utc}; - use sqlx::{types::Uuid, PgPool}; + use sqlx::PgPool; use std::{io, time::Duration}; + use uuid::Uuid; fn init_tracing() { static INIT: std::sync::Once = std::sync::Once::new(); @@ -399,7 +529,8 @@ mod tests { step: String, wakeup_at: DateTime, tried: i32, - is_running: bool, + locked_by: Option, + lock_expires_at: Option>, error: Option, } @@ -427,24 +558,45 @@ mod tests { assert!(matches!(err, crate::Error::Db(sqlx::Error::Database(_), _))); } + fn worker_id() -> Uuid { + Uuid::from_u128(1) + } + + fn other_worker_id() -> Uuid { + Uuid::from_u128(2) + } + + fn task_lease() -> TaskLease { + TaskLease::new(worker_id(), Duration::from_secs(60)) + } + async fn insert_task_row( pool: &PgPool, step: &str, wakeup_at: DateTime, tried: i32, - is_running: bool, + is_leased: bool, error: Option<&str>, ) -> Uuid { + let (locked_by, lock_expires_at) = if is_leased { + ( + Some(other_worker_id()), + Some(Utc::now() + ChronoDuration::seconds(60)), + ) + } else { + (None, None) + }; sqlx::query!( " - INSERT INTO pg_task (step, wakeup_at, tried, is_running, error) - VALUES ($1, $2, $3, $4, $5) + INSERT INTO pg_task (step, wakeup_at, tried, locked_by, lock_expires_at, error) + VALUES ($1, $2, $3, $4, $5, $6) RETURNING id ", step, wakeup_at, tried, - is_running, + locked_by, + lock_expires_at, error, ) .fetch_one(pool) @@ -453,33 +605,60 @@ mod tests { .id } - async fn insert_task(pool: &PgPool, step: &TestTask, tried: i32, is_running: bool) -> Uuid { + async fn insert_task(pool: &PgPool, step: &TestTask, tried: i32, is_leased: bool) -> Uuid { let step = serialized_step(step); insert_task_row( pool, &step, Utc::now() - ChronoDuration::milliseconds(1), tried, - is_running, + is_leased, None, ) .await } - async fn claim_task(pool: &PgPool, step: TestTask, tried: i32) -> (Task, TestTask) { + async fn set_task_lease( + pool: &PgPool, + id: Uuid, + worker_id: Uuid, + lock_expires_at: DateTime, + ) { + sqlx::query!( + r#" + UPDATE pg_task + SET locked_by = $2, + lock_expires_at = $3 + WHERE id = $1 + "#, + id, + worker_id, + lock_expires_at, + ) + .execute(pool) + .await + .unwrap(); + } + + async fn claim_task(pool: &PgPool, step: TestTask, tried: i32) -> (Task, TestTask, TaskLease) { let id = insert_task(pool, &step, tried, false).await; let mut tx = pool.begin().await.unwrap(); let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); assert_eq!(task.id, id); - let claimed = task.claim::(&mut tx).await.unwrap().unwrap(); + let lease = task_lease(); + let claimed = task + .claim::(&mut tx, lease) + .await + .unwrap() + .unwrap(); tx.commit().await.unwrap(); - (task, claimed) + (task, claimed, lease) } async fn fetch_task_row(pool: &PgPool, id: Uuid) -> Option { sqlx::query!( " - SELECT step, wakeup_at, tried, is_running, error + SELECT step, wakeup_at, tried, locked_by, lock_expires_at, error FROM pg_task WHERE id = $1 ", @@ -492,7 +671,8 @@ mod tests { step: row.step, wakeup_at: row.wakeup_at, tried: row.tried, - is_running: row.is_running, + locked_by: row.locked_by, + lock_expires_at: row.lock_expires_at, error: row.error, }) } @@ -526,17 +706,23 @@ mod tests { let mut tx = pool.begin().await.unwrap(); let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - assert!(task.claim::(&mut tx).await.unwrap().is_none()); + assert!(task + .claim::(&mut tx, task_lease()) + .await + .unwrap() + .is_none()); tx.commit().await.unwrap(); - let row = sqlx::query!("SELECT tried, is_running, error FROM pg_task LIMIT 1") - .fetch_one(&pool) - .await - .unwrap(); + let row = + sqlx::query!("SELECT tried, locked_by, lock_expires_at, error FROM pg_task LIMIT 1") + .fetch_one(&pool) + .await + .unwrap(); assert_eq!(row.tried, 0); - assert!(!row.is_running); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); assert!(row.error.is_some()); } @@ -588,13 +774,13 @@ mod tests { async fn mark_running_returns_db_errors_for_update_failures(pool: PgPool) { let id = insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; let task = task_with_step(id, &TestTask::Valid(Valid), 0); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN is_running TO running_state") + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN locked_by TO task_locked_by") .execute(&pool) .await .unwrap(); let mut tx = pool.begin().await.unwrap(); - let err = task.mark_running(&mut tx).await.unwrap_err(); + let err = task.mark_running(&mut tx, task_lease()).await.unwrap_err(); assert_database_error(err); } @@ -609,37 +795,47 @@ mod tests { .unwrap(); let mut tx = pool.begin().await.unwrap(); - let err = task.claim::(&mut tx).await.unwrap_err(); + let err = task + .claim::(&mut tx, task_lease()) + .await + .unwrap_err(); assert_database_error(err); } #[sqlx::test(migrations = "./migrations")] - async fn claim_marks_valid_steps_running(pool: PgPool) { + async fn claim_marks_valid_steps_leased(pool: PgPool) { let id = insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; let mut tx = pool.begin().await.unwrap(); let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - let claimed = task.claim::(&mut tx).await.unwrap(); + let started_at = Utc::now(); + let claimed = task.claim::(&mut tx, task_lease()).await.unwrap(); tx.commit().await.unwrap(); + let finished_at = Utc::now(); assert!(matches!(claimed, Some(TestTask::Valid(Valid)))); let row = fetch_task_row(&pool, id).await.unwrap(); assert_eq!(row.step, serialized_step(&TestTask::Valid(Valid))); assert_eq!(row.tried, 0); - assert!(row.is_running); + assert_eq!(row.locked_by, Some(worker_id())); + assert_timestamp_between( + row.lock_expires_at.unwrap(), + started_at + ChronoDuration::seconds(60), + finished_at + ChronoDuration::seconds(61), + ); assert!(row.error.is_none()); } #[sqlx::test(migrations = "./migrations")] - async fn fetch_ready_ignores_running_errored_and_future_tasks_and_picks_the_earliest_ready_one( + async fn fetch_ready_ignores_leased_errored_and_future_tasks_and_picks_the_earliest_ready_one( pool: PgPool, ) { let now = Utc::now(); let valid = serialized_step(&TestTask::Valid(Valid)); - insert_task_row( + let live_lease = insert_task_row( &pool, &valid, now - ChronoDuration::seconds(3), @@ -648,6 +844,13 @@ mod tests { None, ) .await; + set_task_lease( + &pool, + live_lease, + other_worker_id(), + now + ChronoDuration::seconds(60), + ) + .await; insert_task_row( &pool, &valid, @@ -686,10 +889,38 @@ mod tests { } #[sqlx::test(migrations = "./migrations")] - async fn fetch_next_wakeup_at_returns_the_earliest_visible_eligible_task(pool: PgPool) { + async fn fetch_ready_returns_expired_leased_tasks(pool: PgPool) { let now = Utc::now(); let valid = serialized_step(&TestTask::Valid(Valid)); - insert_task_row( + let expected = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + true, + None, + ) + .await; + set_task_lease( + &pool, + expected, + other_worker_id(), + now - ChronoDuration::seconds(1), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); + tx.commit().await.unwrap(); + + assert_eq!(task.id, expected); + } + + #[sqlx::test(migrations = "./migrations")] + async fn fetch_next_available_at_returns_the_earliest_visible_eligible_task(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let live_lease = insert_task_row( &pool, &valid, now - ChronoDuration::seconds(3), @@ -698,6 +929,13 @@ mod tests { None, ) .await; + set_task_lease( + &pool, + live_lease, + other_worker_id(), + now + ChronoDuration::seconds(2), + ) + .await; insert_task_row( &pool, &valid, @@ -727,7 +965,10 @@ mod tests { .await; let mut tx = pool.begin().await.unwrap(); - let wakeup_at = Task::fetch_next_wakeup_at(&mut tx).await.unwrap().unwrap(); + let wakeup_at = Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .unwrap(); tx.commit().await.unwrap(); let row = fetch_task_row(&pool, expected).await.unwrap(); @@ -737,13 +978,32 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn run_step_completes_tasks(pool: PgPool) { init_tracing(); - let (task, step) = claim_task(&pool, TestTask::Valid(Valid), 0).await; + let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; - task.run_step(&pool, step).await.unwrap(); + task.run_step(&pool, step, lease).await.unwrap(); assert!(fetch_task_row(&pool, task.id).await.is_none()); } + #[sqlx::test(migrations = "./migrations")] + async fn run_step_does_not_complete_tasks_after_losing_the_lease(pool: PgPool) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); + } + #[sqlx::test(migrations = "./migrations")] async fn run_step_returns_db_errors_when_completing_tasks_fails(pool: PgPool) { let step = TestTask::Valid(Valid); @@ -754,18 +1014,18 @@ mod tests { .await .unwrap(); - let err = task.run_step(&pool, step).await.unwrap_err(); + let err = task.run_step(&pool, step, task_lease()).await.unwrap_err(); assert_database_error(err); } #[sqlx::test(migrations = "./migrations")] async fn run_step_saves_immediate_next_step_and_resets_retries(pool: PgPool) { - let (task, step) = + let (task, step, lease) = claim_task(&pool, TestTask::AdvanceNow(AdvanceNow { value: 41 }), 2).await; let started_at = Utc::now(); - task.run_step(&pool, step).await.unwrap(); + task.run_step(&pool, step, lease).await.unwrap(); let finished_at = Utc::now(); let row = fetch_task_row(&pool, task.id).await.unwrap(); @@ -774,7 +1034,8 @@ mod tests { serialized_step(&TestTask::Followup(Followup { value: 42 })), ); assert_eq!(row.tried, 0); - assert!(!row.is_running); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); assert!(row.error.is_none()); assert_timestamp_between( row.wakeup_at, @@ -786,7 +1047,7 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn run_step_saves_delayed_next_step_and_resets_retries(pool: PgPool) { let delay = Duration::from_millis(250); - let (task, step) = claim_task( + let (task, step, lease) = claim_task( &pool, TestTask::AdvanceLater(AdvanceLater { value: 9, @@ -797,7 +1058,7 @@ mod tests { .await; let started_at = Utc::now(); - task.run_step(&pool, step).await.unwrap(); + task.run_step(&pool, step, lease).await.unwrap(); let finished_at = Utc::now(); let row = fetch_task_row(&pool, task.id).await.unwrap(); @@ -807,7 +1068,8 @@ mod tests { serialized_step(&TestTask::Followup(Followup { value: 10 })), ); assert_eq!(row.tried, 0); - assert!(!row.is_running); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); assert!(row.error.is_none()); assert_timestamp_between( row.wakeup_at, @@ -826,7 +1088,7 @@ mod tests { .await .unwrap(); - let err = task.run_step(&pool, step).await.unwrap_err(); + let err = task.run_step(&pool, step, task_lease()).await.unwrap_err(); assert_database_error(err); } @@ -835,17 +1097,18 @@ mod tests { async fn run_step_schedules_retries_before_the_retry_limit(pool: PgPool) { init_tracing(); let retry_delay = >::RETRY_DELAY; - let (task, step) = claim_task(&pool, TestTask::RetryFail(RetryFail), 1).await; + let (task, step, lease) = claim_task(&pool, TestTask::RetryFail(RetryFail), 1).await; let started_at = Utc::now(); - task.run_step(&pool, step).await.unwrap(); + task.run_step(&pool, step, lease).await.unwrap(); let finished_at = Utc::now(); let row = fetch_task_row(&pool, task.id).await.unwrap(); let retry_delay = ChronoDuration::from_std(retry_delay).unwrap(); assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); assert_eq!(row.tried, 2); - assert!(!row.is_running); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); assert!(row.error.is_none()); assert_timestamp_between( row.wakeup_at, @@ -864,7 +1127,7 @@ mod tests { .await .unwrap(); - let err = task.run_step(&pool, step).await.unwrap_err(); + let err = task.run_step(&pool, step, task_lease()).await.unwrap_err(); assert_database_error(err); } @@ -873,16 +1136,18 @@ mod tests { async fn run_step_saves_terminal_errors_after_retry_limit(pool: PgPool) { init_tracing(); let retry_limit = >::RETRY_LIMIT; - let (task, step) = claim_task(&pool, TestTask::RetryFail(RetryFail), retry_limit).await; + let (task, step, lease) = + claim_task(&pool, TestTask::RetryFail(RetryFail), retry_limit).await; let started_at = Utc::now(); - task.run_step(&pool, step).await.unwrap(); + task.run_step(&pool, step, lease).await.unwrap(); let finished_at = Utc::now(); let row = fetch_task_row(&pool, task.id).await.unwrap(); assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); assert_eq!(row.tried, retry_limit + 1); - assert!(!row.is_running); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); assert!(row .error .as_deref() @@ -905,21 +1170,22 @@ mod tests { .await .unwrap(); - let err = task.run_step(&pool, step).await.unwrap_err(); + let err = task.run_step(&pool, step, task_lease()).await.unwrap_err(); assert_database_error(err); } #[sqlx::test(migrations = "./migrations")] async fn run_step_saves_next_step_serialization_failures_as_errors(pool: PgPool) { - let (task, step) = claim_task(&pool, TestTask::BrokenNext(BrokenNext), 0).await; + let (task, step, lease) = claim_task(&pool, TestTask::BrokenNext(BrokenNext), 0).await; - task.run_step(&pool, step).await.unwrap(); + task.run_step(&pool, step, lease).await.unwrap(); let row = fetch_task_row(&pool, task.id).await.unwrap(); assert_eq!(row.step, serialized_step(&TestTask::BrokenNext(BrokenNext)),); assert_eq!(row.tried, 1); - assert!(!row.is_running); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); assert!(row .error .as_deref() diff --git a/src/worker.rs b/src/worker.rs index 9400f4b..a8228cb 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,6 +1,6 @@ use crate::{ listener::Listener, - task::Task, + task::{Task, TaskLease}, util::{db_error, is_connection_error, is_pool_timeout, wait_for_reconnection}, Error, Result, Step, LOST_CONNECTION_SLEEP, }; @@ -8,8 +8,10 @@ use sqlx::postgres::PgPool; use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration}; use tokio::{sync::Semaphore, time::sleep}; use tracing::{error, info, trace, warn}; +use uuid::Uuid; const LOCKED_TASK_RECHECK_DELAY: Duration = Duration::from_millis(100); +const DEFAULT_LEASE_TIMEOUT: Duration = Duration::from_secs(60); /// A worker for processing tasks pub struct Worker { @@ -17,6 +19,8 @@ pub struct Worker { listener: Listener, tasks: PhantomData, concurrency: NonZeroUsize, + worker_id: Uuid, + lease_timeout: Duration, } impl + 'static> Worker { @@ -28,6 +32,8 @@ impl + 'static> Worker { db, listener, concurrency, + worker_id: Uuid::new_v4(), + lease_timeout: DEFAULT_LEASE_TIMEOUT, tasks: PhantomData, } } @@ -38,6 +44,15 @@ impl + 'static> Worker { self } + /// Sets the task lease timeout. + /// + /// If a worker cannot renew this lease before it expires, another worker + /// may reclaim the task. + pub fn with_lease_timeout(mut self, lease_timeout: Duration) -> Self { + self.lease_timeout = lease_timeout; + self + } + /// Runs all ready tasks to completion and waits for new ones pub async fn run(&self) -> Result<()> { self.listener.listen(self.db.clone()).await?; @@ -46,7 +61,7 @@ impl + 'static> Worker { let result = loop { match self.recv_task().await { - Ok(Some((task, step))) => { + Ok(Some((task, step, lease))) => { let permit = semaphore .clone() .acquire_owned() @@ -54,7 +69,7 @@ impl + 'static> Worker { .map_err(Error::UnreachableWorkerSemaphoreClosed)?; let db = self.db.clone(); tokio::spawn(async move { - if let Err(e) = task.run_step(&db, step).await { + if let Err(e) = task.run_step(&db, step, lease).await { error!("[{}] {}", task.id, source_chain::to_string(&e)); }; drop(permit); @@ -75,7 +90,7 @@ impl + 'static> Worker { /// Waits until the next task is ready, marks it running and returns it. /// Returns `None` if the worker is stopped - async fn recv_task(&self) -> Result> { + async fn recv_task(&self) -> Result> { trace!("Receiving the next task"); loop { @@ -91,11 +106,12 @@ impl + 'static> Worker { let mut tx = self.db.begin().await.map_err(db_error!("begin"))?; let Some(task) = Task::fetch_ready(&mut tx).await? else { - let next_wakeup_at = Task::fetch_next_wakeup_at(&mut tx).await?; + let next_available_at = Task::fetch_next_available_at(&mut tx).await?; tx.commit().await.map_err(db_error!("no ready tasks"))?; - if let Some(wakeup_at) = next_wakeup_at { - let delay = Task::delay_until(wakeup_at).unwrap_or(LOCKED_TASK_RECHECK_DELAY); + if let Some(available_at) = next_available_at { + let delay = + Task::delay_until(available_at).unwrap_or(LOCKED_TASK_RECHECK_DELAY); table_changes.wait_for(delay).await; } else { table_changes.wait_forever().await; @@ -103,12 +119,13 @@ impl + 'static> Worker { continue; }; - let Some(step) = task.claim(&mut tx).await? else { + let lease = TaskLease::new(self.worker_id, self.lease_timeout); + let Some(step) = task.claim(&mut tx, lease).await? else { tx.commit().await.map_err(db_error!("save error"))?; continue; }; tx.commit().await.map_err(db_error!("mark running"))?; - return Ok(Some((task, step))); + return Ok(Some((task, step, lease))); } } @@ -173,7 +190,7 @@ mod tests { use super::Worker; use crate::{Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; - use sqlx::{postgres::PgPoolOptions, types::Uuid, PgPool}; + use sqlx::{postgres::PgPoolOptions, PgPool}; use std::{ collections::HashMap, io, @@ -188,6 +205,7 @@ mod tests { sync::{Notify, Semaphore}, time::{sleep, timeout}, }; + use uuid::Uuid; fn init_tracing() { static INIT: std::sync::Once = std::sync::Once::new(); @@ -419,18 +437,27 @@ mod tests { pool: &PgPool, step: &str, wakeup_at: chrono::DateTime, - is_running: bool, + is_leased: bool, error: Option<&str>, ) -> Uuid { + let (locked_by, lock_expires_at) = if is_leased { + ( + Some(Uuid::from_u128(1)), + Some(Utc::now() + ChronoDuration::seconds(60)), + ) + } else { + (None, None) + }; sqlx::query!( " - INSERT INTO pg_task (step, wakeup_at, is_running, error) - VALUES ($1, $2, $3, $4) + INSERT INTO pg_task (step, wakeup_at, locked_by, lock_expires_at, error) + VALUES ($1, $2, $3, $4, $5) RETURNING id ", step, wakeup_at, - is_running, + locked_by, + lock_expires_at, error, ) .fetch_one(pool) @@ -443,28 +470,45 @@ mod tests { pool: &PgPool, step: &TestTask, wakeup_at: chrono::DateTime, - is_running: bool, + is_leased: bool, ) -> Uuid { insert_raw_task( pool, &serde_json::to_string(step).unwrap(), wakeup_at, - is_running, + is_leased, None, ) .await } - async fn insert_task(pool: &PgPool, step: &TestTask, is_running: bool) { + async fn insert_task(pool: &PgPool, step: &TestTask, is_leased: bool) { insert_task_at( pool, step, Utc::now() - ChronoDuration::milliseconds(1), - is_running, + is_leased, ) .await; } + async fn set_task_lease(pool: &PgPool, id: Uuid, lock_expires_at: chrono::DateTime) { + sqlx::query!( + r#" + UPDATE pg_task + SET locked_by = $2, + lock_expires_at = $3 + WHERE id = $1 + "#, + id, + Uuid::from_u128(1), + lock_expires_at, + ) + .execute(pool) + .await + .unwrap(); + } + async fn connect_to_current_db( pool: &PgPool, max_connections: u32, @@ -663,7 +707,34 @@ mod tests { tx.rollback().await.unwrap(); - let (task, step) = timeout(Duration::from_secs(1), recv) + let (task, step, _lease) = timeout(Duration::from_secs(1), recv) + .await + .unwrap() + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(task.id, id); + assert!(matches!(step, TestTask::Noop(Noop))); + } + + #[sqlx::test(migrations = "./migrations")] + async fn recv_task_rechecks_leased_tasks_when_their_lease_expires(pool: PgPool) { + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + true, + ) + .await; + set_task_lease(&pool, id, Utc::now() + ChronoDuration::milliseconds(100)).await; + + let worker = Worker::::new(pool); + let recv = tokio::spawn(async move { worker.recv_task().await }); + + sleep(Duration::from_millis(50)).await; + assert!(!recv.is_finished()); + + let (task, step, _lease) = timeout(Duration::from_secs(1), recv) .await .unwrap() .unwrap() @@ -695,18 +766,19 @@ mod tests { let first_recv = tokio::spawn(async move { first_worker.recv_task().await }); let second_recv = tokio::spawn(async move { second_worker.recv_task().await }); - let (first_task, first_step) = timeout(Duration::from_secs(1), first_recv) - .await - .unwrap() - .unwrap() - .unwrap() - .unwrap(); - let (second_task, second_step) = timeout(Duration::from_secs(1), second_recv) + let (first_task, first_step, _first_lease) = timeout(Duration::from_secs(1), first_recv) .await .unwrap() .unwrap() .unwrap() .unwrap(); + let (second_task, second_step, _second_lease) = + timeout(Duration::from_secs(1), second_recv) + .await + .unwrap() + .unwrap() + .unwrap() + .unwrap(); assert!(matches!(first_step, TestTask::Noop(Noop))); assert!(matches!(second_step, TestTask::Noop(Noop))); @@ -714,7 +786,7 @@ mod tests { assert!([first_id, second_id].contains(&first_task.id)); assert!([first_id, second_id].contains(&second_task.id)); - let running = sqlx::query!("SELECT id FROM pg_task WHERE is_running = true") + let running = sqlx::query!("SELECT id FROM pg_task WHERE locked_by IS NOT NULL") .fetch_all(&pool) .await .unwrap(); @@ -1009,7 +1081,7 @@ mod tests { .unwrap(); let invalid_row = sqlx::query!( - "SELECT tried, is_running, error FROM pg_task WHERE id = $1", + "SELECT tried, locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", invalid_id, ) .fetch_one(&pool) @@ -1018,7 +1090,8 @@ mod tests { assert_eq!(state.state().events(), vec!["complete"]); assert_eq!(invalid_row.tried, 0); - assert!(!invalid_row.is_running); + assert!(invalid_row.locked_by.is_none()); + assert!(invalid_row.lock_expires_at.is_none()); assert!(invalid_row.error.is_some()); assert_eq!(task_count(&pool).await, 1); } From bf68a6d20b1d3456f37b4823bd5a9e4f1ef35309 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 18:44:43 +0600 Subject: [PATCH 03/44] Renew task leases from worker heartbeat --- ...bb94fc5f78b9f934329c54515f7168b3b905b.json | 15 ++ migrations/20260509130000_task-leases.sql | 23 +++ src/task.rs | 80 +++++++++++ src/worker.rs | 135 +++++++++++++++++- 4 files changed, 248 insertions(+), 5 deletions(-) create mode 100644 .sqlx/query-fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b.json create mode 100644 migrations/20260509130000_task-leases.sql diff --git a/.sqlx/query-fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b.json b/.sqlx/query-fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b.json new file mode 100644 index 0000000..a2ac801 --- /dev/null +++ b/.sqlx/query-fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET lock_expires_at = $2\n WHERE locked_by = $1\n AND lock_expires_at > now()\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b" +} diff --git a/migrations/20260509130000_task-leases.sql b/migrations/20260509130000_task-leases.sql new file mode 100644 index 0000000..ef9483c --- /dev/null +++ b/migrations/20260509130000_task-leases.sql @@ -0,0 +1,23 @@ +ALTER TABLE pg_task +ADD COLUMN locked_by UUID, +ADD COLUMN lock_expires_at timestamptz; + +UPDATE pg_task +SET locked_by = gen_random_uuid(), + lock_expires_at = now() +WHERE is_running = true; + +ALTER TABLE pg_task +ADD CONSTRAINT pg_task_lease_state_check CHECK ( + (locked_by IS NULL AND lock_expires_at IS NULL) + OR (locked_by IS NOT NULL AND lock_expires_at IS NOT NULL) +); + +CREATE INDEX pg_task_lock_expires_at_idx ON pg_task (lock_expires_at) +WHERE locked_by IS NOT NULL + AND error IS NULL; + +ALTER TABLE pg_task DROP COLUMN is_running; + +COMMENT ON COLUMN pg_task.locked_by IS 'Worker currently owning the running step lease'; +COMMENT ON COLUMN pg_task.lock_expires_at IS 'Time when the running step lease expires and can be reclaimed'; diff --git a/src/task.rs b/src/task.rs index e997306..370538c 100644 --- a/src/task.rs +++ b/src/task.rs @@ -120,6 +120,25 @@ impl Task { Ok(()) } + /// Renews all live task leases owned by a worker. + pub(crate) async fn renew_leases(db: &PgPool, lease: TaskLease) -> Result { + trace!("Renewing task leases for worker {}", lease.worker_id); + sqlx::query!( + r#" + UPDATE pg_task + SET lock_expires_at = $2 + WHERE locked_by = $1 + AND lock_expires_at > now() + "#, + lease.worker_id, + lease.expires_at(), + ) + .execute(db) + .await + .map(|result| result.rows_affected()) + .map_err(db_error!("renew leases")) + } + /// Deserializes the current task step and marks it running. /// If deserialization fails, stores the error instead and leaves the task /// non-running. @@ -828,6 +847,67 @@ mod tests { assert!(row.error.is_none()); } + #[sqlx::test(migrations = "./migrations")] + async fn renew_leases_extends_only_live_owned_leases(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let owned = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + let owned_expires_at = now + ChronoDuration::seconds(30); + set_task_lease(&pool, owned, worker_id(), owned_expires_at).await; + let expired = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + let expired_expires_at = now - ChronoDuration::seconds(1); + set_task_lease(&pool, expired, worker_id(), expired_expires_at).await; + let other_worker = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + let other_worker_expires_at = now + ChronoDuration::seconds(30); + set_task_lease( + &pool, + other_worker, + other_worker_id(), + other_worker_expires_at, + ) + .await; + + let started_at = Utc::now(); + let renewed = Task::renew_leases(&pool, task_lease()).await.unwrap(); + let finished_at = Utc::now(); + + assert_eq!(renewed, 1); + let owned = fetch_task_row(&pool, owned).await.unwrap(); + assert_timestamp_between( + owned.lock_expires_at.unwrap(), + started_at + ChronoDuration::seconds(60), + finished_at + ChronoDuration::seconds(61), + ); + let expired = fetch_task_row(&pool, expired).await.unwrap(); + assert!(expired.lock_expires_at.unwrap() < started_at); + let other_worker = fetch_task_row(&pool, other_worker).await.unwrap(); + assert!(other_worker.lock_expires_at.unwrap() < started_at + ChronoDuration::seconds(45)); + } + #[sqlx::test(migrations = "./migrations")] async fn fetch_ready_ignores_leased_errored_and_future_tasks_and_picks_the_earliest_ready_one( pool: PgPool, diff --git a/src/worker.rs b/src/worker.rs index a8228cb..06adb79 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -6,12 +6,16 @@ use crate::{ }; use sqlx::postgres::PgPool; use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration}; -use tokio::{sync::Semaphore, time::sleep}; +use tokio::{ + sync::{mpsc, Semaphore}, + time::{interval, sleep, MissedTickBehavior}, +}; use tracing::{error, info, trace, warn}; use uuid::Uuid; const LOCKED_TASK_RECHECK_DELAY: Duration = Duration::from_millis(100); const DEFAULT_LEASE_TIMEOUT: Duration = Duration::from_secs(60); +const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(20); /// A worker for processing tasks pub struct Worker { @@ -21,6 +25,7 @@ pub struct Worker { concurrency: NonZeroUsize, worker_id: Uuid, lease_timeout: Duration, + heartbeat_interval: Duration, } impl + 'static> Worker { @@ -34,6 +39,7 @@ impl + 'static> Worker { concurrency, worker_id: Uuid::new_v4(), lease_timeout: DEFAULT_LEASE_TIMEOUT, + heartbeat_interval: DEFAULT_HEARTBEAT_INTERVAL, tasks: PhantomData, } } @@ -53,14 +59,34 @@ impl + 'static> Worker { self } + /// Sets how often the worker renews leases for its running tasks. + pub fn with_heartbeat_interval(mut self, heartbeat_interval: Duration) -> Self { + assert!( + !heartbeat_interval.is_zero(), + "heartbeat interval must be non-zero" + ); + self.heartbeat_interval = heartbeat_interval; + self + } + /// Runs all ready tasks to completion and waits for new ones pub async fn run(&self) -> Result<()> { self.listener.listen(self.db.clone()).await?; let semaphore = Arc::new(Semaphore::new(self.concurrency.get())); + let heartbeat = self.spawn_heartbeat(); + let (step_error_sender, mut step_errors) = mpsc::unbounded_channel(); let result = loop { - match self.recv_task().await { + let received = tokio::select! { + biased; + + Some(error) = step_errors.recv() => { + break Err(error); + } + received = self.recv_task() => received, + }; + match received { Ok(Some((task, step, lease))) => { let permit = semaphore .clone() @@ -68,9 +94,11 @@ impl + 'static> Worker { .await .map_err(Error::UnreachableWorkerSemaphoreClosed)?; let db = self.db.clone(); + let step_error_sender = step_error_sender.clone(); tokio::spawn(async move { if let Err(e) = task.run_step(&db, step, lease).await { error!("[{}] {}", task.id, source_chain::to_string(&e)); + let _ = step_error_sender.send(e); }; drop(permit); }); @@ -85,7 +113,7 @@ impl + 'static> Worker { } } }; - self.finish_run(result, semaphore).await + self.finish_run(result, semaphore, heartbeat).await } /// Waits until the next task is ready, marks it running and returns it. @@ -151,11 +179,42 @@ impl + 'static> Worker { } } - async fn finish_run(&self, result: Result<()>, semaphore: Arc) -> Result<()> { + fn spawn_heartbeat(&self) -> tokio::task::AbortHandle { + let db = self.db.clone(); + let lease = TaskLease::new(self.worker_id, self.lease_timeout); + let mut heartbeat = interval(self.heartbeat_interval); + heartbeat.set_missed_tick_behavior(MissedTickBehavior::Delay); + tokio::spawn(async move { + loop { + heartbeat.tick().await; + match Task::renew_leases(&db, lease).await { + Ok(renewed) if renewed > 0 => { + trace!("Renewed {renewed} task leases"); + } + Ok(_) => {} + Err(error) => { + warn!( + "Task lease renewal failed:\n{}", + source_chain::to_string(&error) + ); + } + } + } + }) + .abort_handle() + } + + async fn finish_run( + &self, + result: Result<()>, + semaphore: Arc, + heartbeat: tokio::task::AbortHandle, + ) -> Result<()> { self.listener.shutdown(); // Drain in-flight steps before returning so a restarted worker can't // reclaim them as stale while they are still running. self.wait_for_steps_to_finish(semaphore).await; + heartbeat.abort(); if result.is_ok() { info!("Stopped"); } @@ -509,6 +568,28 @@ mod tests { .unwrap(); } + async fn fetch_task_lease(pool: &PgPool, id: Uuid) -> Option<(Uuid, chrono::DateTime)> { + sqlx::query!( + " + SELECT locked_by, lock_expires_at + FROM pg_task + WHERE id = $1 + ", + id, + ) + .fetch_optional(pool) + .await + .unwrap() + .map(|row| (row.locked_by.unwrap(), row.lock_expires_at.unwrap())) + } + + fn idle_heartbeat() -> tokio::task::AbortHandle { + tokio::spawn(async { + std::future::pending::<()>().await; + }) + .abort_handle() + } + async fn connect_to_current_db( pool: &PgPool, max_connections: u32, @@ -793,6 +874,49 @@ mod tests { assert_eq!(running.len(), 2); } + #[sqlx::test(migrations = "./migrations")] + async fn run_renews_leases_for_running_tasks(pool: PgPool) { + let state = StepStateGuard::new(); + let id = insert_task_at( + &pool, + &TestTask::Blocking(Blocking { key: state.key() }), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + + let worker = tokio::spawn({ + let pool = pool.clone(); + async move { + Worker::::new(pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(200)) + .with_heartbeat_interval(Duration::from_millis(50)) + .run() + .await + } + }); + + state.state().wait_for_events(1).await; + let (locked_by, initial_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + + sleep(Duration::from_millis(350)).await; + + let (renewed_by, renewed_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + assert_eq!(renewed_by, locked_by); + assert!(renewed_expires_at > initial_expires_at); + assert!(renewed_expires_at > Utc::now()); + + stop_worker(&pool).await; + state.state().release(); + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + } + #[tokio::test] async fn finish_run_waits_for_inflight_steps_before_returning_errors() { init_tracing(); @@ -817,6 +941,7 @@ mod tests { "listener failed".into(), ))), semaphore, + idle_heartbeat(), ) .await } @@ -1127,7 +1252,7 @@ mod tests { } #[sqlx::test(migrations = "./migrations")] - async fn run_returns_fetch_errors_after_spawned_step_persistence_errors(pool: PgPool) { + async fn run_returns_spawned_step_persistence_errors(pool: PgPool) { let state = StepStateGuard::new(); insert_task( &pool, From 753210461a2935fa0da0e6450f92a472c9ab5f42 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 22:33:33 +0600 Subject: [PATCH 04/44] Stop workers when task lease heartbeats expire --- src/worker.rs | 258 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 250 insertions(+), 8 deletions(-) diff --git a/src/worker.rs b/src/worker.rs index 06adb79..af0d6b8 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -5,7 +5,12 @@ use crate::{ Error, Result, Step, LOST_CONNECTION_SLEEP, }; use sqlx::postgres::PgPool; -use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration}; +use std::{ + marker::PhantomData, + num::NonZeroUsize, + sync::{Arc, Mutex}, + time::{Duration, Instant}, +}; use tokio::{ sync::{mpsc, Semaphore}, time::{interval, sleep, MissedTickBehavior}, @@ -17,6 +22,12 @@ const LOCKED_TASK_RECHECK_DELAY: Duration = Duration::from_millis(100); const DEFAULT_LEASE_TIMEOUT: Duration = Duration::from_secs(60); const DEFAULT_HEARTBEAT_INTERVAL: Duration = Duration::from_secs(20); +enum HeartbeatEvent { + Failed, + Recovered, + Expired(Error), +} + /// A worker for processing tasks pub struct Worker { db: PgPool, @@ -74,8 +85,12 @@ impl + 'static> Worker { self.listener.listen(self.db.clone()).await?; let semaphore = Arc::new(Semaphore::new(self.concurrency.get())); - let heartbeat = self.spawn_heartbeat(); + let running_steps = Arc::new(Mutex::new(Vec::new())); + let (heartbeat_events_sender, mut heartbeat_events) = mpsc::unbounded_channel(); + let heartbeat = self.spawn_heartbeat(heartbeat_events_sender, running_steps.clone()); let (step_error_sender, mut step_errors) = mpsc::unbounded_channel(); + let mut heartbeat_healthy = true; + let mut abort_running_steps = false; let result = loop { let received = tokio::select! { @@ -84,7 +99,33 @@ impl + 'static> Worker { Some(error) = step_errors.recv() => { break Err(error); } - received = self.recv_task() => received, + Some(event) = heartbeat_events.recv() => { + if let Err(error) = Self::handle_heartbeat_event(event, &mut heartbeat_healthy) { + abort_running_steps = true; + break Err(error); + } + continue; + } + _ = sleep(LOCKED_TASK_RECHECK_DELAY), if !heartbeat_healthy => { + if self.listener.time_to_stop_worker() { + break Ok(()); + } + if let Some(error) = self.listener.take_error() { + if let Err(error) = self + .handle_recv_task_error_or_heartbeat( + error, + &mut heartbeat_events, + &mut heartbeat_healthy, + &mut abort_running_steps, + ) + .await + { + break Err(error); + } + } + continue; + } + received = self.recv_task(), if heartbeat_healthy => received, }; match received { Ok(Some((task, step, lease))) => { @@ -95,25 +136,41 @@ impl + 'static> Worker { .map_err(Error::UnreachableWorkerSemaphoreClosed)?; let db = self.db.clone(); let step_error_sender = step_error_sender.clone(); - tokio::spawn(async move { + let step = tokio::spawn(async move { if let Err(e) = task.run_step(&db, step, lease).await { error!("[{}] {}", task.id, source_chain::to_string(&e)); let _ = step_error_sender.send(e); }; drop(permit); }); + Self::track_running_step(&running_steps, step.abort_handle()); } Ok(None) => { break Ok(()); } Err(e) => { - if let Err(error) = self.handle_recv_task_error(e).await { + if let Err(error) = self + .handle_recv_task_error_or_heartbeat( + e, + &mut heartbeat_events, + &mut heartbeat_healthy, + &mut abort_running_steps, + ) + .await + { break Err(error); } } } }; - self.finish_run(result, semaphore, heartbeat).await + self.finish_run( + result, + semaphore, + heartbeat, + running_steps, + abort_running_steps, + ) + .await } /// Waits until the next task is ready, marks it running and returns it. @@ -157,6 +214,49 @@ impl + 'static> Worker { } } + fn handle_heartbeat_event(event: HeartbeatEvent, heartbeat_healthy: &mut bool) -> Result<()> { + match event { + HeartbeatEvent::Failed => { + if *heartbeat_healthy { + warn!("Task fetching paused because task leases are not renewing"); + } + *heartbeat_healthy = false; + Ok(()) + } + HeartbeatEvent::Recovered => { + if !*heartbeat_healthy { + warn!("Task lease renewal recovered; task fetching resumed"); + } + *heartbeat_healthy = true; + Ok(()) + } + HeartbeatEvent::Expired(error) => Err(error), + } + } + + async fn handle_recv_task_error_or_heartbeat( + &self, + error: Error, + heartbeat_events: &mut mpsc::UnboundedReceiver, + heartbeat_healthy: &mut bool, + abort_running_steps: &mut bool, + ) -> Result<()> { + let handle_error = self.handle_recv_task_error(error); + tokio::pin!(handle_error); + + loop { + tokio::select! { + result = &mut handle_error => return result, + Some(event) = heartbeat_events.recv() => { + if let Err(error) = Self::handle_heartbeat_event(event, heartbeat_healthy) { + *abort_running_steps = true; + return Err(error); + } + } + } + } + } + async fn handle_recv_task_error(&self, error: Error) -> Result<()> { if matches!(&error, Error::Db(db_error, _) if is_connection_error(db_error)) { warn!( @@ -179,24 +279,55 @@ impl + 'static> Worker { } } - fn spawn_heartbeat(&self) -> tokio::task::AbortHandle { + fn spawn_heartbeat( + &self, + events: mpsc::UnboundedSender, + running_steps: Arc>>, + ) -> tokio::task::AbortHandle { let db = self.db.clone(); let lease = TaskLease::new(self.worker_id, self.lease_timeout); let mut heartbeat = interval(self.heartbeat_interval); + let heartbeat_interval = self.heartbeat_interval; + let lease_timeout = self.lease_timeout; heartbeat.set_missed_tick_behavior(MissedTickBehavior::Delay); tokio::spawn(async move { + let mut last_renewed_at = Instant::now(); + let mut renewal_failed = false; + heartbeat.tick().await; loop { heartbeat.tick().await; match Task::renew_leases(&db, lease).await { Ok(renewed) if renewed > 0 => { trace!("Renewed {renewed} task leases"); + last_renewed_at = Instant::now(); + if renewal_failed { + let _ = events.send(HeartbeatEvent::Recovered); + renewal_failed = false; + } + } + Ok(_) => { + last_renewed_at = Instant::now(); + if renewal_failed { + let _ = events.send(HeartbeatEvent::Recovered); + renewal_failed = false; + } } - Ok(_) => {} Err(error) => { warn!( "Task lease renewal failed:\n{}", source_chain::to_string(&error) ); + if !renewal_failed { + let _ = events.send(HeartbeatEvent::Failed); + renewal_failed = true; + } + if Self::has_running_steps(&running_steps) + && last_renewed_at.elapsed().saturating_add(heartbeat_interval) + >= lease_timeout + { + let _ = events.send(HeartbeatEvent::Expired(error)); + break; + } } } } @@ -204,13 +335,47 @@ impl + 'static> Worker { .abort_handle() } + fn track_running_step( + running_steps: &Mutex>, + step: tokio::task::AbortHandle, + ) { + let mut running_steps = running_steps + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + running_steps.retain(|step| !step.is_finished()); + running_steps.push(step); + } + + fn abort_running_steps(running_steps: &Mutex>) { + let running_steps = running_steps + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + for step in &*running_steps { + step.abort(); + } + } + + fn has_running_steps(running_steps: &Mutex>) -> bool { + running_steps + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .iter() + .any(|step| !step.is_finished()) + } + async fn finish_run( &self, result: Result<()>, semaphore: Arc, heartbeat: tokio::task::AbortHandle, + running_steps: Arc>>, + abort_running_steps: bool, ) -> Result<()> { self.listener.shutdown(); + if abort_running_steps { + heartbeat.abort(); + Self::abort_running_steps(&running_steps); + } // Drain in-flight steps before returning so a restarted worker can't // reclaim them as stale while they are still running. self.wait_for_steps_to_finish(semaphore).await; @@ -917,6 +1082,81 @@ mod tests { .unwrap(); } + #[sqlx::test(migrations = "./migrations")] + async fn run_aborts_running_steps_when_heartbeat_cannot_renew_before_the_lease_expires( + pool: PgPool, + ) { + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: state.key() }), + false, + ) + .await; + + let worker = tokio::spawn({ + let pool = pool.clone(); + async move { + Worker::::new(pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(250)) + .with_heartbeat_interval(Duration::from_millis(100)) + .run() + .await + } + }); + + state.state().wait_for_events(1).await; + sleep(Duration::from_millis(50)).await; + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN lock_expires_at TO task_lock_expires_at") + .execute(&pool) + .await + .unwrap(); + + let err = timeout(Duration::from_secs(2), worker) + .await + .unwrap() + .unwrap() + .unwrap_err(); + + assert_eq!(state.state().events(), vec!["started"]); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); + } + + #[sqlx::test(migrations = "./migrations")] + async fn run_pauses_fetching_while_heartbeat_cannot_renew(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let state = StepStateGuard::new(); + + let worker = tokio::spawn(async move { + Worker::::new(worker_pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(200)) + .with_heartbeat_interval(Duration::from_millis(50)) + .run() + .await + }); + + sleep(Duration::from_millis(100)).await; + insert_task( + &pool, + &TestTask::Complete(Complete { key: state.key() }), + false, + ) + .await; + sleep(Duration::from_millis(150)).await; + + stop_worker(&pool).await; + + timeout(Duration::from_secs(2), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert!(state.state().events().is_empty()); + } + #[tokio::test] async fn finish_run_waits_for_inflight_steps_before_returning_errors() { init_tracing(); @@ -942,6 +1182,8 @@ mod tests { ))), semaphore, idle_heartbeat(), + Arc::new(Mutex::new(Vec::new())), + false, ) .await } From 31827266a1cf110c43e815d15e12decfac12dc0d Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 22:38:41 +0600 Subject: [PATCH 05/44] Stop workers when task lease heartbeats expire --- Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Cargo.toml b/Cargo.toml index 68e707f..d4ce06a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ sqlx = { version = "0.8", features = [ "uuid", ] } thiserror = "2" -tokio = "1" +tokio = { version = "1", features = ["macros"] } tracing = "0.1" uuid = { version = "1", features = ["v4"] } From 9b5a5acdf3f69b002bf63f3e7ae3d0a2e3eb1e61 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 22:50:22 +0600 Subject: [PATCH 06/44] Cover heartbeat recovery and reject zero lease timeouts --- src/worker.rs | 75 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/src/worker.rs b/src/worker.rs index af0d6b8..e5ec185 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -66,6 +66,7 @@ impl + 'static> Worker { /// If a worker cannot renew this lease before it expires, another worker /// may reclaim the task. pub fn with_lease_timeout(mut self, lease_timeout: Duration) -> Self { + assert!(!lease_timeout.is_zero(), "lease timeout must be non-zero"); self.lease_timeout = lease_timeout; self } @@ -808,6 +809,28 @@ mod tests { }) } + #[tokio::test] + #[should_panic(expected = "lease timeout must be non-zero")] + async fn with_lease_timeout_rejects_zero() { + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_lease_timeout(Duration::ZERO); + } + + #[tokio::test] + #[should_panic(expected = "heartbeat interval must be non-zero")] + async fn with_heartbeat_interval_rejects_zero() { + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_heartbeat_interval(Duration::ZERO); + } + #[sqlx::test(migrations = "./migrations")] async fn run_returns_listener_startup_errors(pool: PgPool) { let worker = Worker::::new(pool); @@ -1157,6 +1180,58 @@ mod tests { assert!(state.state().events().is_empty()); } + #[sqlx::test(migrations = "./migrations")] + async fn run_resumes_fetching_after_heartbeat_recovers(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 2, Duration::from_millis(20)).await; + let held_connection = worker_pool.acquire().await.unwrap(); + let state = StepStateGuard::new(); + + let worker = tokio::spawn({ + let worker_pool = worker_pool.clone(); + async move { + Worker::::new(worker_pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(300)) + .with_heartbeat_interval(Duration::from_millis(50)) + .run() + .await + } + }); + + sleep(Duration::from_millis(100)).await; + insert_task( + &pool, + &TestTask::Complete(Complete { key: state.key() }), + false, + ) + .await; + sleep(Duration::from_millis(150)).await; + + assert!(state.state().events().is_empty()); + drop(held_connection); + + timeout(Duration::from_secs(3), async { + loop { + if !state.state().events().is_empty() { + break; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + stop_worker(&pool).await; + + timeout(Duration::from_secs(2), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(state.state().events(), vec!["complete"]); + } + #[tokio::test] async fn finish_run_waits_for_inflight_steps_before_returning_errors() { init_tracing(); From 58585d492cbeb8dc5802556fc59188ca149ef5d4 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 22:56:12 +0600 Subject: [PATCH 07/44] Cover aborting steps after lease renewal expiry --- src/worker.rs | 35 +++++++++++++++++++++++++++++++++++ 1 file changed, 35 insertions(+) diff --git a/src/worker.rs b/src/worker.rs index e5ec185..5095e74 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1276,6 +1276,41 @@ mod tests { )); } + #[tokio::test] + async fn finish_run_aborts_inflight_steps_when_lease_renewal_expires() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let running_step = tokio::spawn(async move { + let _permit = permit; + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); + + let err = timeout( + Duration::from_secs(1), + worker.finish_run( + Err(Error::Db(sqlx::Error::PoolTimedOut, "test".into())), + semaphore, + idle_heartbeat(), + running_steps, + true, + ), + ) + .await + .unwrap() + .unwrap_err(); + + assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); + assert!(running_step.await.unwrap_err().is_cancelled()); + } + #[tokio::test] async fn wait_for_steps_to_finish_rechecks_when_the_inflight_task_count_changes() { init_tracing(); From 3e2d50f0360de6d5e85703185cbcfcc564426bb7 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 22:58:19 +0600 Subject: [PATCH 08/44] Cover heartbeat lifetime during graceful shutdown --- src/worker.rs | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/worker.rs b/src/worker.rs index 5095e74..6f9d70e 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1276,6 +1276,54 @@ mod tests { )); } + #[tokio::test] + async fn finish_run_keeps_heartbeat_alive_while_waiting_for_inflight_steps() { + init_tracing(); + let worker = Arc::new( + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)), + ); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let heartbeat = tokio::spawn(async { + std::future::pending::<()>().await; + }); + + let finish = tokio::spawn({ + let worker = worker.clone(); + let semaphore = semaphore.clone(); + let heartbeat_abort = heartbeat.abort_handle(); + async move { + worker + .finish_run( + Ok(()), + semaphore, + heartbeat_abort, + Arc::new(Mutex::new(Vec::new())), + false, + ) + .await + } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!finish.is_finished()); + assert!(!heartbeat.is_finished()); + + drop(permit); + + timeout(Duration::from_secs(1), finish) + .await + .unwrap() + .unwrap() + .unwrap(); + assert!(heartbeat.await.unwrap_err().is_cancelled()); + } + #[tokio::test] async fn finish_run_aborts_inflight_steps_when_lease_renewal_expires() { init_tracing(); From 1afe7d3104c76737e0822654460447748c81d0bc Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 23:03:47 +0600 Subject: [PATCH 09/44] Add worker heartbeat and task availability tests --- src/task.rs | 40 +++++++++++++++++++++++ src/worker.rs | 89 +++++++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 127 insertions(+), 2 deletions(-) diff --git a/src/task.rs b/src/task.rs index 370538c..5708607 100644 --- a/src/task.rs +++ b/src/task.rs @@ -711,6 +711,19 @@ mod tests { ); } + #[test] + fn delay_until_returns_none_for_ready_times() { + assert!(Task::delay_until(Utc::now() - ChronoDuration::milliseconds(1)).is_none()); + } + + #[test] + fn delay_until_returns_duration_for_future_times() { + let delay = Task::delay_until(Utc::now() + ChronoDuration::milliseconds(250)).unwrap(); + + assert!(delay <= Duration::from_millis(250)); + assert!(delay > Duration::ZERO); + } + #[sqlx::test(migrations = "./migrations")] async fn claim_marks_invalid_steps_errored(pool: PgPool) { sqlx::query!( @@ -996,6 +1009,33 @@ mod tests { assert_eq!(task.id, expected); } + #[sqlx::test(migrations = "./migrations")] + async fn fetch_next_available_at_returns_none_when_no_tasks_are_visible(pool: PgPool) { + let mut tx = pool.begin().await.unwrap(); + assert!(Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .is_none()); + tx.commit().await.unwrap(); + + insert_task_row( + &pool, + &serialized_step(&TestTask::Valid(Valid)), + Utc::now() - ChronoDuration::seconds(1), + 0, + false, + Some("boom"), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + assert!(Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .is_none()); + tx.commit().await.unwrap(); + } + #[sqlx::test(migrations = "./migrations")] async fn fetch_next_available_at_returns_the_earliest_visible_eligible_task(pool: PgPool) { let now = Utc::now(); diff --git a/src/worker.rs b/src/worker.rs index 6f9d70e..3041643 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -412,7 +412,7 @@ impl + 'static> Worker { #[cfg(test)] mod tests { - use super::Worker; + use super::{HeartbeatEvent, Worker}; use crate::{Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; use sqlx::{postgres::PgPoolOptions, PgPool}; @@ -427,7 +427,7 @@ mod tests { time::Duration, }; use tokio::{ - sync::{Notify, Semaphore}, + sync::{mpsc, Notify, Semaphore}, time::{sleep, timeout}, }; use uuid::Uuid; @@ -831,6 +831,91 @@ mod tests { .with_heartbeat_interval(Duration::ZERO); } + #[test] + fn heartbeat_events_pause_resume_and_expire_fetching() { + let mut heartbeat_healthy = true; + Worker::::handle_heartbeat_event(HeartbeatEvent::Failed, &mut heartbeat_healthy) + .unwrap(); + assert!(!heartbeat_healthy); + + Worker::::handle_heartbeat_event( + HeartbeatEvent::Recovered, + &mut heartbeat_healthy, + ) + .unwrap(); + assert!(heartbeat_healthy); + + let err = Worker::::handle_heartbeat_event( + HeartbeatEvent::Expired(Error::Db(sqlx::Error::PoolTimedOut, "test".into())), + &mut heartbeat_healthy, + ) + .unwrap_err(); + assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); + } + + #[tokio::test] + async fn heartbeat_expiry_interrupts_retryable_fetch_error_handling() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ); + let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); + heartbeat_events + .send(HeartbeatEvent::Expired(Error::Db( + sqlx::Error::PoolTimedOut, + "heartbeat".into(), + ))) + .unwrap(); + let mut heartbeat_healthy = true; + let mut abort_running_steps = false; + + let err = timeout( + Duration::from_millis(100), + worker.handle_recv_task_error_or_heartbeat( + Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), + &mut heartbeat_events_receiver, + &mut heartbeat_healthy, + &mut abort_running_steps, + ), + ) + .await + .unwrap() + .unwrap_err(); + + assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); + assert!(abort_running_steps); + } + + #[tokio::test] + async fn running_step_tracking_prunes_finished_steps() { + let finished_step = tokio::spawn(async {}); + let finished_step_abort = finished_step.abort_handle(); + finished_step.await.unwrap(); + + let running_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_step_abort = running_step.abort_handle(); + let running_steps = Mutex::new(vec![finished_step_abort]); + + Worker::::track_running_step(&running_steps, running_step_abort); + + assert_eq!( + running_steps + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .len(), + 1, + ); + assert!(Worker::::has_running_steps(&running_steps)); + + Worker::::abort_running_steps(&running_steps); + assert!(running_step.await.unwrap_err().is_cancelled()); + assert!(!Worker::::has_running_steps(&running_steps)); + } + #[sqlx::test(migrations = "./migrations")] async fn run_returns_listener_startup_errors(pool: PgPool) { let worker = Worker::::new(pool); From f1abc79ad7007bf8010f652196199751be24bb4c Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 23:11:50 +0600 Subject: [PATCH 10/44] Expand lease and worker wakeup test coverage --- src/listener.rs | 14 +++++ src/macros.rs | 30 +++++++++- src/task.rs | 156 ++++++++++++++++++++++++++++++++++++++++++++++++ src/worker.rs | 47 +++++++++++++++ 4 files changed, 246 insertions(+), 1 deletion(-) diff --git a/src/listener.rs b/src/listener.rs index d61afdd..fd7106c 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -344,6 +344,20 @@ mod tests { assert!(listener.time_to_stop_worker()); } + #[tokio::test] + async fn wait_for_returns_when_the_timeout_expires_without_a_wakeup() { + let listener = Listener::new(); + + timeout( + Duration::from_millis(100), + listener.subscribe().wait_for(Duration::from_millis(10)), + ) + .await + .unwrap(); + assert!(!listener.time_to_stop_worker()); + assert!(listener.take_error().is_none()); + } + #[tokio::test] async fn pool_timeouts_do_not_become_terminal_listener_errors() { init_tracing(); diff --git a/src/macros.rs b/src/macros.rs index a3aff73..6e18058 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -64,7 +64,14 @@ mod tests { #[derive(Debug, serde::Deserialize, serde::Serialize)] pub(super) struct Second; - crate::task!(MacroTask { First, Second }); + #[derive(Debug, serde::Deserialize, serde::Serialize)] + pub(super) struct Third; + + crate::task!(MacroTask { + First, + Second, + Third + }); crate::scheduler!(MacroScheduler { MacroTask }); #[async_trait::async_trait] @@ -84,6 +91,13 @@ mod tests { } } + #[async_trait::async_trait] + impl crate::Step for Third { + async fn step(self, _db: &PgPool) -> crate::StepResult { + crate::NextStep::now(Second) + } + } + fn assert_scheduler() {} #[test] @@ -132,4 +146,18 @@ mod tests { crate::NextStep::None )); } + + #[tokio::test] + async fn task_macro_forwards_immediate_steps() { + let pool = PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(); + + assert!(matches!( + crate::Step::::step(MacroTask::Third(Third), &pool) + .await + .unwrap(), + crate::NextStep::Now(MacroTask::Second(Second)) + )); + } } diff --git a/src/task.rs b/src/task.rs index 5708607..d0b5357 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1095,6 +1095,70 @@ mod tests { assert_eq!(wakeup_at, row.wakeup_at); } + #[sqlx::test(migrations = "./migrations")] + async fn fetch_next_available_at_returns_lease_expiry_for_ready_leased_tasks(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let expected = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + true, + None, + ) + .await; + set_task_lease( + &pool, + expected, + other_worker_id(), + now + ChronoDuration::seconds(5), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + let wakeup_at = Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .unwrap(); + tx.commit().await.unwrap(); + + let row = fetch_task_row(&pool, expected).await.unwrap(); + assert_eq!(wakeup_at, row.lock_expires_at.unwrap()); + } + + #[sqlx::test(migrations = "./migrations")] + async fn fetch_next_available_at_keeps_future_leased_tasks_delayed_until_wakeup(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let expected = insert_task_row( + &pool, + &valid, + now + ChronoDuration::seconds(5), + 0, + true, + None, + ) + .await; + set_task_lease( + &pool, + expected, + other_worker_id(), + now + ChronoDuration::seconds(1), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + let wakeup_at = Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .unwrap(); + tx.commit().await.unwrap(); + + let row = fetch_task_row(&pool, expected).await.unwrap(); + assert_eq!(wakeup_at, row.wakeup_at); + } + #[sqlx::test(migrations = "./migrations")] async fn run_step_completes_tasks(pool: PgPool) { init_tracing(); @@ -1198,6 +1262,31 @@ mod tests { ); } + #[sqlx::test(migrations = "./migrations")] + async fn run_step_does_not_save_next_step_after_losing_the_lease(pool: PgPool) { + init_tracing(); + let (task, step, lease) = + claim_task(&pool, TestTask::AdvanceNow(AdvanceNow { value: 41 }), 2).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!( + row.step, + serialized_step(&TestTask::AdvanceNow(AdvanceNow { value: 41 })), + ); + assert_eq!(row.tried, 2); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); + } + #[sqlx::test(migrations = "./migrations")] async fn run_step_returns_db_errors_when_saving_next_step_fails(pool: PgPool) { let step = TestTask::AdvanceNow(AdvanceNow { value: 41 }); @@ -1237,6 +1326,27 @@ mod tests { ); } + #[sqlx::test(migrations = "./migrations")] + async fn run_step_does_not_schedule_retries_after_losing_the_lease(pool: PgPool) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::RetryFail(RetryFail), 1).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); + assert_eq!(row.tried, 1); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); + } + #[sqlx::test(migrations = "./migrations")] async fn run_step_returns_db_errors_when_retrying_fails(pool: PgPool) { let step = TestTask::RetryFail(RetryFail); @@ -1279,6 +1389,29 @@ mod tests { ); } + #[sqlx::test(migrations = "./migrations")] + async fn run_step_does_not_save_terminal_errors_after_losing_the_lease(pool: PgPool) { + init_tracing(); + let retry_limit = >::RETRY_LIMIT; + let (task, step, lease) = + claim_task(&pool, TestTask::RetryFail(RetryFail), retry_limit).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); + assert_eq!(row.tried, retry_limit); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); + } + #[sqlx::test(migrations = "./migrations")] async fn run_step_returns_db_errors_when_saving_terminal_errors_fails(pool: PgPool) { let step = TestTask::RetryFail(RetryFail); @@ -1311,4 +1444,27 @@ mod tests { .as_deref() .is_some_and(|error| error.contains("can't serialize test step"))); } + + #[sqlx::test(migrations = "./migrations")] + async fn run_step_does_not_save_next_step_serialization_errors_after_losing_the_lease( + pool: PgPool, + ) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::BrokenNext(BrokenNext), 0).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::BrokenNext(BrokenNext)),); + assert_eq!(row.tried, 0); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); + } } diff --git a/src/worker.rs b/src/worker.rs index 3041643..66716a6 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1037,6 +1037,53 @@ mod tests { } } + #[sqlx::test(migrations = "./migrations")] + async fn recv_task_stops_while_waiting_for_work(pool: PgPool) { + let worker = Arc::new(Worker::::new(pool)); + let recv = tokio::spawn({ + let worker = worker.clone(); + async move { worker.recv_task().await } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!recv.is_finished()); + worker.listener.stop_worker_for_tests(); + + let received = timeout(Duration::from_secs(1), recv) + .await + .unwrap() + .unwrap() + .unwrap(); + assert!(received.is_none()); + } + + #[sqlx::test(migrations = "./migrations")] + async fn recv_task_returns_listener_errors_while_waiting_for_work(pool: PgPool) { + let worker = Arc::new(Worker::::new(pool)); + let recv = tokio::spawn({ + let worker = worker.clone(); + async move { worker.recv_task().await } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!recv.is_finished()); + worker + .listener + .set_error_and_notify_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + + let err = timeout(Duration::from_secs(1), recv) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!( + err, + Error::ListenerReceive(sqlx::Error::Protocol(_)) + )); + } + #[sqlx::test(migrations = "./migrations")] async fn recv_task_rechecks_locked_ready_tasks_without_notifications(pool: PgPool) { let id = insert_task_at( From 9424b0839dba9f2d0a3273cb1bd0fe0dd3b5584b Mon Sep 17 00:00:00 2001 From: imbolc Date: Sat, 9 May 2026 23:52:09 +0600 Subject: [PATCH 11/44] Add coverage-guided worker and task tests --- examples/counter.rs | 24 +++ examples/delay.rs | 13 ++ examples/tutorial.rs | 22 +++ src/error.rs | 17 ++ src/listener.rs | 84 +++++++++ src/macros.rs | 13 ++ src/task.rs | 122 +++++++++++++ src/traits.rs | 54 ++++++ src/util.rs | 19 ++ src/worker.rs | 409 +++++++++++++++++++++++++++++++++++++++++++ 10 files changed, 777 insertions(+) diff --git a/examples/counter.rs b/examples/counter.rs index 734da0a..f922822 100644 --- a/examples/counter.rs +++ b/examples/counter.rs @@ -181,6 +181,30 @@ mod tests { } } + #[tokio::test] + async fn proceed_finishes_single_step_counts() { + let started_at = Utc::now(); + let next = Proceed { + up_to: 1, + started_at, + cur: 0, + } + .step(&lazy_pool()) + .await + .unwrap(); + + match next { + NextStep::Now(Count::Finish(Finish { + up_to, + started_at: finished_started_at, + })) => { + assert_eq!(up_to, 1); + assert_eq!(finished_started_at, started_at); + } + _ => panic!("expected the finish step"), + } + } + #[tokio::test] async fn finish_completes_the_task() { assert!(matches!( diff --git a/examples/delay.rs b/examples/delay.rs index 736e70a..6acd5a1 100644 --- a/examples/delay.rs +++ b/examples/delay.rs @@ -75,6 +75,19 @@ mod tests { } } + #[tokio::test] + async fn sleep_allows_zero_delay_wakeups() { + let next = Sleep(0).step(&lazy_pool()).await.unwrap(); + + match next { + NextStep::Delayed(Sleeper::Wakeup(Wakeup(seconds)), delay) => { + assert_eq!(seconds, 0); + assert_eq!(delay, Duration::ZERO); + } + _ => panic!("expected the delayed wakeup step"), + } + } + #[tokio::test] async fn wakeup_finishes_the_task() { assert!(matches!( diff --git a/examples/tutorial.rs b/examples/tutorial.rs index d335d3b..b4b8ae5 100644 --- a/examples/tutorial.rs +++ b/examples/tutorial.rs @@ -102,6 +102,28 @@ mod tests { } } + #[tokio::test] + async fn read_name_preserves_file_contents_exactly() { + let path = temp_path(); + std::fs::write(&path, "Alice\n").unwrap(); + + let next = ReadName { + filename: path.display().to_string(), + } + .step(&lazy_pool()) + .await + .unwrap(); + + std::fs::remove_file(path).unwrap(); + + match next { + NextStep::Now(Greeter::SayHello(SayHello { name })) => { + assert_eq!(name, "Alice\n"); + } + _ => panic!("expected the greeting step"), + } + } + #[tokio::test] async fn read_name_returns_io_errors_for_missing_files() { let result = ReadName { diff --git a/src/error.rs b/src/error.rs index 322bb90..2b452f1 100644 --- a/src/error.rs +++ b/src/error.rs @@ -33,3 +33,20 @@ pub type StepError = Box; /// Result returning from task steps pub type StepResult = StdResult, StepError>; + +#[cfg(test)] +mod tests { + use super::Error; + + #[test] + fn error_display_messages_are_stable() { + assert_eq!( + Error::Db(sqlx::Error::PoolTimedOut, "fetch task".into()).to_string(), + "db error: fetch task", + ); + assert_eq!( + Error::ListenerReceive(sqlx::Error::PoolClosed).to_string(), + "listener can't receive table change notifications", + ); + } +} diff --git a/src/listener.rs b/src/listener.rs index fd7106c..c33dc52 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -328,6 +328,20 @@ mod tests { )); } + #[tokio::test] + async fn take_error_clears_stored_errors() { + let listener = Listener::new(); + listener.set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + + assert!(matches!( + listener.take_error(), + Some(Error::ListenerReceive(sqlx::Error::Protocol(_))) + )); + assert!(listener.take_error().is_none()); + } + #[tokio::test] async fn wait_for_returns_when_a_wakeup_arrives_before_the_timeout() { let listener = Listener::new(); @@ -492,6 +506,21 @@ mod tests { assert!(second_error.is_cancelled()); } + #[tokio::test] + async fn shutdown_aborts_the_background_task() { + let listener = Listener::new(); + let task = tokio::spawn(pending::<()>()); + listener.set_task_for_tests(task.abort_handle()); + + listener.shutdown(); + + let error = timeout(Duration::from_millis(50), task) + .await + .unwrap() + .unwrap_err(); + assert!(error.is_cancelled()); + } + #[sqlx::test(migrations = "./migrations")] async fn listen_wakes_subscribers_for_task_inserts(pool: PgPool) { let listener = Listener::new(); @@ -515,6 +544,61 @@ mod tests { assert!(listener.take_error().is_none()); } + #[sqlx::test(migrations = "./migrations")] + async fn listen_wakes_subscribers_for_non_stop_notifications(pool: PgPool) { + let listener = Listener::new(); + listener.listen(pool.clone()).await.unwrap(); + + let subscription = listener.subscribe(); + sqlx::query!("NOTIFY pg_task_changed, 'wake'") + .execute(&pool) + .await + .unwrap(); + + timeout(Duration::from_secs(1), subscription.wait_forever()) + .await + .unwrap(); + + assert!(!listener.time_to_stop_worker()); + assert!(listener.take_error().is_none()); + } + + #[sqlx::test(migrations = "./migrations")] + async fn listen_wakes_subscribers_for_task_updates_and_deletes(pool: PgPool) { + let listener = Listener::new(); + listener.listen(pool.clone()).await.unwrap(); + let id = sqlx::query!( + "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2) RETURNING id", + "{}", + Utc::now(), + ) + .fetch_one(&pool) + .await + .unwrap() + .id; + + let update_subscription = listener.subscribe(); + sqlx::query!("UPDATE pg_task SET error = $2 WHERE id = $1", id, "boom",) + .execute(&pool) + .await + .unwrap(); + timeout(Duration::from_secs(1), update_subscription.wait_forever()) + .await + .unwrap(); + + let delete_subscription = listener.subscribe(); + sqlx::query!("DELETE FROM pg_task WHERE id = $1", id) + .execute(&pool) + .await + .unwrap(); + timeout(Duration::from_secs(1), delete_subscription.wait_forever()) + .await + .unwrap(); + + assert!(!listener.time_to_stop_worker()); + assert!(listener.take_error().is_none()); + } + #[sqlx::test(migrations = "./migrations")] async fn stop_worker_notifications_wake_future_subscribers(pool: PgPool) { let listener = Listener::new(); diff --git a/src/macros.rs b/src/macros.rs index 6e18058..d6e915f 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -112,6 +112,19 @@ mod tests { assert_scheduler::(); } + #[sqlx::test(migrations = "./migrations")] + async fn scheduler_macro_schedules_wrapped_tasks(pool: PgPool) { + let task = MacroScheduler::MacroTask(MacroTask::First(First)); + + let id = crate::enqueue(&pool, &task).await.unwrap(); + + let row = sqlx::query!("SELECT step FROM pg_task WHERE id = $1", id) + .fetch_one(&pool) + .await + .unwrap(); + assert_eq!(row.step, serde_json::to_string(&task).unwrap()); + } + #[tokio::test] async fn task_macro_forwards_step_and_retry_metadata() { let pool = PgPoolOptions::new() diff --git a/src/task.rs b/src/task.rs index d0b5357..3263fd2 100644 --- a/src/task.rs +++ b/src/task.rs @@ -860,6 +860,40 @@ mod tests { assert!(row.error.is_none()); } + #[sqlx::test(migrations = "./migrations")] + async fn task_lease_columns_must_be_set_together(pool: PgPool) { + let valid = serialized_step(&TestTask::Valid(Valid)); + let now = Utc::now(); + + let err = sqlx::query!( + " + INSERT INTO pg_task (step, wakeup_at, locked_by) + VALUES ($1, $2, $3) + ", + &valid, + now, + worker_id(), + ) + .execute(&pool) + .await + .unwrap_err(); + assert!(matches!(err, sqlx::Error::Database(_))); + + let err = sqlx::query!( + " + INSERT INTO pg_task (step, wakeup_at, lock_expires_at) + VALUES ($1, $2, $3) + ", + &valid, + now, + now, + ) + .execute(&pool) + .await + .unwrap_err(); + assert!(matches!(err, sqlx::Error::Database(_))); + } + #[sqlx::test(migrations = "./migrations")] async fn renew_leases_extends_only_live_owned_leases(pool: PgPool) { let now = Utc::now(); @@ -921,6 +955,46 @@ mod tests { assert!(other_worker.lock_expires_at.unwrap() < started_at + ChronoDuration::seconds(45)); } + #[sqlx::test(migrations = "./migrations")] + async fn renew_leases_returns_zero_when_no_live_owned_leases_exist(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let expired = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + set_task_lease( + &pool, + expired, + worker_id(), + now - ChronoDuration::seconds(1), + ) + .await; + let other_worker = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + set_task_lease( + &pool, + other_worker, + other_worker_id(), + now + ChronoDuration::seconds(30), + ) + .await; + + assert_eq!(Task::renew_leases(&pool, task_lease()).await.unwrap(), 0); + } + #[sqlx::test(migrations = "./migrations")] async fn fetch_ready_ignores_leased_errored_and_future_tasks_and_picks_the_earliest_ready_one( pool: PgPool, @@ -1009,6 +1083,34 @@ mod tests { assert_eq!(task.id, expected); } + #[sqlx::test(migrations = "./migrations")] + async fn fetch_ready_returns_none_when_no_tasks_are_ready(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + insert_task_row( + &pool, + &valid, + now + ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + Some("boom"), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + assert!(Task::fetch_ready(&mut tx).await.unwrap().is_none()); + tx.commit().await.unwrap(); + } + #[sqlx::test(migrations = "./migrations")] async fn fetch_next_available_at_returns_none_when_no_tasks_are_visible(pool: PgPool) { let mut tx = pool.begin().await.unwrap(); @@ -1188,6 +1290,26 @@ mod tests { assert!(row.error.is_none()); } + #[sqlx::test(migrations = "./migrations")] + async fn run_step_does_not_complete_tasks_after_the_lease_expires(pool: PgPool) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; + set_task_lease( + &pool, + task.id, + worker_id(), + Utc::now() - ChronoDuration::seconds(1), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.locked_by, Some(worker_id())); + assert!(row.lock_expires_at.unwrap() < Utc::now()); + assert!(row.error.is_none()); + } + #[sqlx::test(migrations = "./migrations")] async fn run_step_returns_db_errors_when_completing_tasks_fails(pool: PgPool) { let step = TestTask::Valid(Valid); diff --git a/src/traits.rs b/src/traits.rs index e14bdbc..4501213 100644 --- a/src/traits.rs +++ b/src/traits.rs @@ -175,6 +175,23 @@ mod tests { ); } + #[sqlx::test(migrations = "./migrations")] + async fn zero_delay_schedules_tasks_immediately(pool: PgPool) { + let task = ScheduledTask { value: 12 }; + let started_at = Utc::now(); + + let id = crate::delay(&pool, &task, Duration::ZERO).await.unwrap(); + + let finished_at = Utc::now(); + let row = fetch_task_row(&pool, id).await; + assert_eq!(row.step, serde_json::to_string(&task).unwrap()); + assert_timestamp_between( + row.wakeup_at, + started_at, + finished_at + ChronoDuration::seconds(1), + ); + } + #[sqlx::test(migrations = "./migrations")] async fn enqueue_schedules_tasks_immediately(pool: PgPool) { let task = ScheduledTask { value: 9 }; @@ -192,6 +209,43 @@ mod tests { ); } + #[sqlx::test(migrations = "./migrations")] + async fn schedule_accepts_transaction_executors(pool: PgPool) { + let task = ScheduledTask { value: 10 }; + let at = Utc::now() + ChronoDuration::seconds(5); + let mut tx = pool.begin().await.unwrap(); + + let id = crate::schedule(&mut *tx, &task, at).await.unwrap(); + tx.commit().await.unwrap(); + + let row = fetch_task_row(&pool, id).await; + assert_eq!(row.step, serde_json::to_string(&task).unwrap()); + assert!( + (row.wakeup_at - at) + .num_microseconds() + .is_some_and(|diff| diff.abs() <= 1), + "scheduled time {0:?} should match {1:?}", + row.wakeup_at, + at, + ); + } + + #[sqlx::test(migrations = "./migrations")] + async fn rolled_back_transactions_discard_scheduled_tasks(pool: PgPool) { + let mut tx = pool.begin().await.unwrap(); + + crate::enqueue(&mut *tx, &ScheduledTask { value: 11 }) + .await + .unwrap(); + tx.rollback().await.unwrap(); + + let row_count = sqlx::query!("SELECT id FROM pg_task") + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(row_count.len(), 0); + } + #[sqlx::test(migrations = "./migrations")] async fn schedule_returns_serialization_errors_before_touching_the_database(pool: PgPool) { let err = crate::schedule(&pool, &UnserializableTask, Utc::now()) diff --git a/src/util.rs b/src/util.rs index 46fcdd3..c71c214 100644 --- a/src/util.rs +++ b/src/util.rs @@ -166,6 +166,18 @@ mod tests { assert!(is_retryable_database_error(Some("53300"), true)); } + #[test] + fn documented_database_connection_error_codes_are_retryable() { + for code in [ + "08000", "08001", "08003", "08004", "08006", "08007", "57P01", "57P02", + ] { + assert!( + is_retryable_database_error(Some(code), false), + "{code} should be retryable", + ); + } + } + #[test] fn protocol_violation_is_not_retryable() { assert!(!is_retryable_database_error(Some("08P01"), false)); @@ -185,6 +197,13 @@ mod tests { assert!(matches!(err, crate::Error::Db(sqlx::Error::Database(_), _))); } + #[sqlx::test(migrations = "./migrations")] + async fn wait_for_reconnection_returns_when_the_database_is_available(pool: PgPool) { + wait_for_reconnection(&pool, Duration::from_millis(10)) + .await + .unwrap(); + } + #[sqlx::test(migrations = "./migrations")] async fn wait_for_reconnection_retries_pool_timeouts_until_the_database_is_available( pool: PgPool, diff --git a/src/worker.rs b/src/worker.rs index 66716a6..41ac3ba 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -471,6 +471,11 @@ mod tests { key: u64, } + #[derive(Debug, serde::Deserialize, serde::Serialize)] + pub(super) struct FailStep { + key: u64, + } + crate::task!(TestTask { Noop, Advance, @@ -478,6 +483,7 @@ mod tests { Complete, Blocking, FailSavingError, + FailStep, }); #[async_trait::async_trait] @@ -540,6 +546,18 @@ mod tests { } } + #[async_trait::async_trait] + impl Step for FailStep { + async fn step(self, _db: &PgPool) -> crate::StepResult { + step_state(self.key).record("started"); + Err(io::Error::other("step failed").into()) + } + + fn retry_limit(&self) -> i32 { + 0 + } + } + struct StepState { events: Mutex>, events_changed: Notify, @@ -837,7 +855,16 @@ mod tests { Worker::::handle_heartbeat_event(HeartbeatEvent::Failed, &mut heartbeat_healthy) .unwrap(); assert!(!heartbeat_healthy); + Worker::::handle_heartbeat_event(HeartbeatEvent::Failed, &mut heartbeat_healthy) + .unwrap(); + assert!(!heartbeat_healthy); + Worker::::handle_heartbeat_event( + HeartbeatEvent::Recovered, + &mut heartbeat_healthy, + ) + .unwrap(); + assert!(heartbeat_healthy); Worker::::handle_heartbeat_event( HeartbeatEvent::Recovered, &mut heartbeat_healthy, @@ -888,6 +915,167 @@ mod tests { assert!(abort_running_steps); } + #[tokio::test] + async fn heartbeat_recovery_preserves_retryable_fetch_error_handling() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ); + let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); + heartbeat_events.send(HeartbeatEvent::Failed).unwrap(); + heartbeat_events.send(HeartbeatEvent::Recovered).unwrap(); + let mut heartbeat_healthy = true; + let mut abort_running_steps = false; + + worker + .handle_recv_task_error_or_heartbeat( + Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), + &mut heartbeat_events_receiver, + &mut heartbeat_healthy, + &mut abort_running_steps, + ) + .await + .unwrap(); + + assert!(heartbeat_healthy); + assert!(!abort_running_steps); + } + + #[tokio::test] + async fn heartbeat_failure_pauses_after_retryable_fetch_error_handling() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ); + let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); + heartbeat_events.send(HeartbeatEvent::Failed).unwrap(); + let mut heartbeat_healthy = true; + let mut abort_running_steps = false; + + worker + .handle_recv_task_error_or_heartbeat( + Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), + &mut heartbeat_events_receiver, + &mut heartbeat_healthy, + &mut abort_running_steps, + ) + .await + .unwrap(); + + assert!(!heartbeat_healthy); + assert!(!abort_running_steps); + } + + #[sqlx::test(migrations = "./migrations")] + async fn heartbeat_failures_without_running_steps_do_not_expire(pool: PgPool) { + init_tracing(); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN lock_expires_at TO task_lock_expires_at") + .execute(&pool) + .await + .unwrap(); + let worker = Worker::::new(pool) + .with_lease_timeout(Duration::from_millis(80)) + .with_heartbeat_interval(Duration::from_millis(20)); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat(events, Arc::new(Mutex::new(Vec::new()))); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Failed)); + assert!(timeout(Duration::from_millis(150), events_receiver.recv()) + .await + .is_err()); + + heartbeat.abort(); + } + + #[sqlx::test(migrations = "./migrations")] + async fn heartbeat_reports_recovery_after_renewal_failures_stop(pool: PgPool) { + init_tracing(); + let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let held_connection = worker_pool.acquire().await.unwrap(); + let worker = Worker::::new(worker_pool) + .with_lease_timeout(Duration::from_millis(500)) + .with_heartbeat_interval(Duration::from_millis(20)); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat(events, Arc::new(Mutex::new(Vec::new()))); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Failed)); + + drop(held_connection); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Recovered)); + + heartbeat.abort(); + } + + #[sqlx::test(migrations = "./migrations")] + async fn heartbeat_reports_recovery_after_live_leases_are_renewed(pool: PgPool) { + init_tracing(); + let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let held_connection = worker_pool.acquire().await.unwrap(); + let worker = Worker::::new(worker_pool) + .with_lease_timeout(Duration::from_millis(500)) + .with_heartbeat_interval(Duration::from_millis(20)); + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let initial_expires_at = Utc::now() + ChronoDuration::milliseconds(200); + sqlx::query!( + " + UPDATE pg_task + SET locked_by = $2, + lock_expires_at = $3 + WHERE id = $1 + ", + id, + worker.worker_id, + initial_expires_at, + ) + .execute(&pool) + .await + .unwrap(); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat(events, Arc::new(Mutex::new(Vec::new()))); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Failed)); + + drop(held_connection); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Recovered)); + + let (_locked_by, renewed_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + assert!(renewed_expires_at > initial_expires_at); + + heartbeat.abort(); + } + #[tokio::test] async fn running_step_tracking_prunes_finished_steps() { let finished_step = tokio::spawn(async {}); @@ -1084,6 +1272,41 @@ mod tests { )); } + #[sqlx::test(migrations = "./migrations")] + async fn recv_task_skips_invalid_tasks_and_returns_next_ready_task(pool: PgPool) { + let invalid_id = insert_raw_task( + &pool, + "not-json", + Utc::now() - ChronoDuration::seconds(2), + false, + None, + ) + .await; + let expected = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::seconds(1), + false, + ) + .await; + let worker = Worker::::new(pool.clone()); + + let (task, step, _lease) = worker.recv_task().await.unwrap().unwrap(); + + assert_eq!(task.id, expected); + assert!(matches!(step, TestTask::Noop(Noop))); + let invalid_row = sqlx::query!( + "SELECT locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", + invalid_id, + ) + .fetch_one(&pool) + .await + .unwrap(); + assert!(invalid_row.locked_by.is_none()); + assert!(invalid_row.lock_expires_at.is_none()); + assert!(invalid_row.error.is_some()); + } + #[sqlx::test(migrations = "./migrations")] async fn recv_task_rechecks_locked_ready_tasks_without_notifications(pool: PgPool) { let id = insert_task_at( @@ -1145,6 +1368,28 @@ mod tests { assert!(matches!(step, TestTask::Noop(Noop))); } + #[sqlx::test(migrations = "./migrations")] + async fn recv_task_replaces_expired_lease_with_the_current_worker(pool: PgPool) { + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + true, + ) + .await; + set_task_lease(&pool, id, Utc::now() - ChronoDuration::milliseconds(1)).await; + let worker = Worker::::new(pool.clone()); + let worker_id = worker.worker_id; + + let (task, step, _lease) = worker.recv_task().await.unwrap().unwrap(); + + assert_eq!(task.id, id); + assert!(matches!(step, TestTask::Noop(Noop))); + let (locked_by, lock_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + assert_eq!(locked_by, worker_id); + assert!(lock_expires_at > Utc::now()); + } + #[sqlx::test(migrations = "./migrations")] async fn two_workers_claim_ready_tasks_once(pool: PgPool) { let first_id = insert_task_at( @@ -1312,6 +1557,68 @@ mod tests { assert!(state.state().events().is_empty()); } + #[sqlx::test(migrations = "./migrations")] + async fn run_returns_listener_errors_while_fetching_is_paused(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let worker = Arc::new( + Worker::::new(worker_pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(200)) + .with_heartbeat_interval(Duration::from_millis(50)), + ); + let run = tokio::spawn({ + let worker = worker.clone(); + async move { worker.run().await } + }); + + sleep(Duration::from_millis(1250)).await; + worker + .listener + .set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + + let err = timeout(Duration::from_secs(3), run) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!( + err, + Error::ListenerReceive(sqlx::Error::Protocol(_)) + )); + } + + #[sqlx::test(migrations = "./migrations")] + async fn run_keeps_waiting_after_retryable_errors_while_fetching_is_paused(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let worker = Arc::new( + Worker::::new(worker_pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(200)) + .with_heartbeat_interval(Duration::from_millis(50)), + ); + let run = tokio::spawn({ + let worker = worker.clone(); + async move { worker.run().await } + }); + + sleep(Duration::from_millis(1250)).await; + worker + .listener + .set_error_for_tests(Error::Db(sqlx::Error::PoolTimedOut, "fetch task".into())); + sleep(Duration::from_millis(1300)).await; + assert!(!run.is_finished()); + + stop_worker(&pool).await; + + timeout(Duration::from_secs(3), run) + .await + .unwrap() + .unwrap() + .unwrap(); + } + #[sqlx::test(migrations = "./migrations")] async fn run_resumes_fetching_after_heartbeat_recovers(pool: PgPool) { let worker_pool = connect_to_current_db(&pool, 2, Duration::from_millis(20)).await; @@ -1589,6 +1896,21 @@ mod tests { .unwrap(); } + #[sqlx::test(migrations = "./migrations")] + async fn run_stops_when_stop_notification_arrives_while_idle(pool: PgPool) { + let worker = spawn_worker(pool.clone()); + + sleep(Duration::from_millis(100)).await; + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(task_count(&pool).await, 0); + } + #[sqlx::test(migrations = "./migrations")] async fn run_wakes_up_for_tasks_inserted_while_idle(pool: PgPool) { let state = StepStateGuard::new(); @@ -1806,6 +2128,33 @@ mod tests { assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); } + #[sqlx::test(migrations = "./migrations")] + async fn run_returns_step_errors_from_spawned_tasks(pool: PgPool) { + let state = StepStateGuard::new(); + sqlx::query!("ALTER TABLE pg_task ADD CONSTRAINT reject_errors CHECK (error IS NULL)") + .execute(&pool) + .await + .unwrap(); + insert_task( + &pool, + &TestTask::FailStep(FailStep { key: state.key() }), + false, + ) + .await; + + let worker = spawn_worker(pool); + + state.state().wait_for_events(1).await; + let err = timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap_err(); + + assert_eq!(state.state().events(), vec!["started"]); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); + } + #[sqlx::test(migrations = "./migrations")] async fn run_processes_multiple_blocking_steps_up_to_the_concurrency_limit(pool: PgPool) { let first = StepStateGuard::new(); @@ -1844,4 +2193,64 @@ mod tests { assert_eq!(second.state().events(), vec!["started", "completed"]); assert_eq!(task_count(&pool).await, 0); } + + #[sqlx::test(migrations = "./migrations")] + async fn run_respects_the_configured_concurrency_limit(pool: PgPool) { + let first = StepStateGuard::new(); + let second = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: first.key() }), + false, + ) + .await; + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: second.key() }), + false, + ) + .await; + + let worker = spawn_worker_with_concurrency(pool.clone(), 1); + + timeout(Duration::from_secs(1), async { + loop { + let started_count = usize::from(!first.state().events().is_empty()) + + usize::from(!second.state().events().is_empty()); + if started_count == 1 { + break; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + sleep(Duration::from_millis(100)).await; + let first_started = !first.state().events().is_empty(); + let second_started = !second.state().events().is_empty(); + assert_ne!(first_started, second_started); + + if first_started { + first.state().release(); + second.state().wait_for_events(1).await; + stop_worker(&pool).await; + second.state().release(); + } else { + second.state().release(); + first.state().wait_for_events(1).await; + stop_worker(&pool).await; + first.state().release(); + } + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(first.state().events(), vec!["started", "completed"]); + assert_eq!(second.state().events(), vec!["started", "completed"]); + assert_eq!(task_count(&pool).await, 0); + } } From f1f263b01a9fb2b7ed2995985a09509a2c2a0b81 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sun, 10 May 2026 06:49:11 +0600 Subject: [PATCH 12/44] Fix lease expiry and worker drain edge cases --- Cargo.toml | 2 +- src/error.rs | 6 ++ src/lib.rs | 7 +- src/listener.rs | 4 ++ src/worker.rs | 172 +++++++++++++++++++++++++++++++++++++++++++++--- 5 files changed, 181 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d4ce06a..d007d48 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,7 +23,7 @@ sqlx = { version = "0.8", features = [ "uuid", ] } thiserror = "2" -tokio = { version = "1", features = ["macros"] } +tokio = { version = "1", features = ["macros", "rt", "sync", "time"] } tracing = "0.1" uuid = { version = "1", features = ["v4"] } diff --git a/src/error.rs b/src/error.rs index 2b452f1..dc5fc8d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -21,6 +21,8 @@ pub enum Error { ListenerReceive(#[source] sqlx::Error), /// unreachable: worker semaphore is closed UnreachableWorkerSemaphoreClosed(#[source] tokio::sync::AcquireError), + /// task lease expired before the worker could renew it + TaskLeaseExpired, /// db error: {1} Db(#[source] sqlx::Error, String), } @@ -48,5 +50,9 @@ mod tests { Error::ListenerReceive(sqlx::Error::PoolClosed).to_string(), "listener can't receive table change notifications", ); + assert_eq!( + Error::TaskLeaseExpired.to_string(), + "task lease expired before the worker could renew it", + ); } } diff --git a/src/lib.rs b/src/lib.rs index 1b6b693..9bd8a6d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -271,7 +271,12 @@ The workers would wait until the current step of all the tasks is finished and then exit. You can wait for this by checking for the existence of running tasks: ```sql -SELECT EXISTS(SELECT 1 FROM pg_task WHERE locked_by IS NOT NULL); +SELECT EXISTS( + SELECT 1 + FROM pg_task + WHERE locked_by IS NOT NULL + AND lock_expires_at > now() +); ``` ### Delaying Steps diff --git a/src/listener.rs b/src/listener.rs index c33dc52..e229023 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -567,6 +567,7 @@ mod tests { async fn listen_wakes_subscribers_for_task_updates_and_deletes(pool: PgPool) { let listener = Listener::new(); listener.listen(pool.clone()).await.unwrap(); + let insert_subscription = listener.subscribe(); let id = sqlx::query!( "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2) RETURNING id", "{}", @@ -576,6 +577,9 @@ mod tests { .await .unwrap() .id; + timeout(Duration::from_secs(1), insert_subscription.wait_forever()) + .await + .unwrap(); let update_subscription = listener.subscribe(); sqlx::query!("UPDATE pg_task SET error = $2 WHERE id = $1", id, "boom",) diff --git a/src/worker.rs b/src/worker.rs index 41ac3ba..f72f0ca 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -92,23 +92,30 @@ impl + 'static> Worker { let (step_error_sender, mut step_errors) = mpsc::unbounded_channel(); let mut heartbeat_healthy = true; let mut abort_running_steps = false; + let mut reserved_permit = None; let result = loop { let received = tokio::select! { biased; Some(error) = step_errors.recv() => { + drop(reserved_permit.take()); break Err(error); } Some(event) = heartbeat_events.recv() => { if let Err(error) = Self::handle_heartbeat_event(event, &mut heartbeat_healthy) { abort_running_steps = true; + drop(reserved_permit.take()); break Err(error); } + if !heartbeat_healthy { + drop(reserved_permit.take()); + } continue; } _ = sleep(LOCKED_TASK_RECHECK_DELAY), if !heartbeat_healthy => { if self.listener.time_to_stop_worker() { + drop(reserved_permit.take()); break Ok(()); } if let Some(error) = self.listener.take_error() { @@ -121,20 +128,23 @@ impl + 'static> Worker { ) .await { + drop(reserved_permit.take()); break Err(error); } } continue; } - received = self.recv_task(), if heartbeat_healthy => received, + permit = semaphore.clone().acquire_owned(), if heartbeat_healthy && reserved_permit.is_none() => { + reserved_permit = Some(permit.map_err(Error::UnreachableWorkerSemaphoreClosed)?); + continue; + } + received = self.recv_task(), if heartbeat_healthy && reserved_permit.is_some() => received, }; match received { Ok(Some((task, step, lease))) => { - let permit = semaphore - .clone() - .acquire_owned() - .await - .map_err(Error::UnreachableWorkerSemaphoreClosed)?; + let permit = reserved_permit + .take() + .expect("task fetching requires a reserved semaphore permit"); let db = self.db.clone(); let step_error_sender = step_error_sender.clone(); let step = tokio::spawn(async move { @@ -147,9 +157,11 @@ impl + 'static> Worker { Self::track_running_step(&running_steps, step.abort_handle()); } Ok(None) => { + drop(reserved_permit.take()); break Ok(()); } Err(e) => { + drop(reserved_permit.take()); if let Err(error) = self .handle_recv_task_error_or_heartbeat( e, @@ -159,6 +171,7 @@ impl + 'static> Worker { ) .await { + drop(reserved_permit.take()); break Err(error); } } @@ -170,6 +183,7 @@ impl + 'static> Worker { heartbeat, running_steps, abort_running_steps, + heartbeat_events, ) .await } @@ -306,13 +320,26 @@ impl + 'static> Worker { renewal_failed = false; } } - Ok(_) => { + Ok(_) if !Self::has_running_steps(&running_steps) => { last_renewed_at = Instant::now(); if renewal_failed { let _ = events.send(HeartbeatEvent::Recovered); renewal_failed = false; } } + Ok(_) => { + warn!("Task lease renewal updated no rows while steps are still running"); + if !renewal_failed { + let _ = events.send(HeartbeatEvent::Failed); + renewal_failed = true; + } + if last_renewed_at.elapsed().saturating_add(heartbeat_interval) + >= lease_timeout + { + let _ = events.send(HeartbeatEvent::Expired(Error::TaskLeaseExpired)); + break; + } + } Err(error) => { warn!( "Task lease renewal failed:\n{}", @@ -371,6 +398,7 @@ impl + 'static> Worker { heartbeat: tokio::task::AbortHandle, running_steps: Arc>>, abort_running_steps: bool, + heartbeat_events: mpsc::UnboundedReceiver, ) -> Result<()> { self.listener.shutdown(); if abort_running_steps { @@ -379,7 +407,22 @@ impl + 'static> Worker { } // Drain in-flight steps before returning so a restarted worker can't // reclaim them as stale while they are still running. - self.wait_for_steps_to_finish(semaphore).await; + let result = if abort_running_steps { + self.wait_for_steps_to_finish(semaphore).await; + result + } else { + match self + .wait_for_steps_to_finish_or_heartbeat(semaphore.clone(), heartbeat_events) + .await + { + Ok(()) => result, + Err(error) => { + Self::abort_running_steps(&running_steps); + self.wait_for_steps_to_finish(semaphore).await; + Err(error) + } + } + }; heartbeat.abort(); if result.is_ok() { info!("Stopped"); @@ -408,6 +451,39 @@ impl + 'static> Worker { trace!("The current step of every task is done") } } + + async fn wait_for_steps_to_finish_or_heartbeat( + &self, + semaphore: Arc, + mut heartbeat_events: mpsc::UnboundedReceiver, + ) -> Result<()> { + let mut logged_tasks_left = None; + let mut heartbeat_healthy = true; + loop { + let tasks_left = self.concurrency.get() - semaphore.available_permits(); + if tasks_left == 0 { + break; + } + if let Some(logged) = logged_tasks_left { + if logged != tasks_left { + trace!("Waiting for the current steps of {tasks_left} tasks to finish..."); + } + } else { + info!("Waiting for the current steps of {tasks_left} tasks to finish..."); + } + logged_tasks_left = Some(tasks_left); + tokio::select! { + Some(event) = heartbeat_events.recv() => { + Self::handle_heartbeat_event(event, &mut heartbeat_healthy)?; + } + _ = sleep(Duration::from_secs_f32(0.1)) => {} + } + } + if logged_tasks_left.is_some() { + trace!("The current step of every task is done") + } + Ok(()) + } } #[cfg(test)] @@ -774,6 +850,11 @@ mod tests { .abort_handle() } + fn idle_heartbeat_events() -> mpsc::UnboundedReceiver { + let (_sender, receiver) = mpsc::unbounded_channel(); + receiver + } + async fn connect_to_current_db( pool: &PgPool, max_connections: u32, @@ -995,6 +1076,38 @@ mod tests { heartbeat.abort(); } + #[sqlx::test(migrations = "./migrations")] + async fn heartbeat_expires_when_running_steps_have_no_live_leases(pool: PgPool) { + init_tracing(); + let worker = Worker::::new(pool) + .with_lease_timeout(Duration::from_millis(80)) + .with_heartbeat_interval(Duration::from_millis(20)); + let running_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat(events, running_steps); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Failed)); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!( + event, + HeartbeatEvent::Expired(Error::TaskLeaseExpired) + )); + + heartbeat.abort(); + running_step.abort(); + } + #[sqlx::test(migrations = "./migrations")] async fn heartbeat_reports_recovery_after_renewal_failures_stop(pool: PgPool) { init_tracing(); @@ -1698,6 +1811,7 @@ mod tests { idle_heartbeat(), Arc::new(Mutex::new(Vec::new())), false, + idle_heartbeat_events(), ) .await } @@ -1744,6 +1858,7 @@ mod tests { heartbeat_abort, Arc::new(Mutex::new(Vec::new())), false, + idle_heartbeat_events(), ) .await } @@ -1788,6 +1903,7 @@ mod tests { idle_heartbeat(), running_steps, true, + idle_heartbeat_events(), ), ) .await @@ -1798,6 +1914,46 @@ mod tests { assert!(running_step.await.unwrap_err().is_cancelled()); } + #[tokio::test] + async fn finish_run_aborts_inflight_steps_when_heartbeat_expires_while_draining() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let running_step = tokio::spawn(async move { + let _permit = permit; + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); + let (heartbeat_events_sender, heartbeat_events) = mpsc::unbounded_channel(); + heartbeat_events_sender + .send(HeartbeatEvent::Expired(Error::TaskLeaseExpired)) + .unwrap(); + + let err = timeout( + Duration::from_secs(1), + worker.finish_run( + Ok(()), + semaphore, + idle_heartbeat(), + running_steps, + false, + heartbeat_events, + ), + ) + .await + .unwrap() + .unwrap_err(); + + assert!(matches!(err, Error::TaskLeaseExpired)); + assert!(running_step.await.unwrap_err().is_cancelled()); + } + #[tokio::test] async fn wait_for_steps_to_finish_rechecks_when_the_inflight_task_count_changes() { init_tracing(); From 0368312f4198bc1e675ac0d2267bbaf5f374f656 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sun, 10 May 2026 06:52:57 +0600 Subject: [PATCH 13/44] Add lease notification and renewal indexes --- ...fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json} | 4 ++-- src/task.rs | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) rename .sqlx/{query-fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b.json => query-03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json} (66%) diff --git a/.sqlx/query-fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b.json b/.sqlx/query-03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json similarity index 66% rename from .sqlx/query-fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b.json rename to .sqlx/query-03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json index a2ac801..a4866c3 100644 --- a/.sqlx/query-fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b.json +++ b/.sqlx/query-03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET lock_expires_at = $2\n WHERE locked_by = $1\n AND lock_expires_at > now()\n ", + "query": "\n UPDATE pg_task\n SET lock_expires_at = $2\n WHERE locked_by = $1\n AND lock_expires_at > now()\n AND error IS NULL\n ", "describe": { "columns": [], "parameters": { @@ -11,5 +11,5 @@ }, "nullable": [] }, - "hash": "fdaa8469f4274846a0b397016a5bb94fc5f78b9f934329c54515f7168b3b905b" + "hash": "03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7" } diff --git a/src/task.rs b/src/task.rs index 3263fd2..9592c14 100644 --- a/src/task.rs +++ b/src/task.rs @@ -129,6 +129,7 @@ impl Task { SET lock_expires_at = $2 WHERE locked_by = $1 AND lock_expires_at > now() + AND error IS NULL "#, lease.worker_id, lease.expires_at(), From d4632bf4c66c74b34d69417ff07b25582ab1da2e Mon Sep 17 00:00:00 2001 From: imbolc Date: Sun, 10 May 2026 06:56:51 +0600 Subject: [PATCH 14/44] Add lease notification and index migrations --- .pre-commit.sh | 2 +- ...2d39daffaa7f4394c4b50e69db37f9d24a376.json | 12 +++++ ...c874a976a6583242402773fc3e57861632ef5.json | 12 +++++ ...84e60c4843823da0b06ed3bf42431ec311236.json | 38 ++++++++++++++ ...dae1a16d5c691dc05aa5a4852434180516b1e.json | 12 +++++ ...a1931391252b367b22f56a0058d676e7c72e6.json | 23 ++++++++ ...b9fd49c4605c64604d8f4a03226e923d7e4a1.json | 22 ++++++++ ...bed3c117f2edca11d3976bffad53725c8469d.json | 16 ++++++ ...edeb995c0569dfd937bb2a993637ef6851b41.json | 16 ++++++ ...7ca620d671461c37ea4b95980075c394bc8a2.json | 26 ++++++++++ ...56303e9bd18d97d87667217d7e2d323a52035.json | 34 ++++++++++++ ...01e747634ec03d6f2afa45f6d9f1c697b43e0.json | 20 +++++++ ...9422cc625816045c989302013ec7cb8eafd1a.json | 20 +++++++ ...c83580f50abe7a2a90a95331255e20ee20682.json | 12 +++++ ...591df3dfa0991e9b5acb99331a7bf046cceb1.json | 12 +++++ ...ffd226245f51981a63818a67190fef4f44bf8.json | 27 ++++++++++ ...26e84928e7d4b5ac27fd0118489a04700bc02.json | 28 ++++++++++ ...1fb9cbf4c30c2a767b5314e4cc7803b338cd9.json | 12 +++++ ...ff7d986e6348680f2e386dbb40fd6f37e0fbc.json | 29 +++++++++++ ...6efe8f10ef3e3e056494ca58f0dc013f92643.json | 28 ++++++++++ ...3c4f057a82ce4f453c56e4b6728f5ec9931b6.json | 20 +++++++ ...d737c54c095c4fa8a8e85ecba8593c8a42267.json | 20 +++++++ ...d57eafced6836523a3312cac57fb13396834e.json | 15 ++++++ ...1895da495a64ff6719724150b1f939617971a.json | 40 ++++++++++++++ ...7b650e08c2f7de73485c20758eddb1164c827.json | 12 +++++ ...d9d9486783dcba5e1a8d96a58f074e766db83.json | 15 ++++++ ...e0cf015372eff542ef6220439ae4d8f2ec43e.json | 12 +++++ ...dc35efc017db15316c1b2e1ab8320988f9b83.json | 12 +++++ ...b906b7b52ff09a55a3064ed7943558234b103.json | 14 +++++ ...018d997ccb4affb898df3eb59ba06b2d5a3ce.json | 52 +++++++++++++++++++ ...a198b4801f453ef61e8dcc216432b1b1b947e.json | 22 ++++++++ .../20260510000000_notify-on-delete.sql | 37 +++++++++++++ .../20260510000100_task-lease-indexes.sql | 17 ++++++ 33 files changed, 688 insertions(+), 1 deletion(-) create mode 100644 .sqlx/query-0e27da7580e82a38a7fb3d0d7ab2d39daffaa7f4394c4b50e69db37f9d24a376.json create mode 100644 .sqlx/query-19c3ed955dd0cf1800b7c5d6757c874a976a6583242402773fc3e57861632ef5.json create mode 100644 .sqlx/query-1b13a067bda37fcf4fcf4a081c884e60c4843823da0b06ed3bf42431ec311236.json create mode 100644 .sqlx/query-419e27727faf579f041c979d5fddae1a16d5c691dc05aa5a4852434180516b1e.json create mode 100644 .sqlx/query-489b35c9fb11e7647e75ae3b30ca1931391252b367b22f56a0058d676e7c72e6.json create mode 100644 .sqlx/query-4ce326ceeda8adef3cc7b549e0db9fd49c4605c64604d8f4a03226e923d7e4a1.json create mode 100644 .sqlx/query-4e5e8277069eecc7c3f0380a431bed3c117f2edca11d3976bffad53725c8469d.json create mode 100644 .sqlx/query-4f88b4ad437e4b89cd80abdb978edeb995c0569dfd937bb2a993637ef6851b41.json create mode 100644 .sqlx/query-50e892ba3ad57ab603dbe2516bd7ca620d671461c37ea4b95980075c394bc8a2.json create mode 100644 .sqlx/query-742156444758d69f3a2b634359e56303e9bd18d97d87667217d7e2d323a52035.json create mode 100644 .sqlx/query-85366f4ba53750ca317ef4515b201e747634ec03d6f2afa45f6d9f1c697b43e0.json create mode 100644 .sqlx/query-978c0565523b783d1a3c0140a239422cc625816045c989302013ec7cb8eafd1a.json create mode 100644 .sqlx/query-9cc24e2538c05af994fb2cbed6cc83580f50abe7a2a90a95331255e20ee20682.json create mode 100644 .sqlx/query-acc7fc8f1a01f86436d22057525591df3dfa0991e9b5acb99331a7bf046cceb1.json create mode 100644 .sqlx/query-b435c637c0d110cfac080c591edffd226245f51981a63818a67190fef4f44bf8.json create mode 100644 .sqlx/query-beb4efe0422bd9ecf970d3a77f426e84928e7d4b5ac27fd0118489a04700bc02.json create mode 100644 .sqlx/query-bf104b9a07b92e6daef1346255c1fb9cbf4c30c2a767b5314e4cc7803b338cd9.json create mode 100644 .sqlx/query-c0e3d8bf263772b25bacfdf4ba0ff7d986e6348680f2e386dbb40fd6f37e0fbc.json create mode 100644 .sqlx/query-cb9d98bc740a6380e027f17cacd6efe8f10ef3e3e056494ca58f0dc013f92643.json create mode 100644 .sqlx/query-cc6b874dc24ea589c93c81f9dbc3c4f057a82ce4f453c56e4b6728f5ec9931b6.json create mode 100644 .sqlx/query-d40cd4c235045a61e8fb6d607d8d737c54c095c4fa8a8e85ecba8593c8a42267.json create mode 100644 .sqlx/query-d4feaaaaee7399d5c85aa76a02fd57eafced6836523a3312cac57fb13396834e.json create mode 100644 .sqlx/query-d50ef96acf789ca4e715dfaec291895da495a64ff6719724150b1f939617971a.json create mode 100644 .sqlx/query-dafc146b51ef7a3fc6fe1d9c5487b650e08c2f7de73485c20758eddb1164c827.json create mode 100644 .sqlx/query-de8b346243c7de75d3329eb4852d9d9486783dcba5e1a8d96a58f074e766db83.json create mode 100644 .sqlx/query-e206f3b2ae75c04296bd3b44f26e0cf015372eff542ef6220439ae4d8f2ec43e.json create mode 100644 .sqlx/query-e4688afeb7fb77e0e5961ac1650dc35efc017db15316c1b2e1ab8320988f9b83.json create mode 100644 .sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json create mode 100644 .sqlx/query-ef3f510b525d1693e135cc7ae3e018d997ccb4affb898df3eb59ba06b2d5a3ce.json create mode 100644 .sqlx/query-f44843af1870d4d9dc02dcec353a198b4801f453ef61e8dcc216432b1b1b947e.json create mode 100644 migrations/20260510000000_notify-on-delete.sql create mode 100644 migrations/20260510000100_task-lease-indexes.sql diff --git a/.pre-commit.sh b/.pre-commit.sh index adaed53..fc2d6a5 100755 --- a/.pre-commit.sh +++ b/.pre-commit.sh @@ -31,5 +31,5 @@ cargo +nightly fmt -- --check cargo sort -c cargo test --all-targets cargo test --doc -cargo sqlx prepare && git add .sqlx +cargo sqlx prepare -- --all-targets && git add .sqlx cargo clippy --all-targets -- -D warnings diff --git a/.sqlx/query-0e27da7580e82a38a7fb3d0d7ab2d39daffaa7f4394c4b50e69db37f9d24a376.json b/.sqlx/query-0e27da7580e82a38a7fb3d0d7ab2d39daffaa7f4394c4b50e69db37f9d24a376.json new file mode 100644 index 0000000..5ec9f59 --- /dev/null +++ b/.sqlx/query-0e27da7580e82a38a7fb3d0d7ab2d39daffaa7f4394c4b50e69db37f9d24a376.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "ALTER TABLE pg_task RENAME COLUMN wakeup_at TO scheduled_at", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "0e27da7580e82a38a7fb3d0d7ab2d39daffaa7f4394c4b50e69db37f9d24a376" +} diff --git a/.sqlx/query-19c3ed955dd0cf1800b7c5d6757c874a976a6583242402773fc3e57861632ef5.json b/.sqlx/query-19c3ed955dd0cf1800b7c5d6757c874a976a6583242402773fc3e57861632ef5.json new file mode 100644 index 0000000..7afd2ca --- /dev/null +++ b/.sqlx/query-19c3ed955dd0cf1800b7c5d6757c874a976a6583242402773fc3e57861632ef5.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "ALTER TABLE pg_task RENAME COLUMN locked_by TO task_locked_by", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "19c3ed955dd0cf1800b7c5d6757c874a976a6583242402773fc3e57861632ef5" +} diff --git a/.sqlx/query-1b13a067bda37fcf4fcf4a081c884e60c4843823da0b06ed3bf42431ec311236.json b/.sqlx/query-1b13a067bda37fcf4fcf4a081c884e60c4843823da0b06ed3bf42431ec311236.json new file mode 100644 index 0000000..9ec3da1 --- /dev/null +++ b/.sqlx/query-1b13a067bda37fcf4fcf4a081c884e60c4843823da0b06ed3bf42431ec311236.json @@ -0,0 +1,38 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT tried, locked_by, lock_expires_at, error FROM pg_task LIMIT 1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "tried", + "type_info": "Int4" + }, + { + "ordinal": 1, + "name": "locked_by", + "type_info": "Uuid" + }, + { + "ordinal": 2, + "name": "lock_expires_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 3, + "name": "error", + "type_info": "Text" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + true, + true, + true + ] + }, + "hash": "1b13a067bda37fcf4fcf4a081c884e60c4843823da0b06ed3bf42431ec311236" +} diff --git a/.sqlx/query-419e27727faf579f041c979d5fddae1a16d5c691dc05aa5a4852434180516b1e.json b/.sqlx/query-419e27727faf579f041c979d5fddae1a16d5c691dc05aa5a4852434180516b1e.json new file mode 100644 index 0000000..50e33cb --- /dev/null +++ b/.sqlx/query-419e27727faf579f041c979d5fddae1a16d5c691dc05aa5a4852434180516b1e.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "ALTER TABLE pg_task ADD CONSTRAINT reject_errors CHECK (error IS NULL)", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "419e27727faf579f041c979d5fddae1a16d5c691dc05aa5a4852434180516b1e" +} diff --git a/.sqlx/query-489b35c9fb11e7647e75ae3b30ca1931391252b367b22f56a0058d676e7c72e6.json b/.sqlx/query-489b35c9fb11e7647e75ae3b30ca1931391252b367b22f56a0058d676e7c72e6.json new file mode 100644 index 0000000..ae9963b --- /dev/null +++ b/.sqlx/query-489b35c9fb11e7647e75ae3b30ca1931391252b367b22f56a0058d676e7c72e6.json @@ -0,0 +1,23 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET error = $2\n WHERE id = $1\n RETURNING updated_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + }, + "nullable": [ + false + ] + }, + "hash": "489b35c9fb11e7647e75ae3b30ca1931391252b367b22f56a0058d676e7c72e6" +} diff --git a/.sqlx/query-4ce326ceeda8adef3cc7b549e0db9fd49c4605c64604d8f4a03226e923d7e4a1.json b/.sqlx/query-4ce326ceeda8adef3cc7b549e0db9fd49c4605c64604d8f4a03226e923d7e4a1.json new file mode 100644 index 0000000..828279b --- /dev/null +++ b/.sqlx/query-4ce326ceeda8adef3cc7b549e0db9fd49c4605c64604d8f4a03226e923d7e4a1.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT step FROM pg_task WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "step", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "4ce326ceeda8adef3cc7b549e0db9fd49c4605c64604d8f4a03226e923d7e4a1" +} diff --git a/.sqlx/query-4e5e8277069eecc7c3f0380a431bed3c117f2edca11d3976bffad53725c8469d.json b/.sqlx/query-4e5e8277069eecc7c3f0380a431bed3c117f2edca11d3976bffad53725c8469d.json new file mode 100644 index 0000000..266523f --- /dev/null +++ b/.sqlx/query-4e5e8277069eecc7c3f0380a431bed3c117f2edca11d3976bffad53725c8469d.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO pg_task (step, wakeup_at, locked_by)\n VALUES ($1, $2, $3)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Timestamptz", + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "4e5e8277069eecc7c3f0380a431bed3c117f2edca11d3976bffad53725c8469d" +} diff --git a/.sqlx/query-4f88b4ad437e4b89cd80abdb978edeb995c0569dfd937bb2a993637ef6851b41.json b/.sqlx/query-4f88b4ad437e4b89cd80abdb978edeb995c0569dfd937bb2a993637ef6851b41.json new file mode 100644 index 0000000..1138c13 --- /dev/null +++ b/.sqlx/query-4f88b4ad437e4b89cd80abdb978edeb995c0569dfd937bb2a993637ef6851b41.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO pg_task (step, wakeup_at, lock_expires_at)\n VALUES ($1, $2, $3)\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Timestamptz", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "4f88b4ad437e4b89cd80abdb978edeb995c0569dfd937bb2a993637ef6851b41" +} diff --git a/.sqlx/query-50e892ba3ad57ab603dbe2516bd7ca620d671461c37ea4b95980075c394bc8a2.json b/.sqlx/query-50e892ba3ad57ab603dbe2516bd7ca620d671461c37ea4b95980075c394bc8a2.json new file mode 100644 index 0000000..e768218 --- /dev/null +++ b/.sqlx/query-50e892ba3ad57ab603dbe2516bd7ca620d671461c37ea4b95980075c394bc8a2.json @@ -0,0 +1,26 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO pg_task (step, wakeup_at, locked_by, lock_expires_at, error)\n VALUES ($1, $2, $3, $4, $5)\n RETURNING id\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [ + "Text", + "Timestamptz", + "Uuid", + "Timestamptz", + "Text" + ] + }, + "nullable": [ + false + ] + }, + "hash": "50e892ba3ad57ab603dbe2516bd7ca620d671461c37ea4b95980075c394bc8a2" +} diff --git a/.sqlx/query-742156444758d69f3a2b634359e56303e9bd18d97d87667217d7e2d323a52035.json b/.sqlx/query-742156444758d69f3a2b634359e56303e9bd18d97d87667217d7e2d323a52035.json new file mode 100644 index 0000000..9519f29 --- /dev/null +++ b/.sqlx/query-742156444758d69f3a2b634359e56303e9bd18d97d87667217d7e2d323a52035.json @@ -0,0 +1,34 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "locked_by", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "lock_expires_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 2, + "name": "error", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + true, + true, + true + ] + }, + "hash": "742156444758d69f3a2b634359e56303e9bd18d97d87667217d7e2d323a52035" +} diff --git a/.sqlx/query-85366f4ba53750ca317ef4515b201e747634ec03d6f2afa45f6d9f1c697b43e0.json b/.sqlx/query-85366f4ba53750ca317ef4515b201e747634ec03d6f2afa45f6d9f1c697b43e0.json new file mode 100644 index 0000000..05a8eab --- /dev/null +++ b/.sqlx/query-85366f4ba53750ca317ef4515b201e747634ec03d6f2afa45f6d9f1c697b43e0.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT step FROM pg_task", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "step", + "type_info": "Text" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false + ] + }, + "hash": "85366f4ba53750ca317ef4515b201e747634ec03d6f2afa45f6d9f1c697b43e0" +} diff --git a/.sqlx/query-978c0565523b783d1a3c0140a239422cc625816045c989302013ec7cb8eafd1a.json b/.sqlx/query-978c0565523b783d1a3c0140a239422cc625816045c989302013ec7cb8eafd1a.json new file mode 100644 index 0000000..5bc5298 --- /dev/null +++ b/.sqlx/query-978c0565523b783d1a3c0140a239422cc625816045c989302013ec7cb8eafd1a.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id FROM pg_task", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false + ] + }, + "hash": "978c0565523b783d1a3c0140a239422cc625816045c989302013ec7cb8eafd1a" +} diff --git a/.sqlx/query-9cc24e2538c05af994fb2cbed6cc83580f50abe7a2a90a95331255e20ee20682.json b/.sqlx/query-9cc24e2538c05af994fb2cbed6cc83580f50abe7a2a90a95331255e20ee20682.json new file mode 100644 index 0000000..3c24145 --- /dev/null +++ b/.sqlx/query-9cc24e2538c05af994fb2cbed6cc83580f50abe7a2a90a95331255e20ee20682.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "NOTIFY pg_task_changed, 'wake'", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "9cc24e2538c05af994fb2cbed6cc83580f50abe7a2a90a95331255e20ee20682" +} diff --git a/.sqlx/query-acc7fc8f1a01f86436d22057525591df3dfa0991e9b5acb99331a7bf046cceb1.json b/.sqlx/query-acc7fc8f1a01f86436d22057525591df3dfa0991e9b5acb99331a7bf046cceb1.json new file mode 100644 index 0000000..64fc5f6 --- /dev/null +++ b/.sqlx/query-acc7fc8f1a01f86436d22057525591df3dfa0991e9b5acb99331a7bf046cceb1.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "ALTER TABLE pg_task RENAME COLUMN step TO task_step", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "acc7fc8f1a01f86436d22057525591df3dfa0991e9b5acb99331a7bf046cceb1" +} diff --git a/.sqlx/query-b435c637c0d110cfac080c591edffd226245f51981a63818a67190fef4f44bf8.json b/.sqlx/query-b435c637c0d110cfac080c591edffd226245f51981a63818a67190fef4f44bf8.json new file mode 100644 index 0000000..6d37041 --- /dev/null +++ b/.sqlx/query-b435c637c0d110cfac080c591edffd226245f51981a63818a67190fef4f44bf8.json @@ -0,0 +1,27 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO pg_task (step, wakeup_at, tried, locked_by, lock_expires_at, error)\n VALUES ($1, $2, $3, $4, $5, $6)\n RETURNING id\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [ + "Text", + "Timestamptz", + "Int4", + "Uuid", + "Timestamptz", + "Text" + ] + }, + "nullable": [ + false + ] + }, + "hash": "b435c637c0d110cfac080c591edffd226245f51981a63818a67190fef4f44bf8" +} diff --git a/.sqlx/query-beb4efe0422bd9ecf970d3a77f426e84928e7d4b5ac27fd0118489a04700bc02.json b/.sqlx/query-beb4efe0422bd9ecf970d3a77f426e84928e7d4b5ac27fd0118489a04700bc02.json new file mode 100644 index 0000000..3f71faa --- /dev/null +++ b/.sqlx/query-beb4efe0422bd9ecf970d3a77f426e84928e7d4b5ac27fd0118489a04700bc02.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT step, wakeup_at FROM pg_task WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "step", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "wakeup_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "beb4efe0422bd9ecf970d3a77f426e84928e7d4b5ac27fd0118489a04700bc02" +} diff --git a/.sqlx/query-bf104b9a07b92e6daef1346255c1fb9cbf4c30c2a767b5314e4cc7803b338cd9.json b/.sqlx/query-bf104b9a07b92e6daef1346255c1fb9cbf4c30c2a767b5314e4cc7803b338cd9.json new file mode 100644 index 0000000..f34ec31 --- /dev/null +++ b/.sqlx/query-bf104b9a07b92e6daef1346255c1fb9cbf4c30c2a767b5314e4cc7803b338cd9.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "NOTIFY pg_task_changed, 'stop_worker'", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "bf104b9a07b92e6daef1346255c1fb9cbf4c30c2a767b5314e4cc7803b338cd9" +} diff --git a/.sqlx/query-c0e3d8bf263772b25bacfdf4ba0ff7d986e6348680f2e386dbb40fd6f37e0fbc.json b/.sqlx/query-c0e3d8bf263772b25bacfdf4ba0ff7d986e6348680f2e386dbb40fd6f37e0fbc.json new file mode 100644 index 0000000..b354491 --- /dev/null +++ b/.sqlx/query-c0e3d8bf263772b25bacfdf4ba0ff7d986e6348680f2e386dbb40fd6f37e0fbc.json @@ -0,0 +1,29 @@ +{ + "db_name": "PostgreSQL", + "query": "\n INSERT INTO pg_task (step, wakeup_at)\n VALUES ($1, $2)\n RETURNING id, updated_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Text", + "Timestamptz" + ] + }, + "nullable": [ + false, + false + ] + }, + "hash": "c0e3d8bf263772b25bacfdf4ba0ff7d986e6348680f2e386dbb40fd6f37e0fbc" +} diff --git a/.sqlx/query-cb9d98bc740a6380e027f17cacd6efe8f10ef3e3e056494ca58f0dc013f92643.json b/.sqlx/query-cb9d98bc740a6380e027f17cacd6efe8f10ef3e3e056494ca58f0dc013f92643.json new file mode 100644 index 0000000..550c544 --- /dev/null +++ b/.sqlx/query-cb9d98bc740a6380e027f17cacd6efe8f10ef3e3e056494ca58f0dc013f92643.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT locked_by, lock_expires_at\n FROM pg_task\n WHERE id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "locked_by", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "lock_expires_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + true, + true + ] + }, + "hash": "cb9d98bc740a6380e027f17cacd6efe8f10ef3e3e056494ca58f0dc013f92643" +} diff --git a/.sqlx/query-cc6b874dc24ea589c93c81f9dbc3c4f057a82ce4f453c56e4b6728f5ec9931b6.json b/.sqlx/query-cc6b874dc24ea589c93c81f9dbc3c4f057a82ce4f453c56e4b6728f5ec9931b6.json new file mode 100644 index 0000000..449c528 --- /dev/null +++ b/.sqlx/query-cc6b874dc24ea589c93c81f9dbc3c4f057a82ce4f453c56e4b6728f5ec9931b6.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT current_database() AS \"db_name!\"", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "db_name!", + "type_info": "Name" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + null + ] + }, + "hash": "cc6b874dc24ea589c93c81f9dbc3c4f057a82ce4f453c56e4b6728f5ec9931b6" +} diff --git a/.sqlx/query-d40cd4c235045a61e8fb6d607d8d737c54c095c4fa8a8e85ecba8593c8a42267.json b/.sqlx/query-d40cd4c235045a61e8fb6d607d8d737c54c095c4fa8a8e85ecba8593c8a42267.json new file mode 100644 index 0000000..cc333ee --- /dev/null +++ b/.sqlx/query-d40cd4c235045a61e8fb6d607d8d737c54c095c4fa8a8e85ecba8593c8a42267.json @@ -0,0 +1,20 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id FROM pg_task WHERE locked_by IS NOT NULL", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false + ] + }, + "hash": "d40cd4c235045a61e8fb6d607d8d737c54c095c4fa8a8e85ecba8593c8a42267" +} diff --git a/.sqlx/query-d4feaaaaee7399d5c85aa76a02fd57eafced6836523a3312cac57fb13396834e.json b/.sqlx/query-d4feaaaaee7399d5c85aa76a02fd57eafced6836523a3312cac57fb13396834e.json new file mode 100644 index 0000000..a04480a --- /dev/null +++ b/.sqlx/query-d4feaaaaee7399d5c85aa76a02fd57eafced6836523a3312cac57fb13396834e.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2)", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Text", + "Timestamptz" + ] + }, + "nullable": [] + }, + "hash": "d4feaaaaee7399d5c85aa76a02fd57eafced6836523a3312cac57fb13396834e" +} diff --git a/.sqlx/query-d50ef96acf789ca4e715dfaec291895da495a64ff6719724150b1f939617971a.json b/.sqlx/query-d50ef96acf789ca4e715dfaec291895da495a64ff6719724150b1f939617971a.json new file mode 100644 index 0000000..95f24e4 --- /dev/null +++ b/.sqlx/query-d50ef96acf789ca4e715dfaec291895da495a64ff6719724150b1f939617971a.json @@ -0,0 +1,40 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT tried, locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "tried", + "type_info": "Int4" + }, + { + "ordinal": 1, + "name": "locked_by", + "type_info": "Uuid" + }, + { + "ordinal": 2, + "name": "lock_expires_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 3, + "name": "error", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + true, + true, + true + ] + }, + "hash": "d50ef96acf789ca4e715dfaec291895da495a64ff6719724150b1f939617971a" +} diff --git a/.sqlx/query-dafc146b51ef7a3fc6fe1d9c5487b650e08c2f7de73485c20758eddb1164c827.json b/.sqlx/query-dafc146b51ef7a3fc6fe1d9c5487b650e08c2f7de73485c20758eddb1164c827.json new file mode 100644 index 0000000..1bb66d9 --- /dev/null +++ b/.sqlx/query-dafc146b51ef7a3fc6fe1d9c5487b650e08c2f7de73485c20758eddb1164c827.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "ALTER TABLE pg_task RENAME COLUMN error TO task_error", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "dafc146b51ef7a3fc6fe1d9c5487b650e08c2f7de73485c20758eddb1164c827" +} diff --git a/.sqlx/query-de8b346243c7de75d3329eb4852d9d9486783dcba5e1a8d96a58f074e766db83.json b/.sqlx/query-de8b346243c7de75d3329eb4852d9d9486783dcba5e1a8d96a58f074e766db83.json new file mode 100644 index 0000000..53a287c --- /dev/null +++ b/.sqlx/query-de8b346243c7de75d3329eb4852d9d9486783dcba5e1a8d96a58f074e766db83.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "UPDATE pg_task SET error = $2 WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Text" + ] + }, + "nullable": [] + }, + "hash": "de8b346243c7de75d3329eb4852d9d9486783dcba5e1a8d96a58f074e766db83" +} diff --git a/.sqlx/query-e206f3b2ae75c04296bd3b44f26e0cf015372eff542ef6220439ae4d8f2ec43e.json b/.sqlx/query-e206f3b2ae75c04296bd3b44f26e0cf015372eff542ef6220439ae4d8f2ec43e.json new file mode 100644 index 0000000..1ad98c6 --- /dev/null +++ b/.sqlx/query-e206f3b2ae75c04296bd3b44f26e0cf015372eff542ef6220439ae4d8f2ec43e.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "ALTER TABLE pg_task RENAME COLUMN id TO task_id", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "e206f3b2ae75c04296bd3b44f26e0cf015372eff542ef6220439ae4d8f2ec43e" +} diff --git a/.sqlx/query-e4688afeb7fb77e0e5961ac1650dc35efc017db15316c1b2e1ab8320988f9b83.json b/.sqlx/query-e4688afeb7fb77e0e5961ac1650dc35efc017db15316c1b2e1ab8320988f9b83.json new file mode 100644 index 0000000..a045eeb --- /dev/null +++ b/.sqlx/query-e4688afeb7fb77e0e5961ac1650dc35efc017db15316c1b2e1ab8320988f9b83.json @@ -0,0 +1,12 @@ +{ + "db_name": "PostgreSQL", + "query": "ALTER TABLE pg_task RENAME COLUMN lock_expires_at TO task_lock_expires_at", + "describe": { + "columns": [], + "parameters": { + "Left": [] + }, + "nullable": [] + }, + "hash": "e4688afeb7fb77e0e5961ac1650dc35efc017db15316c1b2e1ab8320988f9b83" +} diff --git a/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json b/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json new file mode 100644 index 0000000..8c0474e --- /dev/null +++ b/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json @@ -0,0 +1,14 @@ +{ + "db_name": "PostgreSQL", + "query": "DELETE FROM pg_task WHERE id = $1", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [] + }, + "hash": "ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103" +} diff --git a/.sqlx/query-ef3f510b525d1693e135cc7ae3e018d997ccb4affb898df3eb59ba06b2d5a3ce.json b/.sqlx/query-ef3f510b525d1693e135cc7ae3e018d997ccb4affb898df3eb59ba06b2d5a3ce.json new file mode 100644 index 0000000..07fa624 --- /dev/null +++ b/.sqlx/query-ef3f510b525d1693e135cc7ae3e018d997ccb4affb898df3eb59ba06b2d5a3ce.json @@ -0,0 +1,52 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT step, wakeup_at, tried, locked_by, lock_expires_at, error\n FROM pg_task\n WHERE id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "step", + "type_info": "Text" + }, + { + "ordinal": 1, + "name": "wakeup_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 2, + "name": "tried", + "type_info": "Int4" + }, + { + "ordinal": 3, + "name": "locked_by", + "type_info": "Uuid" + }, + { + "ordinal": 4, + "name": "lock_expires_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 5, + "name": "error", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false, + false, + false, + true, + true, + true + ] + }, + "hash": "ef3f510b525d1693e135cc7ae3e018d997ccb4affb898df3eb59ba06b2d5a3ce" +} diff --git a/.sqlx/query-f44843af1870d4d9dc02dcec353a198b4801f453ef61e8dcc216432b1b1b947e.json b/.sqlx/query-f44843af1870d4d9dc02dcec353a198b4801f453ef61e8dcc216432b1b1b947e.json new file mode 100644 index 0000000..22fb4da --- /dev/null +++ b/.sqlx/query-f44843af1870d4d9dc02dcec353a198b4801f453ef61e8dcc216432b1b1b947e.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id FROM pg_task WHERE id = $1 FOR UPDATE", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "f44843af1870d4d9dc02dcec353a198b4801f453ef61e8dcc216432b1b1b947e" +} diff --git a/migrations/20260510000000_notify-on-delete.sql b/migrations/20260510000000_notify-on-delete.sql new file mode 100644 index 0000000..f1e7e9e --- /dev/null +++ b/migrations/20260510000000_notify-on-delete.sql @@ -0,0 +1,37 @@ +DROP TRIGGER pg_task_changed ON pg_task; + +CREATE OR REPLACE FUNCTION pg_task_notify_on_change() +RETURNS trigger AS $$ +BEGIN + PERFORM pg_notify('pg_task_changed', ''); + IF TG_OP = 'DELETE' THEN + RETURN OLD; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER pg_task_changed_insert +AFTER INSERT +ON pg_task +FOR EACH ROW +EXECUTE PROCEDURE pg_task_notify_on_change(); + +CREATE TRIGGER pg_task_changed_delete +AFTER DELETE +ON pg_task +FOR EACH ROW +EXECUTE PROCEDURE pg_task_notify_on_change(); + +CREATE TRIGGER pg_task_changed_update +AFTER UPDATE +ON pg_task +FOR EACH ROW +WHEN ( + OLD.step IS DISTINCT FROM NEW.step + OR OLD.wakeup_at IS DISTINCT FROM NEW.wakeup_at + OR OLD.tried IS DISTINCT FROM NEW.tried + OR OLD.error IS DISTINCT FROM NEW.error + OR OLD.locked_by IS DISTINCT FROM NEW.locked_by +) +EXECUTE PROCEDURE pg_task_notify_on_change(); diff --git a/migrations/20260510000100_task-lease-indexes.sql b/migrations/20260510000100_task-lease-indexes.sql new file mode 100644 index 0000000..253bcd2 --- /dev/null +++ b/migrations/20260510000100_task-lease-indexes.sql @@ -0,0 +1,17 @@ +DROP INDEX pg_task_lock_expires_at_idx; + +CREATE INDEX pg_task_locked_by_idx +ON pg_task (locked_by) +WHERE locked_by IS NOT NULL + AND error IS NULL; + +CREATE INDEX pg_task_next_available_at_idx +ON pg_task (( + CASE + WHEN locked_by IS NOT NULL THEN + GREATEST(wakeup_at, lock_expires_at) + ELSE + wakeup_at + END +)) +WHERE error IS NULL; From af1a6c2aa05c04b012b38a7d16b130ec9a4384e2 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sun, 10 May 2026 07:51:15 +0600 Subject: [PATCH 15/44] Fix worker lease renewal lifecycle --- src/task.rs | 2 +- src/worker.rs | 238 +++++++++++++++++++++++++++++++++++--------------- 2 files changed, 170 insertions(+), 70 deletions(-) diff --git a/src/task.rs b/src/task.rs index 9592c14..789acba 100644 --- a/src/task.rs +++ b/src/task.rs @@ -402,7 +402,7 @@ impl Task { fn log_lost_lease(&self, worker_id: Uuid, action: &str) { warn!( - "[{}] couldn't {action} because worker {worker_id} no longer owns the task", + "[{}] couldn't {action} because worker {worker_id}'s lease expired or is no longer owned by this worker", self.id ); } diff --git a/src/worker.rs b/src/worker.rs index f72f0ca..43d8ec6 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -34,7 +34,6 @@ pub struct Worker { listener: Listener, tasks: PhantomData, concurrency: NonZeroUsize, - worker_id: Uuid, lease_timeout: Duration, heartbeat_interval: Duration, } @@ -48,7 +47,6 @@ impl + 'static> Worker { db, listener, concurrency, - worker_id: Uuid::new_v4(), lease_timeout: DEFAULT_LEASE_TIMEOUT, heartbeat_interval: DEFAULT_HEARTBEAT_INTERVAL, tasks: PhantomData, @@ -83,12 +81,14 @@ impl + 'static> Worker { /// Runs all ready tasks to completion and waits for new ones pub async fn run(&self) -> Result<()> { + self.validate_lease_timing(); self.listener.listen(self.db.clone()).await?; + let lease = TaskLease::new(Uuid::new_v4(), self.lease_timeout); let semaphore = Arc::new(Semaphore::new(self.concurrency.get())); let running_steps = Arc::new(Mutex::new(Vec::new())); let (heartbeat_events_sender, mut heartbeat_events) = mpsc::unbounded_channel(); - let heartbeat = self.spawn_heartbeat(heartbeat_events_sender, running_steps.clone()); + let heartbeat = self.spawn_heartbeat(heartbeat_events_sender, running_steps.clone(), lease); let (step_error_sender, mut step_errors) = mpsc::unbounded_channel(); let mut heartbeat_healthy = true; let mut abort_running_steps = false; @@ -138,7 +138,7 @@ impl + 'static> Worker { reserved_permit = Some(permit.map_err(Error::UnreachableWorkerSemaphoreClosed)?); continue; } - received = self.recv_task(), if heartbeat_healthy && reserved_permit.is_some() => received, + received = self.recv_task(lease), if heartbeat_healthy && reserved_permit.is_some() => received, }; match received { Ok(Some((task, step, lease))) => { @@ -188,9 +188,16 @@ impl + 'static> Worker { .await } + fn validate_lease_timing(&self) { + assert!( + self.heartbeat_interval < self.lease_timeout, + "heartbeat interval must be shorter than lease timeout" + ); + } + /// Waits until the next task is ready, marks it running and returns it. /// Returns `None` if the worker is stopped - async fn recv_task(&self) -> Result> { + async fn recv_task(&self, lease: TaskLease) -> Result> { trace!("Receiving the next task"); loop { @@ -219,7 +226,6 @@ impl + 'static> Worker { continue; }; - let lease = TaskLease::new(self.worker_id, self.lease_timeout); let Some(step) = task.claim(&mut tx, lease).await? else { tx.commit().await.map_err(db_error!("save error"))?; continue; @@ -298,9 +304,10 @@ impl + 'static> Worker { &self, events: mpsc::UnboundedSender, running_steps: Arc>>, + lease: TaskLease, ) -> tokio::task::AbortHandle { + self.validate_lease_timing(); let db = self.db.clone(); - let lease = TaskLease::new(self.worker_id, self.lease_timeout); let mut heartbeat = interval(self.heartbeat_interval); let heartbeat_interval = self.heartbeat_interval; let lease_timeout = self.lease_timeout; @@ -311,6 +318,14 @@ impl + 'static> Worker { heartbeat.tick().await; loop { heartbeat.tick().await; + if !Self::has_running_steps(&running_steps) { + last_renewed_at = Instant::now(); + if renewal_failed { + let _ = events.send(HeartbeatEvent::Recovered); + renewal_failed = false; + } + continue; + } match Task::renew_leases(&db, lease).await { Ok(renewed) if renewed > 0 => { trace!("Renewed {renewed} task leases"); @@ -431,31 +446,15 @@ impl + 'static> Worker { } async fn wait_for_steps_to_finish(&self, semaphore: Arc) { - let mut logged_tasks_left = None; - loop { - let tasks_left = self.concurrency.get() - semaphore.available_permits(); - if tasks_left == 0 { - break; - } - if let Some(logged) = logged_tasks_left { - if logged != tasks_left { - trace!("Waiting for the current steps of {tasks_left} tasks to finish..."); - } - } else { - info!("Waiting for the current steps of {tasks_left} tasks to finish..."); - } - logged_tasks_left = Some(tasks_left); - sleep(Duration::from_secs_f32(0.1)).await; - } - if logged_tasks_left.is_some() { - trace!("The current step of every task is done") - } + self.wait_for_steps_to_finish_impl(semaphore, None) + .await + .expect("waiting without heartbeat events cannot fail"); } - async fn wait_for_steps_to_finish_or_heartbeat( + async fn wait_for_steps_to_finish_impl( &self, semaphore: Arc, - mut heartbeat_events: mpsc::UnboundedReceiver, + mut heartbeat_events: Option>, ) -> Result<()> { let mut logged_tasks_left = None; let mut heartbeat_healthy = true; @@ -472,11 +471,15 @@ impl + 'static> Worker { info!("Waiting for the current steps of {tasks_left} tasks to finish..."); } logged_tasks_left = Some(tasks_left); - tokio::select! { - Some(event) = heartbeat_events.recv() => { - Self::handle_heartbeat_event(event, &mut heartbeat_healthy)?; + if let Some(heartbeat_events) = heartbeat_events.as_mut() { + tokio::select! { + Some(event) = heartbeat_events.recv() => { + Self::handle_heartbeat_event(event, &mut heartbeat_healthy)?; + } + _ = sleep(Duration::from_secs_f32(0.1)) => {} } - _ = sleep(Duration::from_secs_f32(0.1)) => {} + } else { + sleep(Duration::from_secs_f32(0.1)).await; } } if logged_tasks_left.is_some() { @@ -484,12 +487,21 @@ impl + 'static> Worker { } Ok(()) } + + async fn wait_for_steps_to_finish_or_heartbeat( + &self, + semaphore: Arc, + heartbeat_events: mpsc::UnboundedReceiver, + ) -> Result<()> { + self.wait_for_steps_to_finish_impl(semaphore, Some(heartbeat_events)) + .await + } } #[cfg(test)] mod tests { - use super::{HeartbeatEvent, Worker}; - use crate::{Error, NextStep, Step}; + use super::{HeartbeatEvent, Worker, DEFAULT_HEARTBEAT_INTERVAL, DEFAULT_LEASE_TIMEOUT}; + use crate::{task::TaskLease, Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; use sqlx::{postgres::PgPoolOptions, PgPool}; use std::{ @@ -892,6 +904,10 @@ mod tests { NonZeroUsize::new(value).unwrap() } + fn worker_lease(worker: &Worker) -> TaskLease { + TaskLease::new(Uuid::new_v4(), worker.lease_timeout) + } + fn spawn_worker(pool: PgPool) -> tokio::task::JoinHandle> { spawn_worker_with_concurrency(pool, 1) } @@ -930,6 +946,32 @@ mod tests { .with_heartbeat_interval(Duration::ZERO); } + #[tokio::test] + #[should_panic(expected = "heartbeat interval must be shorter than lease timeout")] + async fn run_rejects_lease_timeout_that_is_not_longer_than_the_heartbeat_interval() { + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_lease_timeout(DEFAULT_HEARTBEAT_INTERVAL); + + let _ = worker.run().await; + } + + #[tokio::test] + #[should_panic(expected = "heartbeat interval must be shorter than lease timeout")] + async fn run_rejects_heartbeat_interval_that_is_not_shorter_than_the_lease_timeout() { + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_heartbeat_interval(DEFAULT_LEASE_TIMEOUT); + + let _ = worker.run().await; + } + #[test] fn heartbeat_events_pause_resume_and_expire_fetching() { let mut heartbeat_healthy = true; @@ -1052,7 +1094,7 @@ mod tests { } #[sqlx::test(migrations = "./migrations")] - async fn heartbeat_failures_without_running_steps_do_not_expire(pool: PgPool) { + async fn heartbeat_skips_renewal_without_running_steps(pool: PgPool) { init_tracing(); sqlx::query!("ALTER TABLE pg_task RENAME COLUMN lock_expires_at TO task_lock_expires_at") .execute(&pool) @@ -1062,13 +1104,12 @@ mod tests { .with_lease_timeout(Duration::from_millis(80)) .with_heartbeat_interval(Duration::from_millis(20)); let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat(events, Arc::new(Mutex::new(Vec::new()))); + let heartbeat = worker.spawn_heartbeat( + events, + Arc::new(Mutex::new(Vec::new())), + worker_lease(&worker), + ); - let event = timeout(Duration::from_secs(1), events_receiver.recv()) - .await - .unwrap() - .unwrap(); - assert!(matches!(event, HeartbeatEvent::Failed)); assert!(timeout(Duration::from_millis(150), events_receiver.recv()) .await .is_err()); @@ -1087,7 +1128,7 @@ mod tests { }); let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat(events, running_steps); + let heartbeat = worker.spawn_heartbeat(events, running_steps, worker_lease(&worker)); let event = timeout(Duration::from_secs(1), events_receiver.recv()) .await @@ -1109,7 +1150,7 @@ mod tests { } #[sqlx::test(migrations = "./migrations")] - async fn heartbeat_reports_recovery_after_renewal_failures_stop(pool: PgPool) { + async fn heartbeat_skips_pool_timeouts_without_running_steps(pool: PgPool) { init_tracing(); let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; let held_connection = worker_pool.acquire().await.unwrap(); @@ -1117,22 +1158,17 @@ mod tests { .with_lease_timeout(Duration::from_millis(500)) .with_heartbeat_interval(Duration::from_millis(20)); let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat(events, Arc::new(Mutex::new(Vec::new()))); + let heartbeat = worker.spawn_heartbeat( + events, + Arc::new(Mutex::new(Vec::new())), + worker_lease(&worker), + ); - let event = timeout(Duration::from_secs(1), events_receiver.recv()) + assert!(timeout(Duration::from_millis(150), events_receiver.recv()) .await - .unwrap() - .unwrap(); - assert!(matches!(event, HeartbeatEvent::Failed)); + .is_err()); drop(held_connection); - - let event = timeout(Duration::from_secs(1), events_receiver.recv()) - .await - .unwrap() - .unwrap(); - assert!(matches!(event, HeartbeatEvent::Recovered)); - heartbeat.abort(); } @@ -1144,6 +1180,8 @@ mod tests { let worker = Worker::::new(worker_pool) .with_lease_timeout(Duration::from_millis(500)) .with_heartbeat_interval(Duration::from_millis(20)); + let worker_id = Uuid::new_v4(); + let lease = TaskLease::new(worker_id, worker.lease_timeout); let id = insert_task_at( &pool, &TestTask::Noop(Noop), @@ -1160,14 +1198,18 @@ mod tests { WHERE id = $1 ", id, - worker.worker_id, + worker_id, initial_expires_at, ) .execute(&pool) .await .unwrap(); + let running_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat(events, Arc::new(Mutex::new(Vec::new()))); + let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); let event = timeout(Duration::from_secs(1), events_receiver.recv()) .await @@ -1187,6 +1229,7 @@ mod tests { assert!(renewed_expires_at > initial_expires_at); heartbeat.abort(); + running_step.abort(); } #[tokio::test] @@ -1302,7 +1345,8 @@ mod tests { "listener failed".into(), ))); - let err = worker.recv_task().await.unwrap_err(); + let lease = worker_lease(&worker); + let err = worker.recv_task(lease).await.unwrap_err(); assert!(matches!( err, @@ -1320,7 +1364,8 @@ mod tests { ))); worker.listener.stop_worker_for_tests(); - assert!(worker.recv_task().await.unwrap().is_none()); + let lease = worker_lease(&worker); + assert!(worker.recv_task(lease).await.unwrap().is_none()); } #[sqlx::test(migrations = "./migrations")] @@ -1328,7 +1373,8 @@ mod tests { let worker = Worker::::new(pool.clone()); pool.close().await; - let err = worker.recv_task().await.unwrap_err(); + let lease = worker_lease(&worker); + let err = worker.recv_task(lease).await.unwrap_err(); match err { Error::Db(sqlx::Error::PoolClosed, context) => { @@ -1341,9 +1387,10 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn recv_task_stops_while_waiting_for_work(pool: PgPool) { let worker = Arc::new(Worker::::new(pool)); + let lease = worker_lease(&worker); let recv = tokio::spawn({ let worker = worker.clone(); - async move { worker.recv_task().await } + async move { worker.recv_task(lease).await } }); sleep(Duration::from_millis(50)).await; @@ -1361,9 +1408,10 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn recv_task_returns_listener_errors_while_waiting_for_work(pool: PgPool) { let worker = Arc::new(Worker::::new(pool)); + let lease = worker_lease(&worker); let recv = tokio::spawn({ let worker = worker.clone(); - async move { worker.recv_task().await } + async move { worker.recv_task(lease).await } }); sleep(Duration::from_millis(50)).await; @@ -1403,8 +1451,9 @@ mod tests { ) .await; let worker = Worker::::new(pool.clone()); + let lease = worker_lease(&worker); - let (task, step, _lease) = worker.recv_task().await.unwrap().unwrap(); + let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); assert_eq!(task.id, expected); assert!(matches!(step, TestTask::Noop(Noop))); @@ -1437,7 +1486,8 @@ mod tests { assert_eq!(locked.id, id); let worker = Worker::::new(pool); - let recv = tokio::spawn(async move { worker.recv_task().await }); + let lease = worker_lease(&worker); + let recv = tokio::spawn(async move { worker.recv_task(lease).await }); sleep(Duration::from_millis(50)).await; assert!(!recv.is_finished()); @@ -1466,7 +1516,8 @@ mod tests { set_task_lease(&pool, id, Utc::now() + ChronoDuration::milliseconds(100)).await; let worker = Worker::::new(pool); - let recv = tokio::spawn(async move { worker.recv_task().await }); + let lease = worker_lease(&worker); + let recv = tokio::spawn(async move { worker.recv_task(lease).await }); sleep(Duration::from_millis(50)).await; assert!(!recv.is_finished()); @@ -1492,9 +1543,10 @@ mod tests { .await; set_task_lease(&pool, id, Utc::now() - ChronoDuration::milliseconds(1)).await; let worker = Worker::::new(pool.clone()); - let worker_id = worker.worker_id; + let worker_id = Uuid::new_v4(); + let lease = TaskLease::new(worker_id, worker.lease_timeout); - let (task, step, _lease) = worker.recv_task().await.unwrap().unwrap(); + let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); assert_eq!(task.id, id); assert!(matches!(step, TestTask::Noop(Noop))); @@ -1521,9 +1573,11 @@ mod tests { .await; let first_worker = Worker::::new(pool.clone()); let second_worker = Worker::::new(pool.clone()); + let first_lease = worker_lease(&first_worker); + let second_lease = worker_lease(&second_worker); - let first_recv = tokio::spawn(async move { first_worker.recv_task().await }); - let second_recv = tokio::spawn(async move { second_worker.recv_task().await }); + let first_recv = tokio::spawn(async move { first_worker.recv_task(first_lease).await }); + let second_recv = tokio::spawn(async move { second_worker.recv_task(second_lease).await }); let (first_task, first_step, _first_lease) = timeout(Duration::from_secs(1), first_recv) .await @@ -2284,6 +2338,52 @@ mod tests { assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); } + #[sqlx::test(migrations = "./migrations")] + async fn rerunning_worker_does_not_renew_abandoned_leases_from_previous_runs(pool: PgPool) { + init_tracing(); + sqlx::query!("ALTER TABLE pg_task ADD CONSTRAINT reject_errors CHECK (error IS NULL)") + .execute(&pool) + .await + .unwrap(); + let state = StepStateGuard::new(); + let id = insert_task_at( + &pool, + &TestTask::FailStep(FailStep { key: state.key() }), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let worker = Worker::::new(pool.clone()) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_secs(1)) + .with_heartbeat_interval(Duration::from_millis(50)); + + let err = timeout(Duration::from_secs(1), worker.run()) + .await + .unwrap() + .unwrap_err(); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); + let (abandoned_owner, abandoned_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + + let rerun = tokio::spawn({ + let worker = worker; + async move { worker.run().await } + }); + + sleep(Duration::from_millis(150)).await; + let (locked_by, lock_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + assert_eq!(locked_by, abandoned_owner); + assert_eq!(lock_expires_at, abandoned_expires_at); + + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), rerun) + .await + .unwrap() + .unwrap() + .unwrap(); + } + #[sqlx::test(migrations = "./migrations")] async fn run_returns_step_errors_from_spawned_tasks(pool: PgPool) { let state = StepStateGuard::new(); From 954d8919b90f5ffc632290b96e7f1663457779a5 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sun, 10 May 2026 15:03:30 +0600 Subject: [PATCH 16/44] Fix lease timing and drain step errors --- ...1bb260c7ca01d0a418cd08d44783c445ea5a7.json | 15 -- ...33efd3bbce043b99c2b308622fb67a1357d36.json | 16 ++ ...07d7c64ad2c7d851bf25984dccb39665a3e40.json | 15 ++ src/task.rs | 26 ++- src/worker.rs | 212 ++++++++++++++++-- 5 files changed, 233 insertions(+), 51 deletions(-) delete mode 100644 .sqlx/query-03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json create mode 100644 .sqlx/query-5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36.json create mode 100644 .sqlx/query-6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40.json diff --git a/.sqlx/query-03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json b/.sqlx/query-03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json deleted file mode 100644 index a4866c3..0000000 --- a/.sqlx/query-03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET lock_expires_at = $2\n WHERE locked_by = $1\n AND lock_expires_at > now()\n AND error IS NULL\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Timestamptz" - ] - }, - "nullable": [] - }, - "hash": "03dce7dbc7f7fd28c7bd3bb925f1bb260c7ca01d0a418cd08d44783c445ea5a7" -} diff --git a/.sqlx/query-5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36.json b/.sqlx/query-5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36.json new file mode 100644 index 0000000..21d2f92 --- /dev/null +++ b/.sqlx/query-5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36.json @@ -0,0 +1,16 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET locked_by = $2,\n lock_expires_at = now() + $3::interval\n WHERE id = $1\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Interval" + ] + }, + "nullable": [] + }, + "hash": "5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36" +} diff --git a/.sqlx/query-6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40.json b/.sqlx/query-6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40.json new file mode 100644 index 0000000..d15ede7 --- /dev/null +++ b/.sqlx/query-6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40.json @@ -0,0 +1,15 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET lock_expires_at = now() + $2::interval\n WHERE locked_by = $1\n AND lock_expires_at > now()\n AND error IS NULL\n ", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Uuid", + "Interval" + ] + }, + "nullable": [] + }, + "hash": "6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40" +} diff --git a/src/task.rs b/src/task.rs index 789acba..eb6a629 100644 --- a/src/task.rs +++ b/src/task.rs @@ -5,7 +5,7 @@ use crate::{ use chrono::{DateTime, Utc}; use serde::Serialize; use sqlx::{ - postgres::{PgConnection, PgPool}, + postgres::{types::PgInterval, PgConnection, PgPool}, PgExecutor, }; use std::{fmt, time::Duration}; @@ -22,20 +22,22 @@ pub struct Task { #[derive(Clone, Copy, Debug)] pub(crate) struct TaskLease { worker_id: Uuid, - timeout: chrono::Duration, + timeout: PgInterval, } impl TaskLease { pub(crate) fn new(worker_id: Uuid, timeout: Duration) -> Self { + let microseconds = timeout.as_nanos().saturating_add(999) / 1_000; + let microseconds = microseconds.min(i64::MAX as u128) as i64; Self { worker_id, - timeout: std_duration_to_chrono(timeout), + timeout: PgInterval { + months: 0, + days: 0, + microseconds, + }, } } - - fn expires_at(self) -> DateTime { - Utc::now() + self.timeout - } } impl Task { @@ -107,12 +109,12 @@ impl Task { r#" UPDATE pg_task SET locked_by = $2, - lock_expires_at = $3 + lock_expires_at = now() + $3::interval WHERE id = $1 "#, self.id, lease.worker_id, - lease.expires_at(), + lease.timeout, ) .execute(con) .await @@ -126,13 +128,13 @@ impl Task { sqlx::query!( r#" UPDATE pg_task - SET lock_expires_at = $2 + SET lock_expires_at = now() + $2::interval WHERE locked_by = $1 AND lock_expires_at > now() AND error IS NULL "#, lease.worker_id, - lease.expires_at(), + lease.timeout, ) .execute(db) .await @@ -840,9 +842,9 @@ mod tests { async fn claim_marks_valid_steps_leased(pool: PgPool) { let id = insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; + let started_at = Utc::now(); let mut tx = pool.begin().await.unwrap(); let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - let started_at = Utc::now(); let claimed = task.claim::(&mut tx, task_lease()).await.unwrap(); tx.commit().await.unwrap(); let finished_at = Utc::now(); diff --git a/src/worker.rs b/src/worker.rs index 43d8ec6..65529f2 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -28,6 +28,11 @@ enum HeartbeatEvent { Expired(Error), } +struct RunEvents { + heartbeat: mpsc::UnboundedReceiver, + step_errors: mpsc::UnboundedReceiver, +} + /// A worker for processing tasks pub struct Worker { db: PgPool, @@ -183,7 +188,10 @@ impl + 'static> Worker { heartbeat, running_steps, abort_running_steps, - heartbeat_events, + RunEvents { + heartbeat: heartbeat_events, + step_errors, + }, ) .await } @@ -413,7 +421,7 @@ impl + 'static> Worker { heartbeat: tokio::task::AbortHandle, running_steps: Arc>>, abort_running_steps: bool, - heartbeat_events: mpsc::UnboundedReceiver, + events: RunEvents, ) -> Result<()> { self.listener.shutdown(); if abort_running_steps { @@ -427,10 +435,15 @@ impl + 'static> Worker { result } else { match self - .wait_for_steps_to_finish_or_heartbeat(semaphore.clone(), heartbeat_events) + .wait_for_steps_to_finish_or_events( + semaphore.clone(), + events.heartbeat, + events.step_errors, + result, + ) .await { - Ok(()) => result, + Ok(result) => result, Err(error) => { Self::abort_running_steps(&running_steps); self.wait_for_steps_to_finish(semaphore).await; @@ -446,21 +459,27 @@ impl + 'static> Worker { } async fn wait_for_steps_to_finish(&self, semaphore: Arc) { - self.wait_for_steps_to_finish_impl(semaphore, None) + self.wait_for_steps_to_finish_impl(semaphore, None, None, Ok(())) .await - .expect("waiting without heartbeat events cannot fail"); + .expect("waiting without event receivers cannot fail") + .expect("waiting without step errors cannot fail"); } async fn wait_for_steps_to_finish_impl( &self, semaphore: Arc, mut heartbeat_events: Option>, - ) -> Result<()> { + mut step_errors: Option>, + mut result: Result<()>, + ) -> Result> { let mut logged_tasks_left = None; let mut heartbeat_healthy = true; loop { let tasks_left = self.concurrency.get() - semaphore.available_permits(); if tasks_left == 0 { + if let Some(step_errors) = step_errors.as_mut() { + Self::record_step_errors(step_errors, &mut result); + } break; } if let Some(logged) = logged_tasks_left { @@ -471,36 +490,82 @@ impl + 'static> Worker { info!("Waiting for the current steps of {tasks_left} tasks to finish..."); } logged_tasks_left = Some(tasks_left); - if let Some(heartbeat_events) = heartbeat_events.as_mut() { - tokio::select! { - Some(event) = heartbeat_events.recv() => { - Self::handle_heartbeat_event(event, &mut heartbeat_healthy)?; + match (heartbeat_events.as_mut(), step_errors.as_mut()) { + (Some(heartbeat_events), Some(step_errors)) => { + tokio::select! { + Some(event) = heartbeat_events.recv() => { + Self::handle_heartbeat_event(event, &mut heartbeat_healthy)?; + } + Some(error) = step_errors.recv() => { + Self::record_step_error(error, &mut result); + } + _ = sleep(Duration::from_secs_f32(0.1)) => {} } - _ = sleep(Duration::from_secs_f32(0.1)) => {} } - } else { - sleep(Duration::from_secs_f32(0.1)).await; + (Some(heartbeat_events), None) => { + tokio::select! { + Some(event) = heartbeat_events.recv() => { + Self::handle_heartbeat_event(event, &mut heartbeat_healthy)?; + } + _ = sleep(Duration::from_secs_f32(0.1)) => {} + } + } + (None, Some(step_errors)) => { + tokio::select! { + Some(error) = step_errors.recv() => { + Self::record_step_error(error, &mut result); + } + _ = sleep(Duration::from_secs_f32(0.1)) => {} + } + } + (None, None) => { + sleep(Duration::from_secs_f32(0.1)).await; + } } } if logged_tasks_left.is_some() { trace!("The current step of every task is done") } - Ok(()) + Ok(result) + } + + fn record_step_errors( + step_errors: &mut mpsc::UnboundedReceiver, + result: &mut Result<()>, + ) { + while let Ok(error) = step_errors.try_recv() { + Self::record_step_error(error, result); + } + } + + fn record_step_error(error: Error, result: &mut Result<()>) { + if result.is_ok() { + *result = Err(error); + } } - async fn wait_for_steps_to_finish_or_heartbeat( + async fn wait_for_steps_to_finish_or_events( &self, semaphore: Arc, heartbeat_events: mpsc::UnboundedReceiver, - ) -> Result<()> { - self.wait_for_steps_to_finish_impl(semaphore, Some(heartbeat_events)) - .await + step_errors: mpsc::UnboundedReceiver, + result: Result<()>, + ) -> Result> { + self.wait_for_steps_to_finish_impl( + semaphore, + Some(heartbeat_events), + Some(step_errors), + result, + ) + .await } } #[cfg(test)] mod tests { - use super::{HeartbeatEvent, Worker, DEFAULT_HEARTBEAT_INTERVAL, DEFAULT_LEASE_TIMEOUT}; + use super::{ + HeartbeatEvent, RunEvents, Worker, DEFAULT_HEARTBEAT_INTERVAL, DEFAULT_LEASE_TIMEOUT, + }; use crate::{task::TaskLease, Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; use sqlx::{postgres::PgPoolOptions, PgPool}; @@ -867,6 +932,18 @@ mod tests { receiver } + fn idle_step_errors() -> mpsc::UnboundedReceiver { + let (_sender, receiver) = mpsc::unbounded_channel(); + receiver + } + + fn idle_run_events() -> RunEvents { + RunEvents { + heartbeat: idle_heartbeat_events(), + step_errors: idle_step_errors(), + } + } + async fn connect_to_current_db( pool: &PgPool, max_connections: u32, @@ -1865,7 +1942,7 @@ mod tests { idle_heartbeat(), Arc::new(Mutex::new(Vec::new())), false, - idle_heartbeat_events(), + idle_run_events(), ) .await } @@ -1883,6 +1960,57 @@ mod tests { )); } + #[tokio::test] + async fn finish_run_returns_step_errors_received_while_draining() { + init_tracing(); + let worker = Arc::new( + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)), + ); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let (step_error_sender, step_errors) = mpsc::unbounded_channel(); + + let finish = tokio::spawn({ + let worker = worker.clone(); + let semaphore = semaphore.clone(); + async move { + worker + .finish_run( + Ok(()), + semaphore, + idle_heartbeat(), + Arc::new(Mutex::new(Vec::new())), + false, + RunEvents { + heartbeat: idle_heartbeat_events(), + step_errors, + }, + ) + .await + } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!finish.is_finished()); + + step_error_sender + .send(Error::Db(sqlx::Error::PoolTimedOut, "step".into())) + .unwrap(); + drop(permit); + + let err = timeout(Duration::from_secs(1), finish) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); + } + #[tokio::test] async fn finish_run_keeps_heartbeat_alive_while_waiting_for_inflight_steps() { init_tracing(); @@ -1912,7 +2040,7 @@ mod tests { heartbeat_abort, Arc::new(Mutex::new(Vec::new())), false, - idle_heartbeat_events(), + idle_run_events(), ) .await } @@ -1957,7 +2085,7 @@ mod tests { idle_heartbeat(), running_steps, true, - idle_heartbeat_events(), + idle_run_events(), ), ) .await @@ -1997,7 +2125,10 @@ mod tests { idle_heartbeat(), running_steps, false, - heartbeat_events, + RunEvents { + heartbeat: heartbeat_events, + step_errors: idle_step_errors(), + }, ), ) .await @@ -2315,6 +2446,39 @@ mod tests { assert_eq!(task_count(&pool).await, 0); } + #[sqlx::test(migrations = "./migrations")] + async fn run_returns_step_errors_received_after_stop_while_draining(pool: PgPool) { + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: state.key() }), + false, + ) + .await; + + let worker = spawn_worker_with_concurrency(pool.clone(), 2); + + state.state().wait_for_events(1).await; + stop_worker(&pool).await; + sleep(Duration::from_millis(50)).await; + assert!(!worker.is_finished()); + + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") + .execute(&pool) + .await + .unwrap(); + state.state().release(); + + let err = timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap_err(); + + assert_eq!(state.state().events(), vec!["started", "completed"]); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); + } + #[sqlx::test(migrations = "./migrations")] async fn run_returns_spawned_step_persistence_errors(pool: PgPool) { let state = StepStateGuard::new(); From cff2db26716d004258993ee3357ca3cf7d9c7753 Mon Sep 17 00:00:00 2001 From: imbolc Date: Sun, 10 May 2026 16:05:50 +0600 Subject: [PATCH 17/44] Make task claiming non-cancellable --- ...d4f47672582033af5a77057e5dcaab41d9305.json | 28 +++ src/worker.rs | 165 +++++++++++++----- 2 files changed, 149 insertions(+), 44 deletions(-) create mode 100644 .sqlx/query-c86aa8374fd6533f5bc5e8a7d06d4f47672582033af5a77057e5dcaab41d9305.json diff --git a/.sqlx/query-c86aa8374fd6533f5bc5e8a7d06d4f47672582033af5a77057e5dcaab41d9305.json b/.sqlx/query-c86aa8374fd6533f5bc5e8a7d06d4f47672582033af5a77057e5dcaab41d9305.json new file mode 100644 index 0000000..6a74906 --- /dev/null +++ b/.sqlx/query-c86aa8374fd6533f5bc5e8a7d06d4f47672582033af5a77057e5dcaab41d9305.json @@ -0,0 +1,28 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT locked_by, lock_expires_at FROM pg_task WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "locked_by", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "lock_expires_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + true, + true + ] + }, + "hash": "c86aa8374fd6533f5bc5e8a7d06d4f47672582033af5a77057e5dcaab41d9305" +} diff --git a/src/worker.rs b/src/worker.rs index 65529f2..9a6d1e8 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -33,6 +33,11 @@ struct RunEvents { step_errors: mpsc::UnboundedReceiver, } +enum TaskAvailability { + Ready, + Stopped, +} + /// A worker for processing tasks pub struct Worker { db: PgPool, @@ -100,7 +105,7 @@ impl + 'static> Worker { let mut reserved_permit = None; let result = loop { - let received = tokio::select! { + let availability = tokio::select! { biased; Some(error) = step_errors.recv() => { @@ -143,25 +148,43 @@ impl + 'static> Worker { reserved_permit = Some(permit.map_err(Error::UnreachableWorkerSemaphoreClosed)?); continue; } - received = self.recv_task(lease), if heartbeat_healthy && reserved_permit.is_some() => received, + availability = self.wait_for_available_task(), if heartbeat_healthy && reserved_permit.is_some() => availability, }; - match received { - Ok(Some((task, step, lease))) => { - let permit = reserved_permit - .take() - .expect("task fetching requires a reserved semaphore permit"); - let db = self.db.clone(); - let step_error_sender = step_error_sender.clone(); - let step = tokio::spawn(async move { - if let Err(e) = task.run_step(&db, step, lease).await { - error!("[{}] {}", task.id, source_chain::to_string(&e)); - let _ = step_error_sender.send(e); - }; - drop(permit); - }); - Self::track_running_step(&running_steps, step.abort_handle()); - } - Ok(None) => { + match availability { + Ok(TaskAvailability::Ready) => match self.claim_available_task(lease).await { + Ok(Some((task, step, lease))) => { + let permit = reserved_permit + .take() + .expect("task claiming requires a reserved semaphore permit"); + let db = self.db.clone(); + let step_error_sender = step_error_sender.clone(); + let step = tokio::spawn(async move { + if let Err(e) = task.run_step(&db, step, lease).await { + error!("[{}] {}", task.id, source_chain::to_string(&e)); + let _ = step_error_sender.send(e); + }; + drop(permit); + }); + Self::track_running_step(&running_steps, step.abort_handle()); + } + Ok(None) => continue, + Err(e) => { + drop(reserved_permit.take()); + if let Err(error) = self + .handle_recv_task_error_or_heartbeat( + e, + &mut heartbeat_events, + &mut heartbeat_healthy, + &mut abort_running_steps, + ) + .await + { + drop(reserved_permit.take()); + break Err(error); + } + } + }, + Ok(TaskAvailability::Stopped) => { drop(reserved_permit.take()); break Ok(()); } @@ -203,14 +226,12 @@ impl + 'static> Worker { ); } - /// Waits until the next task is ready, marks it running and returns it. - /// Returns `None` if the worker is stopped - async fn recv_task(&self, lease: TaskLease) -> Result> { - trace!("Receiving the next task"); - + /// Waits until a task may be claimed without mutating task leases. + async fn wait_for_available_task(&self) -> Result { + trace!("Waiting for an available task"); loop { if self.listener.time_to_stop_worker() { - return Ok(None); + return Ok(TaskAvailability::Stopped); } if let Some(error) = self.listener.take_error() { @@ -220,26 +241,56 @@ impl + 'static> Worker { let table_changes = self.listener.subscribe(); let mut tx = self.db.begin().await.map_err(db_error!("begin"))?; - let Some(task) = Task::fetch_ready(&mut tx).await? else { - let next_available_at = Task::fetch_next_available_at(&mut tx).await?; - tx.commit().await.map_err(db_error!("no ready tasks"))?; + if Task::fetch_ready(&mut tx).await?.is_some() { + tx.commit().await.map_err(db_error!("ready task check"))?; + return Ok(TaskAvailability::Ready); + } - if let Some(available_at) = next_available_at { - let delay = - Task::delay_until(available_at).unwrap_or(LOCKED_TASK_RECHECK_DELAY); - table_changes.wait_for(delay).await; - } else { - table_changes.wait_forever().await; - } - continue; - }; + let next_available_at = Task::fetch_next_available_at(&mut tx).await?; + tx.commit().await.map_err(db_error!("no ready tasks"))?; - let Some(step) = task.claim(&mut tx, lease).await? else { - tx.commit().await.map_err(db_error!("save error"))?; - continue; - }; - tx.commit().await.map_err(db_error!("mark running"))?; - return Ok(Some((task, step, lease))); + if let Some(available_at) = next_available_at { + let delay = Task::delay_until(available_at).unwrap_or(LOCKED_TASK_RECHECK_DELAY); + table_changes.wait_for(delay).await; + } else { + table_changes.wait_forever().await; + } + } + } + + /// Claims a currently available task and marks it running. + async fn claim_available_task(&self, lease: TaskLease) -> Result> { + trace!("Claiming an available task"); + let mut tx = self.db.begin().await.map_err(db_error!("begin"))?; + + let Some(task) = Task::fetch_ready(&mut tx).await? else { + tx.commit().await.map_err(db_error!("no ready tasks"))?; + return Ok(None); + }; + + let Some(step) = task.claim(&mut tx, lease).await? else { + tx.commit().await.map_err(db_error!("save error"))?; + return Ok(None); + }; + tx.commit().await.map_err(db_error!("mark running"))?; + Ok(Some((task, step, lease))) + } + + /// Waits until the next task is ready, marks it running and returns it. + /// Returns `None` if the worker is stopped + #[cfg(test)] + async fn recv_task(&self, lease: TaskLease) -> Result> { + trace!("Receiving the next task"); + + loop { + match self.wait_for_available_task().await? { + TaskAvailability::Ready => {} + TaskAvailability::Stopped => return Ok(None), + } + + if let Some(task) = self.claim_available_task(lease).await? { + return Ok(Some(task)); + } } } @@ -564,7 +615,8 @@ impl + 'static> Worker { #[cfg(test)] mod tests { use super::{ - HeartbeatEvent, RunEvents, Worker, DEFAULT_HEARTBEAT_INTERVAL, DEFAULT_LEASE_TIMEOUT, + HeartbeatEvent, RunEvents, TaskAvailability, Worker, DEFAULT_HEARTBEAT_INTERVAL, + DEFAULT_LEASE_TIMEOUT, }; use crate::{task::TaskLease, Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; @@ -1461,6 +1513,31 @@ mod tests { } } + #[sqlx::test(migrations = "./migrations")] + async fn wait_for_available_task_does_not_claim_ready_tasks(pool: PgPool) { + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let worker = Worker::::new(pool.clone()); + + let availability = worker.wait_for_available_task().await.unwrap(); + + assert!(matches!(availability, TaskAvailability::Ready)); + let lease = sqlx::query!( + "SELECT locked_by, lock_expires_at FROM pg_task WHERE id = $1", + id, + ) + .fetch_one(&pool) + .await + .unwrap(); + assert!(lease.locked_by.is_none()); + assert!(lease.lock_expires_at.is_none()); + } + #[sqlx::test(migrations = "./migrations")] async fn recv_task_stops_while_waiting_for_work(pool: PgPool) { let worker = Arc::new(Worker::::new(pool)); From 4f4e00539883e4ac729d573625e8acf7c1f882ed Mon Sep 17 00:00:00 2001 From: imbolc Date: Sun, 10 May 2026 16:22:43 +0600 Subject: [PATCH 18/44] Detect partial lease renewal failures --- ...912a20c6b9f92396c707dbd8411b3e30d625f.json | 24 +++ ...07d7c64ad2c7d851bf25984dccb39665a3e40.json | 15 -- Cargo.toml | 2 +- migrations/20260509130000_task-leases.sql | 4 - .../20260510000100_task-lease-indexes.sql | 2 - src/task.rs | 48 +++-- src/worker.rs | 204 ++++++++++++++---- 7 files changed, 225 insertions(+), 74 deletions(-) create mode 100644 .sqlx/query-0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f.json delete mode 100644 .sqlx/query-6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40.json diff --git a/.sqlx/query-0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f.json b/.sqlx/query-0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f.json new file mode 100644 index 0000000..02dfdb1 --- /dev/null +++ b/.sqlx/query-0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f.json @@ -0,0 +1,24 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET lock_expires_at = now() + $2::interval\n WHERE locked_by = $1\n AND id = ANY($3)\n AND lock_expires_at > now()\n AND error IS NULL\n RETURNING id\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Interval", + "UuidArray" + ] + }, + "nullable": [ + false + ] + }, + "hash": "0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f" +} diff --git a/.sqlx/query-6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40.json b/.sqlx/query-6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40.json deleted file mode 100644 index d15ede7..0000000 --- a/.sqlx/query-6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40.json +++ /dev/null @@ -1,15 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET lock_expires_at = now() + $2::interval\n WHERE locked_by = $1\n AND lock_expires_at > now()\n AND error IS NULL\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Interval" - ] - }, - "nullable": [] - }, - "hash": "6e749543b97bb037981ba70bc0407d7c64ad2c7d851bf25984dccb39665a3e40" -} diff --git a/Cargo.toml b/Cargo.toml index d007d48..d05a4ec 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,7 +4,7 @@ edition = "2021" license = "MIT" name = "pg_task" repository = "https://github.com/imbolc/pg_task" -version = "0.2.2" +version = "0.3.0" [dependencies] async-trait = "0.1" diff --git a/migrations/20260509130000_task-leases.sql b/migrations/20260509130000_task-leases.sql index ef9483c..b1ac808 100644 --- a/migrations/20260509130000_task-leases.sql +++ b/migrations/20260509130000_task-leases.sql @@ -13,10 +13,6 @@ ADD CONSTRAINT pg_task_lease_state_check CHECK ( OR (locked_by IS NOT NULL AND lock_expires_at IS NOT NULL) ); -CREATE INDEX pg_task_lock_expires_at_idx ON pg_task (lock_expires_at) -WHERE locked_by IS NOT NULL - AND error IS NULL; - ALTER TABLE pg_task DROP COLUMN is_running; COMMENT ON COLUMN pg_task.locked_by IS 'Worker currently owning the running step lease'; diff --git a/migrations/20260510000100_task-lease-indexes.sql b/migrations/20260510000100_task-lease-indexes.sql index 253bcd2..ecb3ea0 100644 --- a/migrations/20260510000100_task-lease-indexes.sql +++ b/migrations/20260510000100_task-lease-indexes.sql @@ -1,5 +1,3 @@ -DROP INDEX pg_task_lock_expires_at_idx; - CREATE INDEX pg_task_locked_by_idx ON pg_task (locked_by) WHERE locked_by IS NOT NULL diff --git a/src/task.rs b/src/task.rs index eb6a629..1ed1a47 100644 --- a/src/task.rs +++ b/src/task.rs @@ -122,23 +122,33 @@ impl Task { Ok(()) } - /// Renews all live task leases owned by a worker. - pub(crate) async fn renew_leases(db: &PgPool, lease: TaskLease) -> Result { + /// Renews live task leases owned by a worker. + pub(crate) async fn renew_leases( + db: &PgPool, + lease: TaskLease, + task_ids: &[Uuid], + ) -> Result> { + if task_ids.is_empty() { + return Ok(Vec::new()); + } trace!("Renewing task leases for worker {}", lease.worker_id); sqlx::query!( r#" UPDATE pg_task SET lock_expires_at = now() + $2::interval WHERE locked_by = $1 + AND id = ANY($3) AND lock_expires_at > now() AND error IS NULL + RETURNING id "#, lease.worker_id, lease.timeout, + task_ids, ) - .execute(db) + .fetch_all(db) .await - .map(|result| result.rows_affected()) + .map(|rows| rows.into_iter().map(|row| row.id).collect()) .map_err(db_error!("renew leases")) } @@ -308,8 +318,6 @@ impl Task { Ok(x) => x, Err(e) => return self.save_step_error(db, e.into(), true, lease).await, }; - debug!("[{}] moved to the next step {step}", self.id); - let result = sqlx::query!( " UPDATE pg_task @@ -332,6 +340,8 @@ impl Task { .map_err(db_error!())?; if result.rows_affected() == 0 { self.log_lost_lease(lease.worker_id, "save the next step"); + } else { + debug!("[{}] moved to the next step {step}", self.id); } Ok(()) } @@ -370,12 +380,6 @@ impl Task { lease: TaskLease, ) -> Result<()> { let delay = std_duration_to_chrono(delay); - debug!( - "[{id}] scheduled {attempt} of {retry_limit} retries in {delay:?} on error: {err}", - id = self.id, - attempt = ordinal(tried + 1), - err = source_chain::to_string(&*err), - ); let result = sqlx::query!( " @@ -397,6 +401,13 @@ impl Task { .map_err(db_error!())?; if result.rows_affected() == 0 { self.log_lost_lease(lease.worker_id, "schedule a retry"); + } else { + debug!( + "[{id}] scheduled {attempt} of {retry_limit} retries in {delay:?} on error: {err}", + id = self.id, + attempt = ordinal(tried + 1), + err = source_chain::to_string(&*err), + ); } Ok(()) @@ -942,10 +953,12 @@ mod tests { .await; let started_at = Utc::now(); - let renewed = Task::renew_leases(&pool, task_lease()).await.unwrap(); + let renewed = Task::renew_leases(&pool, task_lease(), &[owned, expired, other_worker]) + .await + .unwrap(); let finished_at = Utc::now(); - assert_eq!(renewed, 1); + assert_eq!(renewed, vec![owned]); let owned = fetch_task_row(&pool, owned).await.unwrap(); assert_timestamp_between( owned.lock_expires_at.unwrap(), @@ -995,7 +1008,12 @@ mod tests { ) .await; - assert_eq!(Task::renew_leases(&pool, task_lease()).await.unwrap(), 0); + assert!( + Task::renew_leases(&pool, task_lease(), &[expired, other_worker]) + .await + .unwrap() + .is_empty() + ); } #[sqlx::test(migrations = "./migrations")] diff --git a/src/worker.rs b/src/worker.rs index 9a6d1e8..45cb5aa 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -33,6 +33,11 @@ struct RunEvents { step_errors: mpsc::UnboundedReceiver, } +struct RunningStep { + task_id: Uuid, + abort_handle: tokio::task::AbortHandle, +} + enum TaskAvailability { Ready, Stopped, @@ -158,6 +163,7 @@ impl + 'static> Worker { .expect("task claiming requires a reserved semaphore permit"); let db = self.db.clone(); let step_error_sender = step_error_sender.clone(); + let task_id = task.id; let step = tokio::spawn(async move { if let Err(e) = task.run_step(&db, step, lease).await { error!("[{}] {}", task.id, source_chain::to_string(&e)); @@ -165,7 +171,7 @@ impl + 'static> Worker { }; drop(permit); }); - Self::track_running_step(&running_steps, step.abort_handle()); + Self::track_running_step(&running_steps, task_id, step.abort_handle()); } Ok(None) => continue, Err(e) => { @@ -362,7 +368,7 @@ impl + 'static> Worker { fn spawn_heartbeat( &self, events: mpsc::UnboundedSender, - running_steps: Arc>>, + running_steps: Arc>>, lease: TaskLease, ) -> tokio::task::AbortHandle { self.validate_lease_timing(); @@ -377,7 +383,8 @@ impl + 'static> Worker { heartbeat.tick().await; loop { heartbeat.tick().await; - if !Self::has_running_steps(&running_steps) { + let running_task_ids = Self::running_task_ids(&running_steps); + if running_task_ids.is_empty() { last_renewed_at = Instant::now(); if renewal_failed { let _ = events.send(HeartbeatEvent::Recovered); @@ -385,24 +392,27 @@ impl + 'static> Worker { } continue; } - match Task::renew_leases(&db, lease).await { - Ok(renewed) if renewed > 0 => { - trace!("Renewed {renewed} task leases"); - last_renewed_at = Instant::now(); - if renewal_failed { - let _ = events.send(HeartbeatEvent::Recovered); - renewal_failed = false; - } - } - Ok(_) if !Self::has_running_steps(&running_steps) => { + match Task::renew_leases(&db, lease, &running_task_ids).await { + Ok(renewed_task_ids) + if Self::renewed_all_running_leases( + &running_task_ids, + &renewed_task_ids, + &running_steps, + ) => + { + trace!("Renewed {} task leases", renewed_task_ids.len()); last_renewed_at = Instant::now(); if renewal_failed { let _ = events.send(HeartbeatEvent::Recovered); renewal_failed = false; } } - Ok(_) => { - warn!("Task lease renewal updated no rows while steps are still running"); + Ok(renewed_task_ids) => { + warn!( + "Task lease renewal updated {} of {} running task leases", + renewed_task_ids.len(), + running_task_ids.len() + ); if !renewal_failed { let _ = events.send(HeartbeatEvent::Failed); renewal_failed = true; @@ -423,7 +433,7 @@ impl + 'static> Worker { let _ = events.send(HeartbeatEvent::Failed); renewal_failed = true; } - if Self::has_running_steps(&running_steps) + if !Self::running_task_ids(&running_steps).is_empty() && last_renewed_at.elapsed().saturating_add(heartbeat_interval) >= lease_timeout { @@ -438,31 +448,51 @@ impl + 'static> Worker { } fn track_running_step( - running_steps: &Mutex>, - step: tokio::task::AbortHandle, + running_steps: &Mutex>, + task_id: Uuid, + abort_handle: tokio::task::AbortHandle, ) { let mut running_steps = running_steps .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); - running_steps.retain(|step| !step.is_finished()); - running_steps.push(step); + running_steps.retain(|step| !step.abort_handle.is_finished()); + running_steps.push(RunningStep { + task_id, + abort_handle, + }); } - fn abort_running_steps(running_steps: &Mutex>) { + fn abort_running_steps(running_steps: &Mutex>) { let running_steps = running_steps .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); for step in &*running_steps { - step.abort(); + step.abort_handle.abort(); } } - fn has_running_steps(running_steps: &Mutex>) -> bool { - running_steps + #[cfg(test)] + fn has_running_steps(running_steps: &Mutex>) -> bool { + !Self::running_task_ids(running_steps).is_empty() + } + + fn running_task_ids(running_steps: &Mutex>) -> Vec { + let mut running_steps = running_steps .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .iter() - .any(|step| !step.is_finished()) + .unwrap_or_else(std::sync::PoisonError::into_inner); + running_steps.retain(|step| !step.abort_handle.is_finished()); + running_steps.iter().map(|step| step.task_id).collect() + } + + fn renewed_all_running_leases( + running_task_ids: &[Uuid], + renewed_task_ids: &[Uuid], + running_steps: &Mutex>, + ) -> bool { + let still_running_task_ids = Self::running_task_ids(running_steps); + running_task_ids.iter().all(|task_id| { + renewed_task_ids.contains(task_id) || !still_running_task_ids.contains(task_id) + }) } async fn finish_run( @@ -470,7 +500,7 @@ impl + 'static> Worker { result: Result<()>, semaphore: Arc, heartbeat: tokio::task::AbortHandle, - running_steps: Arc>>, + running_steps: Arc>>, abort_running_steps: bool, events: RunEvents, ) -> Result<()> { @@ -615,8 +645,8 @@ impl + 'static> Worker { #[cfg(test)] mod tests { use super::{ - HeartbeatEvent, RunEvents, TaskAvailability, Worker, DEFAULT_HEARTBEAT_INTERVAL, - DEFAULT_LEASE_TIMEOUT, + HeartbeatEvent, RunEvents, RunningStep, TaskAvailability, Worker, + DEFAULT_HEARTBEAT_INTERVAL, DEFAULT_LEASE_TIMEOUT, }; use crate::{task::TaskLease, Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; @@ -941,6 +971,15 @@ mod tests { } async fn set_task_lease(pool: &PgPool, id: Uuid, lock_expires_at: chrono::DateTime) { + set_task_lease_for_worker(pool, id, Uuid::from_u128(1), lock_expires_at).await; + } + + async fn set_task_lease_for_worker( + pool: &PgPool, + id: Uuid, + worker_id: Uuid, + lock_expires_at: chrono::DateTime, + ) { sqlx::query!( r#" UPDATE pg_task @@ -949,7 +988,7 @@ mod tests { WHERE id = $1 "#, id, - Uuid::from_u128(1), + worker_id, lock_expires_at, ) .execute(pool) @@ -1037,6 +1076,13 @@ mod tests { TaskLease::new(Uuid::new_v4(), worker.lease_timeout) } + fn running_step_entry(task_id: Uuid, abort_handle: tokio::task::AbortHandle) -> RunningStep { + RunningStep { + task_id, + abort_handle, + } + } + fn spawn_worker(pool: PgPool) -> tokio::task::JoinHandle> { spawn_worker_with_concurrency(pool, 1) } @@ -1255,7 +1301,10 @@ mod tests { let running_step = tokio::spawn(async { std::future::pending::<()>().await; }); - let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); + let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + Uuid::new_v4(), + running_step.abort_handle(), + )])); let (events, mut events_receiver) = mpsc::unbounded_channel(); let heartbeat = worker.spawn_heartbeat(events, running_steps, worker_lease(&worker)); @@ -1278,6 +1327,75 @@ mod tests { running_step.abort(); } + #[sqlx::test(migrations = "./migrations")] + async fn heartbeat_expires_when_any_running_step_loses_its_lease(pool: PgPool) { + init_tracing(); + let worker = Worker::::new(pool.clone()) + .with_lease_timeout(Duration::from_millis(80)) + .with_heartbeat_interval(Duration::from_millis(20)); + let worker_id = Uuid::new_v4(); + let lease = TaskLease::new(worker_id, worker.lease_timeout); + let live = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let expired = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + set_task_lease_for_worker( + &pool, + live, + worker_id, + Utc::now() + ChronoDuration::milliseconds(200), + ) + .await; + set_task_lease_for_worker( + &pool, + expired, + worker_id, + Utc::now() - ChronoDuration::milliseconds(1), + ) + .await; + let live_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let expired_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![ + running_step_entry(live, live_step.abort_handle()), + running_step_entry(expired, expired_step.abort_handle()), + ])); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Failed)); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!( + event, + HeartbeatEvent::Expired(Error::TaskLeaseExpired) + )); + + heartbeat.abort(); + live_step.abort(); + expired_step.abort(); + } + #[sqlx::test(migrations = "./migrations")] async fn heartbeat_skips_pool_timeouts_without_running_steps(pool: PgPool) { init_tracing(); @@ -1336,7 +1454,10 @@ mod tests { let running_step = tokio::spawn(async { std::future::pending::<()>().await; }); - let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); + let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + id, + running_step.abort_handle(), + )])); let (events, mut events_receiver) = mpsc::unbounded_channel(); let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); @@ -1371,9 +1492,12 @@ mod tests { std::future::pending::<()>().await; }); let running_step_abort = running_step.abort_handle(); - let running_steps = Mutex::new(vec![finished_step_abort]); + let running_steps = Mutex::new(vec![running_step_entry( + Uuid::new_v4(), + finished_step_abort, + )]); - Worker::::track_running_step(&running_steps, running_step_abort); + Worker::::track_running_step(&running_steps, Uuid::new_v4(), running_step_abort); assert_eq!( running_steps @@ -2152,7 +2276,10 @@ mod tests { let _permit = permit; std::future::pending::<()>().await; }); - let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); + let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + Uuid::new_v4(), + running_step.abort_handle(), + )])); let err = timeout( Duration::from_secs(1), @@ -2188,7 +2315,10 @@ mod tests { let _permit = permit; std::future::pending::<()>().await; }); - let running_steps = Arc::new(Mutex::new(vec![running_step.abort_handle()])); + let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + Uuid::new_v4(), + running_step.abort_handle(), + )])); let (heartbeat_events_sender, heartbeat_events) = mpsc::unbounded_channel(); heartbeat_events_sender .send(HeartbeatEvent::Expired(Error::TaskLeaseExpired)) From 1c29201474db425baf174a323df2ede4c611ecd9 Mon Sep 17 00:00:00 2001 From: imbolc Date: Mon, 11 May 2026 06:30:56 +0600 Subject: [PATCH 19/44] Exclude .sqlx from PR review --- .gitattributes | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitattributes b/.gitattributes index b3839fb..252b938 100644 --- a/.gitattributes +++ b/.gitattributes @@ -4,3 +4,6 @@ # Exclude from `git grep` /.sqlx/*.* binary + +# Mark it as generated so GitHub collapses it in PR diffs/reviews +.sqlx/** linguist-generated=true From f3b51719593dd4e56a82456a9e6b4593f8d963ae Mon Sep 17 00:00:00 2001 From: imbolc Date: Mon, 11 May 2026 06:37:51 +0600 Subject: [PATCH 20/44] Combine branch lease migrations --- migrations/20260509130000_task-leases.sql | 54 +++++++++++++++++++ .../20260510000000_notify-on-delete.sql | 37 ------------- .../20260510000100_task-lease-indexes.sql | 15 ------ 3 files changed, 54 insertions(+), 52 deletions(-) delete mode 100644 migrations/20260510000000_notify-on-delete.sql delete mode 100644 migrations/20260510000100_task-lease-indexes.sql diff --git a/migrations/20260509130000_task-leases.sql b/migrations/20260509130000_task-leases.sql index b1ac808..b137ab4 100644 --- a/migrations/20260509130000_task-leases.sql +++ b/migrations/20260509130000_task-leases.sql @@ -2,6 +2,8 @@ ALTER TABLE pg_task ADD COLUMN locked_by UUID, ADD COLUMN lock_expires_at timestamptz; +DROP TRIGGER pg_task_changed ON pg_task; + UPDATE pg_task SET locked_by = gen_random_uuid(), lock_expires_at = now() @@ -17,3 +19,55 @@ ALTER TABLE pg_task DROP COLUMN is_running; COMMENT ON COLUMN pg_task.locked_by IS 'Worker currently owning the running step lease'; COMMENT ON COLUMN pg_task.lock_expires_at IS 'Time when the running step lease expires and can be reclaimed'; + +CREATE OR REPLACE FUNCTION pg_task_notify_on_change() +RETURNS trigger AS $$ +BEGIN + PERFORM pg_notify('pg_task_changed', ''); + IF TG_OP = 'DELETE' THEN + RETURN OLD; + END IF; + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER pg_task_changed_insert +AFTER INSERT +ON pg_task +FOR EACH ROW +EXECUTE PROCEDURE pg_task_notify_on_change(); + +CREATE TRIGGER pg_task_changed_delete +AFTER DELETE +ON pg_task +FOR EACH ROW +EXECUTE PROCEDURE pg_task_notify_on_change(); + +CREATE TRIGGER pg_task_changed_update +AFTER UPDATE +ON pg_task +FOR EACH ROW +WHEN ( + OLD.step IS DISTINCT FROM NEW.step + OR OLD.wakeup_at IS DISTINCT FROM NEW.wakeup_at + OR OLD.tried IS DISTINCT FROM NEW.tried + OR OLD.error IS DISTINCT FROM NEW.error + OR OLD.locked_by IS DISTINCT FROM NEW.locked_by +) +EXECUTE PROCEDURE pg_task_notify_on_change(); + +CREATE INDEX pg_task_locked_by_idx +ON pg_task (locked_by) +WHERE locked_by IS NOT NULL + AND error IS NULL; + +CREATE INDEX pg_task_next_available_at_idx +ON pg_task (( + CASE + WHEN locked_by IS NOT NULL THEN + GREATEST(wakeup_at, lock_expires_at) + ELSE + wakeup_at + END +)) +WHERE error IS NULL; diff --git a/migrations/20260510000000_notify-on-delete.sql b/migrations/20260510000000_notify-on-delete.sql deleted file mode 100644 index f1e7e9e..0000000 --- a/migrations/20260510000000_notify-on-delete.sql +++ /dev/null @@ -1,37 +0,0 @@ -DROP TRIGGER pg_task_changed ON pg_task; - -CREATE OR REPLACE FUNCTION pg_task_notify_on_change() -RETURNS trigger AS $$ -BEGIN - PERFORM pg_notify('pg_task_changed', ''); - IF TG_OP = 'DELETE' THEN - RETURN OLD; - END IF; - RETURN NEW; -END; -$$ LANGUAGE plpgsql; - -CREATE TRIGGER pg_task_changed_insert -AFTER INSERT -ON pg_task -FOR EACH ROW -EXECUTE PROCEDURE pg_task_notify_on_change(); - -CREATE TRIGGER pg_task_changed_delete -AFTER DELETE -ON pg_task -FOR EACH ROW -EXECUTE PROCEDURE pg_task_notify_on_change(); - -CREATE TRIGGER pg_task_changed_update -AFTER UPDATE -ON pg_task -FOR EACH ROW -WHEN ( - OLD.step IS DISTINCT FROM NEW.step - OR OLD.wakeup_at IS DISTINCT FROM NEW.wakeup_at - OR OLD.tried IS DISTINCT FROM NEW.tried - OR OLD.error IS DISTINCT FROM NEW.error - OR OLD.locked_by IS DISTINCT FROM NEW.locked_by -) -EXECUTE PROCEDURE pg_task_notify_on_change(); diff --git a/migrations/20260510000100_task-lease-indexes.sql b/migrations/20260510000100_task-lease-indexes.sql deleted file mode 100644 index ecb3ea0..0000000 --- a/migrations/20260510000100_task-lease-indexes.sql +++ /dev/null @@ -1,15 +0,0 @@ -CREATE INDEX pg_task_locked_by_idx -ON pg_task (locked_by) -WHERE locked_by IS NOT NULL - AND error IS NULL; - -CREATE INDEX pg_task_next_available_at_idx -ON pg_task (( - CASE - WHEN locked_by IS NOT NULL THEN - GREATEST(wakeup_at, lock_expires_at) - ELSE - wakeup_at - END -)) -WHERE error IS NULL; From 13f77bc758a4727849edd4589f1a4e802e8e59ae Mon Sep 17 00:00:00 2001 From: imbolc Date: Mon, 11 May 2026 06:55:03 +0600 Subject: [PATCH 21/44] Cache SQLx test harness queries --- .pre-commit.sh | 12 ++++++---- ...44d100ea478371c9bf0fe6bae27744225e1b6.json | 22 +++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) create mode 100644 .sqlx/query-5b28fca1c7225c43f465cb24fd244d100ea478371c9bf0fe6bae27744225e1b6.json diff --git a/.pre-commit.sh b/.pre-commit.sh index fc2d6a5..8361a6f 100755 --- a/.pre-commit.sh +++ b/.pre-commit.sh @@ -29,7 +29,11 @@ typos . cargo shear cargo +nightly fmt -- --check cargo sort -c -cargo test --all-targets -cargo test --doc -cargo sqlx prepare -- --all-targets && git add .sqlx -cargo clippy --all-targets -- -D warnings + +cargo sqlx prepare -- --all-targets --all-features +# `cargo sqlx prepare` uses `cargo check`, which misses query macros compiled only by test harnesses. +SQLX_OFFLINE=false SQLX_OFFLINE_DIR=.sqlx cargo test --all-targets --all-features --no-run +git add .sqlx +SQLX_OFFLINE=true cargo test --all-targets +SQLX_OFFLINE=true cargo test --doc +SQLX_OFFLINE=true cargo clippy --all-targets --all-features -- -D warnings diff --git a/.sqlx/query-5b28fca1c7225c43f465cb24fd244d100ea478371c9bf0fe6bae27744225e1b6.json b/.sqlx/query-5b28fca1c7225c43f465cb24fd244d100ea478371c9bf0fe6bae27744225e1b6.json new file mode 100644 index 0000000..2b8c480 --- /dev/null +++ b/.sqlx/query-5b28fca1c7225c43f465cb24fd244d100ea478371c9bf0fe6bae27744225e1b6.json @@ -0,0 +1,22 @@ +{ + "db_name": "PostgreSQL", + "query": "SELECT id FROM pg_task WHERE id = $1", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + } + ], + "parameters": { + "Left": [ + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "5b28fca1c7225c43f465cb24fd244d100ea478371c9bf0fe6bae27744225e1b6" +} From afd74c62050cbd97554e594bf78d10e796ebb708 Mon Sep 17 00:00:00 2001 From: imbolc Date: Mon, 11 May 2026 07:02:33 +0600 Subject: [PATCH 22/44] Provide Postgres for CI tests --- .github/workflows/ci.yml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e3b5a82..a54a1c0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -56,6 +56,22 @@ jobs: test: runs-on: ubuntu-latest + services: + postgres: + image: postgres:16 + env: + POSTGRES_DB: pg_task + POSTGRES_PASSWORD: postgres + POSTGRES_USER: postgres + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 10s + --health-timeout 5s + --health-retries 5 + env: + DATABASE_URL: postgres://postgres:postgres@localhost:5432/pg_task steps: - name: Checkout uses: actions/checkout@v4 From 1b43b90c14f33e3b57d952e6443481d7e3ef0792 Mon Sep 17 00:00:00 2001 From: imbolc Date: Mon, 11 May 2026 07:28:17 +0600 Subject: [PATCH 23/44] Preserve CI database URL in tests --- examples/util.rs | 21 ++++++++++++++++++++- src/util.rs | 18 ++++++++++++++++-- src/worker.rs | 18 ++++++++++++++++-- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/examples/util.rs b/examples/util.rs index 7549152..fb7868c 100644 --- a/examples/util.rs +++ b/examples/util.rs @@ -37,7 +37,7 @@ mod tests { .await .unwrap(); - std::env::set_var("DATABASE_URL", format!("postgres:///{db_name}")); + std::env::set_var("DATABASE_URL", current_database_url(&db_name)); std::env::remove_var("RUST_LOG"); let db = init().await.unwrap(); @@ -52,4 +52,23 @@ mod tests { .await .unwrap(); } + + // Point DATABASE_URL at the database created by sqlx::test while keeping + // the original host, user, password, and query parameters. CI connects over + // TCP with password auth, so postgres:///{db_name} would lose credentials. + fn current_database_url(db_name: &str) -> String { + let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); + let (url, query) = database_url + .split_once('?') + .map_or((database_url.as_str(), None), |(url, query)| { + (url, Some(query)) + }); + let (prefix, _) = url.rsplit_once('/').unwrap(); + let mut current_database_url = format!("{prefix}/{db_name}"); + if let Some(query) = query { + current_database_url.push('?'); + current_database_url.push_str(query); + } + current_database_url + } } diff --git a/src/util.rs b/src/util.rs index c71c214..7a8feb9 100644 --- a/src/util.rs +++ b/src/util.rs @@ -107,7 +107,10 @@ mod tests { ordinal, std_duration_to_chrono, wait_for_reconnection, }; use chrono::Duration as ChronoDuration; - use sqlx::{postgres::PgPoolOptions, PgPool}; + use sqlx::{ + postgres::{PgConnectOptions, PgPoolOptions}, + PgPool, + }; use std::{io, time::Duration}; #[test] @@ -215,7 +218,7 @@ mod tests { let retry_pool = PgPoolOptions::new() .max_connections(1) .acquire_timeout(Duration::from_millis(20)) - .connect(&format!("postgres:///{db_name}")) + .connect_with(current_database_options(&db_name)) .await .unwrap(); let held_connection = retry_pool.acquire().await.unwrap(); @@ -231,4 +234,15 @@ mod tests { waiter.await.unwrap().unwrap(); } + + // Connect to the database created by sqlx::test while keeping the + // connection settings from DATABASE_URL. CI needs its TCP host and password; + // postgres:///{db_name} only works for local peer-auth socket setups. + fn current_database_options(db_name: &str) -> PgConnectOptions { + std::env::var("DATABASE_URL") + .expect("DATABASE_URL must be set") + .parse::() + .unwrap() + .database(db_name) + } } diff --git a/src/worker.rs b/src/worker.rs index 45cb5aa..2b4e547 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -650,7 +650,10 @@ mod tests { }; use crate::{task::TaskLease, Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; - use sqlx::{postgres::PgPoolOptions, PgPool}; + use sqlx::{ + postgres::{PgConnectOptions, PgPoolOptions}, + PgPool, + }; use std::{ collections::HashMap, io, @@ -1048,11 +1051,22 @@ mod tests { PgPoolOptions::new() .max_connections(max_connections) .acquire_timeout(acquire_timeout) - .connect(&format!("postgres:///{db_name}")) + .connect_with(current_database_options(&db_name)) .await .unwrap() } + // Connect to the database created by sqlx::test while keeping the + // connection settings from DATABASE_URL. CI needs its TCP host and password; + // postgres:///{db_name} only works for local peer-auth socket setups. + fn current_database_options(db_name: &str) -> PgConnectOptions { + std::env::var("DATABASE_URL") + .expect("DATABASE_URL must be set") + .parse::() + .unwrap() + .database(db_name) + } + async fn task_count(pool: &PgPool) -> i64 { sqlx::query!("SELECT id FROM pg_task") .fetch_all(pool) From a3b9cab4a9c22e1d6084bff4a0381fb000863457 Mon Sep 17 00:00:00 2001 From: imbolc Date: Mon, 11 May 2026 07:37:15 +0600 Subject: [PATCH 24/44] Use live database for CI tests --- .github/workflows/ci.yml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a54a1c0..0119138 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -6,9 +6,6 @@ on: - ci pull_request: -env: - SQLX_OFFLINE: true - jobs: rustfmt: runs-on: ubuntu-latest @@ -27,6 +24,8 @@ jobs: clippy: runs-on: ubuntu-latest + env: + SQLX_OFFLINE: true steps: - name: Checkout uses: actions/checkout@v4 @@ -42,6 +41,8 @@ jobs: rustdoc: runs-on: ubuntu-latest + env: + SQLX_OFFLINE: true steps: - name: Checkout uses: actions/checkout@v4 From a261c25b59b6226bfd854302890c7c0820fb6321 Mon Sep 17 00:00:00 2001 From: imbolc Date: Mon, 11 May 2026 07:55:18 +0600 Subject: [PATCH 25/44] Run migrations before CI tests --- .github/workflows/ci.yml | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0119138..e7600d8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -82,6 +82,12 @@ jobs: with: toolchain: stable + - name: Install sqlx-cli + run: cargo install --locked sqlx-cli --version 0.8.6 --no-default-features --features postgres,rustls + + - name: Run migrations + run: cargo sqlx migrate run + - name: Test all targets run: cargo test --all-targets From 3c8993ea7bc9a86b528924603947550855a82a9a Mon Sep 17 00:00:00 2001 From: imbolc Date: Mon, 11 May 2026 08:15:00 +0600 Subject: [PATCH 26/44] Relax pool timeout tests for CI --- src/util.rs | 6 +++++- src/worker.rs | 18 +++++++++++------- 2 files changed, 16 insertions(+), 8 deletions(-) diff --git a/src/util.rs b/src/util.rs index 7a8feb9..8a37013 100644 --- a/src/util.rs +++ b/src/util.rs @@ -113,6 +113,10 @@ mod tests { }; use std::{io, time::Duration}; + // Short enough to exercise PoolTimedOut, but long enough for CI to open + // the first TCP connection before the pool is intentionally exhausted. + const POOL_TIMEOUT: Duration = Duration::from_millis(100); + #[test] fn chrono_duration_to_std_uses_the_absolute_value() { let duration = ChronoDuration::seconds(-1) - ChronoDuration::milliseconds(250); @@ -217,7 +221,7 @@ mod tests { .unwrap(); let retry_pool = PgPoolOptions::new() .max_connections(1) - .acquire_timeout(Duration::from_millis(20)) + .acquire_timeout(POOL_TIMEOUT) .connect_with(current_database_options(&db_name)) .await .unwrap(); diff --git a/src/worker.rs b/src/worker.rs index 2b4e547..9b6127f 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -670,6 +670,10 @@ mod tests { }; use uuid::Uuid; + // Short enough to exercise PoolTimedOut, but long enough for CI to open + // the first TCP connection before the pool is intentionally exhausted. + const POOL_TIMEOUT: Duration = Duration::from_millis(100); + fn init_tracing() { static INIT: std::sync::Once = std::sync::Once::new(); INIT.call_once(|| { @@ -1413,7 +1417,7 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn heartbeat_skips_pool_timeouts_without_running_steps(pool: PgPool) { init_tracing(); - let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; let held_connection = worker_pool.acquire().await.unwrap(); let worker = Worker::::new(worker_pool) .with_lease_timeout(Duration::from_millis(500)) @@ -1436,7 +1440,7 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn heartbeat_reports_recovery_after_live_leases_are_renewed(pool: PgPool) { init_tracing(); - let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; let held_connection = worker_pool.acquire().await.unwrap(); let worker = Worker::::new(worker_pool) .with_lease_timeout(Duration::from_millis(500)) @@ -1984,7 +1988,7 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn run_pauses_fetching_while_heartbeat_cannot_renew(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; let state = StepStateGuard::new(); let worker = tokio::spawn(async move { @@ -2018,7 +2022,7 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn run_returns_listener_errors_while_fetching_is_paused(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; let worker = Arc::new( Worker::::new(worker_pool) .with_concurrency(nonzero(1)) @@ -2050,7 +2054,7 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn run_keeps_waiting_after_retryable_errors_while_fetching_is_paused(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; let worker = Arc::new( Worker::::new(worker_pool) .with_concurrency(nonzero(1)) @@ -2080,7 +2084,7 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn run_resumes_fetching_after_heartbeat_recovers(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 2, Duration::from_millis(20)).await; + let worker_pool = connect_to_current_db(&pool, 2, POOL_TIMEOUT).await; let held_connection = worker_pool.acquire().await.unwrap(); let state = StepStateGuard::new(); @@ -2443,7 +2447,7 @@ mod tests { #[sqlx::test(migrations = "./migrations")] async fn run_recovers_from_pool_timeouts_until_a_stop_notification_arrives(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 1, Duration::from_millis(20)).await; + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; let worker = spawn_worker(worker_pool); sleep(Duration::from_millis(100)).await; From b72dc40e74147bb84f814a2f66faca5570da4dc8 Mon Sep 17 00:00:00 2001 From: imbolc Date: Wed, 13 May 2026 16:19:58 +0600 Subject: [PATCH 27/44] Up migration --- migrations/20260509130000_task-leases.sql | 57 +++++++++++++---------- 1 file changed, 33 insertions(+), 24 deletions(-) diff --git a/migrations/20260509130000_task-leases.sql b/migrations/20260509130000_task-leases.sql index b137ab4..6352e15 100644 --- a/migrations/20260509130000_task-leases.sql +++ b/migrations/20260509130000_task-leases.sql @@ -1,13 +1,10 @@ +-- Add lease related columns ALTER TABLE pg_task ADD COLUMN locked_by UUID, ADD COLUMN lock_expires_at timestamptz; -DROP TRIGGER pg_task_changed ON pg_task; - -UPDATE pg_task -SET locked_by = gen_random_uuid(), - lock_expires_at = now() -WHERE is_running = true; +COMMENT ON COLUMN pg_task.locked_by IS 'Worker currently owning the running step lease'; +COMMENT ON COLUMN pg_task.lock_expires_at IS 'Time when the running step lease expires and can be reclaimed'; ALTER TABLE pg_task ADD CONSTRAINT pg_task_lease_state_check CHECK ( @@ -15,10 +12,37 @@ ADD CONSTRAINT pg_task_lease_state_check CHECK ( OR (locked_by IS NOT NULL AND lock_expires_at IS NOT NULL) ); +CREATE INDEX pg_task_locked_by_idx +ON pg_task (locked_by) +WHERE locked_by IS NOT NULL + AND error IS NULL; + +CREATE INDEX pg_task_next_available_at_idx +ON pg_task (( + CASE + WHEN locked_by IS NOT NULL THEN + GREATEST(wakeup_at, lock_expires_at) + ELSE + wakeup_at + END +)) +WHERE error IS NULL; + +-- Remove `running_at` column +UPDATE pg_task +SET locked_by = gen_random_uuid(), + lock_expires_at = now() +WHERE is_running = true; + ALTER TABLE pg_task DROP COLUMN is_running; -COMMENT ON COLUMN pg_task.locked_by IS 'Worker currently owning the running step lease'; -COMMENT ON COLUMN pg_task.lock_expires_at IS 'Time when the running step lease expires and can be reclaimed'; +-- Update trigger +-- +-- The old trigger in migrations/20230714025134_trigger.sql:9 fired on every INSERT or UPDATE. +-- With leases, that became too noisy because heartbeat renewal updates only lock_expires_at. +-- If the old trigger stayed, every lease renewal would NOTIFY pg_task_changed, +-- waking waiting workers even though no new task became claimable. +DROP TRIGGER pg_task_changed ON pg_task; CREATE OR REPLACE FUNCTION pg_task_notify_on_change() RETURNS trigger AS $$ @@ -43,6 +67,7 @@ ON pg_task FOR EACH ROW EXECUTE PROCEDURE pg_task_notify_on_change(); +-- The new trigger only notifies on updates to fields that affect task state or claimability. CREATE TRIGGER pg_task_changed_update AFTER UPDATE ON pg_task @@ -55,19 +80,3 @@ WHEN ( OR OLD.locked_by IS DISTINCT FROM NEW.locked_by ) EXECUTE PROCEDURE pg_task_notify_on_change(); - -CREATE INDEX pg_task_locked_by_idx -ON pg_task (locked_by) -WHERE locked_by IS NOT NULL - AND error IS NULL; - -CREATE INDEX pg_task_next_available_at_idx -ON pg_task (( - CASE - WHEN locked_by IS NOT NULL THEN - GREATEST(wakeup_at, lock_expires_at) - ELSE - wakeup_at - END -)) -WHERE error IS NULL; From 6640fe7caffb0f099ac78a73be763a394b3ece4e Mon Sep 17 00:00:00 2001 From: imbolc Date: Wed, 13 May 2026 16:44:14 +0600 Subject: [PATCH 28/44] Narrow task change notifications --- ...ffb906b7b52ff09a55a3064ed7943558234b103.json | 14 -------------- migrations/20260509130000_task-leases.sql | 17 ----------------- src/listener.rs | 11 +---------- 3 files changed, 1 insertion(+), 41 deletions(-) delete mode 100644 .sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json diff --git a/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json b/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json deleted file mode 100644 index 8c0474e..0000000 --- a/.sqlx/query-ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "DELETE FROM pg_task WHERE id = $1", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [] - }, - "hash": "ebc5a43458570f6f64356d4fdffb906b7b52ff09a55a3064ed7943558234b103" -} diff --git a/migrations/20260509130000_task-leases.sql b/migrations/20260509130000_task-leases.sql index 6352e15..f854cb6 100644 --- a/migrations/20260509130000_task-leases.sql +++ b/migrations/20260509130000_task-leases.sql @@ -44,29 +44,12 @@ ALTER TABLE pg_task DROP COLUMN is_running; -- waking waiting workers even though no new task became claimable. DROP TRIGGER pg_task_changed ON pg_task; -CREATE OR REPLACE FUNCTION pg_task_notify_on_change() -RETURNS trigger AS $$ -BEGIN - PERFORM pg_notify('pg_task_changed', ''); - IF TG_OP = 'DELETE' THEN - RETURN OLD; - END IF; - RETURN NEW; -END; -$$ LANGUAGE plpgsql; - CREATE TRIGGER pg_task_changed_insert AFTER INSERT ON pg_task FOR EACH ROW EXECUTE PROCEDURE pg_task_notify_on_change(); -CREATE TRIGGER pg_task_changed_delete -AFTER DELETE -ON pg_task -FOR EACH ROW -EXECUTE PROCEDURE pg_task_notify_on_change(); - -- The new trigger only notifies on updates to fields that affect task state or claimability. CREATE TRIGGER pg_task_changed_update AFTER UPDATE diff --git a/src/listener.rs b/src/listener.rs index e229023..c8dcd1c 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -564,7 +564,7 @@ mod tests { } #[sqlx::test(migrations = "./migrations")] - async fn listen_wakes_subscribers_for_task_updates_and_deletes(pool: PgPool) { + async fn listen_wakes_subscribers_for_task_inserts_and_updates(pool: PgPool) { let listener = Listener::new(); listener.listen(pool.clone()).await.unwrap(); let insert_subscription = listener.subscribe(); @@ -590,15 +590,6 @@ mod tests { .await .unwrap(); - let delete_subscription = listener.subscribe(); - sqlx::query!("DELETE FROM pg_task WHERE id = $1", id) - .execute(&pool) - .await - .unwrap(); - timeout(Duration::from_secs(1), delete_subscription.wait_forever()) - .await - .unwrap(); - assert!(!listener.time_to_stop_worker()); assert!(listener.take_error().is_none()); } From e5bf27df9521da6a0cd867adb0233ba6ab144850 Mon Sep 17 00:00:00 2001 From: imbolc Date: Wed, 13 May 2026 17:00:57 +0600 Subject: [PATCH 29/44] Use thiserror for error display --- Cargo.toml | 1 - src/error.rs | 55 ++++++++++++++++++++-------------------------------- 2 files changed, 21 insertions(+), 35 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index d05a4ec..4259aac 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,7 +10,6 @@ version = "0.3.0" async-trait = "0.1" chrono = { version = "0.4", features = ["std", "serde"] } code-path = "0.4" -displaydoc = "0.2" num_cpus = "1" serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/src/error.rs b/src/error.rs index dc5fc8d..5be04bd 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,28 +2,36 @@ use crate::NextStep; use std::{error::Error as StdError, result::Result as StdResult}; /// The crate error -#[derive(Debug, displaydoc::Display, thiserror::Error)] +#[derive(Debug, thiserror::Error)] pub enum Error { - /// can't add task + /// Can't add a task. + #[error("can't add task")] AddTask(#[source] sqlx::Error), - /// can't serialize step: {1} + /// Can't serialize the task step. + #[error("can't serialize step: {1}")] SerializeStep(#[source] serde_json::Error, String), - /** - can't deserialize step (the task was likely changed between the - scheduling and running of the step): {1} - */ + /// Can't deserialize the task step. + #[error( + "can't deserialize step (the task was likely changed between the scheduling and running of the step): {1}" + )] DeserializeStep(#[source] serde_json::Error, String), - /// listener can't connect to the db + /// Listener can't connect to the database. + #[error("listener can't connect to the db")] ListenerConnect(#[source] sqlx::Error), - /// can't start listening for table changes + /// Can't start listening for table changes. + #[error("can't start listening for table changes")] ListenerListen(#[source] sqlx::Error), - /// listener can't receive table change notifications + /// Listener can't receive table change notifications. + #[error("listener can't receive table change notifications")] ListenerReceive(#[source] sqlx::Error), - /// unreachable: worker semaphore is closed + /// Worker semaphore is closed. + #[error("unreachable: worker semaphore is closed")] UnreachableWorkerSemaphoreClosed(#[source] tokio::sync::AcquireError), - /// task lease expired before the worker could renew it + /// Task lease expired before the worker could renew it. + #[error("task lease expired before the worker could renew it")] TaskLeaseExpired, - /// db error: {1} + /// Database operation failed. + #[error("db error: {1}")] Db(#[source] sqlx::Error, String), } @@ -35,24 +43,3 @@ pub type StepError = Box; /// Result returning from task steps pub type StepResult = StdResult, StepError>; - -#[cfg(test)] -mod tests { - use super::Error; - - #[test] - fn error_display_messages_are_stable() { - assert_eq!( - Error::Db(sqlx::Error::PoolTimedOut, "fetch task".into()).to_string(), - "db error: fetch task", - ); - assert_eq!( - Error::ListenerReceive(sqlx::Error::PoolClosed).to_string(), - "listener can't receive table change notifications", - ); - assert_eq!( - Error::TaskLeaseExpired.to_string(), - "task lease expired before the worker could renew it", - ); - } -} From 1b0bb0bf8da982e15a110413777beb3215df9825 Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 08:24:00 +0600 Subject: [PATCH 30/44] Simplify task docs wording --- README.md | 4 ++++ src/lib.rs | 38 ++++++++++++++++---------------------- 2 files changed, 20 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 42bb389..624757c 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,10 @@ FSM-based resumable Postgres tasks. +pg_task stores task state in Postgres and runs each task as a resumable state +machine, with scheduling, retries, delays, errors, and worker leases handled by +a single table. + The full crate documentation, tutorial, and API examples live on [docs.rs/pg_task][docs] diff --git a/src/lib.rs b/src/lib.rs index 9bd8a6d..ba1bc45 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -6,7 +6,7 @@ FSM-based Resumable Postgres tasks - **FSM-based** - each task is a granular state machine - **Resumable** - on error, after you fix the step logic or the external world, the task is able to pick up where it stopped -- **Postgres** - a single table is enough to handle task scheduling, state +- **Postgres** - a single table handles task scheduling, state transitions, and error processing ## Table of Contents @@ -101,8 +101,7 @@ impl Step for SayHello { The second step prints the greeting and finishes the task returning `NextStep::none()`. -That's essentially all, except for some boilerplate you can find in the [full -code][tutorial-example]. Let's run it: +The [full code][tutorial-example] includes the remaining setup. Run it with: ```bash cargo run --example tutorial @@ -110,8 +109,8 @@ cargo run --example tutorial ### Investigating Errors -You'll see log messages about the 6 (first try + `RETRY_LIMIT`) attempts and the -final error message. Let's look into the DB to find out what happened: +The example logs 6 attempts: the first try plus `RETRY_LIMIT` retries. Inspect +the row to see what happened: ```bash ~$ psql pg_task -c 'table pg_task' @@ -134,29 +133,26 @@ updated_at | 2024-06-30 09:32:27.703599+06 ### Fixing the World -In this case, the error is due to the external world state. Let's fix it by -creating the file: +The task failed because the file is missing. Create it: ```bash echo 'Fixed World' >name.txt ``` -To rerun the task, we just need to clear its `error`: +Clear `error` to rerun the task: ```bash psql pg_task -c 'update pg_task set error = null' ``` -You'll see the log messages about rerunning the task and the greeting message of -the final step. That's all 🎉. +The worker reruns the task and prints the greeting from the final step. ### Scheduling Tasks -Essentially scheduling a task is done by inserting a corresponding row into the -`pg_task` table. You can do it by hand from `psql` or code in any language. +Scheduling a task means inserting a row into the `pg_task` table. You can do it +from `psql` or from code in any language. -There's also a few helpers to take care of the first step serialization and time -scheduling: +The crate also provides helpers for first-step serialization and scheduling: - [`enqueue`] - to run the task immediately - [`delay`] - to run it with a delay @@ -255,27 +251,25 @@ pg_task::Worker::::new(db).run().await?; # } ``` -All the communication is synchronized by the DB, so it doesn't matter how or how -many workers you run. It could be a separate process as well as in-process -[`tokio::spawn`]. +Workers coordinate through Postgres, so you can run one or many of them, either +in separate processes or with [`tokio::spawn`]. ### Stopping Workers -You can gracefully stop task runners by sending a notification using the DB: +Gracefully stop workers by sending a notification through the database: ```sql SELECT pg_notify('pg_task_changed', 'stop_worker'); ``` -The workers would wait until the current step of all the tasks is finished and -then exit. You can wait for this by checking for the existence of running tasks: +Workers finish their current steps before exiting. To wait for them, check for +live leases: ```sql SELECT EXISTS( SELECT 1 FROM pg_task - WHERE locked_by IS NOT NULL - AND lock_expires_at > now() + WHERE lock_expires_at > now() ); ``` From a06d0a116d78208a493149ed68933b2fde3e93cc Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 08:43:55 +0600 Subject: [PATCH 31/44] Use parking_lot mutexes --- Cargo.toml | 1 + src/listener.rs | 56 ++++++++++++--------------------------------- src/worker.rs | 61 ++++++++++++------------------------------------- 3 files changed, 30 insertions(+), 88 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4259aac..d20ddc6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,6 +11,7 @@ async-trait = "0.1" chrono = { version = "0.4", features = ["std", "serde"] } code-path = "0.4" num_cpus = "1" +parking_lot = "0.12" serde = { version = "1", features = ["derive"] } serde_json = "1" source-chain = "0.1" diff --git a/src/listener.rs b/src/listener.rs index c8dcd1c..1f193dc 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -1,9 +1,10 @@ use crate::{util, LOST_CONNECTION_SLEEP}; +use parking_lot::Mutex; use sqlx::{postgres::PgListener, PgPool}; use std::{ sync::{ atomic::{AtomicBool, Ordering}, - Arc, Mutex, + Arc, }, time::Duration, }; @@ -103,10 +104,7 @@ impl Listener { } pub(crate) fn take_error(&self) -> Option { - self.error - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .take() + self.error.lock().take() } pub(crate) fn shutdown(&self) { @@ -114,30 +112,22 @@ impl Listener { } fn set_error(error_slot: &Mutex>, error: crate::Error) { - *error_slot - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) = Some(error); + *error_slot.lock() = Some(error); } fn replace_task( task_slot: &Mutex>, task: tokio::task::AbortHandle, ) { - if let Some(old_task) = task_slot - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .replace(task) - { + let old_task = task_slot.lock().replace(task); + if let Some(old_task) = old_task { old_task.abort(); } } fn clear_task(task_slot: &Mutex>) { - if let Some(task) = task_slot - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .take() - { + let task = task_slot.lock().take(); + if let Some(task) = task { task.abort(); } } @@ -255,8 +245,9 @@ mod tests { use super::Listener; use crate::Error; use chrono::{DateTime, Utc}; + use parking_lot::Mutex; use sqlx::{postgres::PgPoolOptions, types::Uuid, PgPool}; - use std::{future::pending, io, sync::Mutex, time::Duration}; + use std::{future::pending, io, time::Duration}; use tokio::{ sync::Notify, time::{sleep, timeout}, @@ -385,10 +376,7 @@ mod tests { !Listener::handle_recv_error(&error_slot, ¬ify, &db, sqlx::Error::PoolTimedOut) .await ); - assert!(error_slot - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .is_none()); + assert!(error_slot.lock().is_none()); } #[tokio::test] @@ -415,10 +403,7 @@ mod tests { .unwrap(); assert!(matches!( - error_slot - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .take(), + error_slot.lock().take(), Some(Error::ListenerReceive(sqlx::Error::Protocol(_))) )); } @@ -436,10 +421,7 @@ mod tests { assert!(timeout(Duration::from_millis(50), subscription) .await .is_err()); - assert!(error_slot - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .is_none()); + assert!(error_slot.lock().is_none()); } #[sqlx::test(migrations = "./migrations")] @@ -460,10 +442,7 @@ mod tests { .await .unwrap(); assert!(matches!( - error_slot - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .take(), + error_slot.lock().take(), Some(Error::Db(sqlx::Error::Database(_), _)) )); } @@ -663,12 +642,7 @@ mod tests { timeout(Duration::from_secs(1), async { loop { - if listener - .error - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .is_some() - { + if listener.error.lock().is_some() { return; } sleep(Duration::from_millis(10)).await; diff --git a/src/worker.rs b/src/worker.rs index 9b6127f..4cbad70 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -4,11 +4,12 @@ use crate::{ util::{db_error, is_connection_error, is_pool_timeout, wait_for_reconnection}, Error, Result, Step, LOST_CONNECTION_SLEEP, }; +use parking_lot::Mutex; use sqlx::postgres::PgPool; use std::{ marker::PhantomData, num::NonZeroUsize, - sync::{Arc, Mutex}, + sync::Arc, time::{Duration, Instant}, }; use tokio::{ @@ -452,9 +453,7 @@ impl + 'static> Worker { task_id: Uuid, abort_handle: tokio::task::AbortHandle, ) { - let mut running_steps = running_steps - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut running_steps = running_steps.lock(); running_steps.retain(|step| !step.abort_handle.is_finished()); running_steps.push(RunningStep { task_id, @@ -463,9 +462,7 @@ impl + 'static> Worker { } fn abort_running_steps(running_steps: &Mutex>) { - let running_steps = running_steps - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let running_steps = running_steps.lock(); for step in &*running_steps { step.abort_handle.abort(); } @@ -477,9 +474,7 @@ impl + 'static> Worker { } fn running_task_ids(running_steps: &Mutex>) -> Vec { - let mut running_steps = running_steps - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner); + let mut running_steps = running_steps.lock(); running_steps.retain(|step| !step.abort_handle.is_finished()); running_steps.iter().map(|step| step.task_id).collect() } @@ -650,6 +645,7 @@ mod tests { }; use crate::{task::TaskLease, Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; + use parking_lot::Mutex; use sqlx::{ postgres::{PgConnectOptions, PgPoolOptions}, PgPool, @@ -660,7 +656,7 @@ mod tests { num::NonZeroUsize, sync::{ atomic::{AtomicU64, Ordering}, - Arc, Mutex, OnceLock, + Arc, OnceLock, }, time::Duration, }; @@ -816,10 +812,7 @@ mod tests { } fn record(&self, event: &'static str) { - self.events - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .push(event); + self.events.lock().push(event); self.events_changed.notify_waiters(); } @@ -828,22 +821,13 @@ mod tests { } fn events(&self) -> Vec<&'static str> { - self.events - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .clone() + self.events.lock().clone() } async fn wait_for_events(&self, count: usize) { timeout(Duration::from_secs(1), async { loop { - if self - .events - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .len() - >= count - { + if self.events.lock().len() >= count { return; } self.events_changed.notified().await; @@ -867,10 +851,7 @@ mod tests { fn new() -> Self { let key = NEXT_STEP_STATE_KEY.fetch_add(1, Ordering::Relaxed); let state = Arc::new(StepState::new()); - step_states() - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .insert(key, state.clone()); + step_states().lock().insert(key, state.clone()); Self { key, state } } @@ -885,10 +866,7 @@ mod tests { impl Drop for StepStateGuard { fn drop(&mut self) { - step_states() - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .remove(&self.key); + step_states().lock().remove(&self.key); } } @@ -900,12 +878,7 @@ mod tests { } fn step_state(key: u64) -> Arc { - step_states() - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .get(&key) - .cloned() - .unwrap() + step_states().lock().get(&key).cloned().unwrap() } fn connection_error() -> Error { @@ -1517,13 +1490,7 @@ mod tests { Worker::::track_running_step(&running_steps, Uuid::new_v4(), running_step_abort); - assert_eq!( - running_steps - .lock() - .unwrap_or_else(std::sync::PoisonError::into_inner) - .len(), - 1, - ); + assert_eq!(running_steps.lock().len(), 1); assert!(Worker::::has_running_steps(&running_steps)); Worker::::abort_running_steps(&running_steps); From d389dad8a1e529bbe507e7353685a994e71ce027 Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 09:00:05 +0600 Subject: [PATCH 32/44] Stabilize lease renewal test --- src/worker.rs | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/worker.rs b/src/worker.rs index 4cbad70..acd15fc 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1885,7 +1885,7 @@ mod tests { async move { Worker::::new(pool) .with_concurrency(nonzero(1)) - .with_lease_timeout(Duration::from_millis(200)) + .with_lease_timeout(Duration::from_secs(1)) .with_heartbeat_interval(Duration::from_millis(50)) .run() .await @@ -1895,11 +1895,18 @@ mod tests { state.state().wait_for_events(1).await; let (locked_by, initial_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); - sleep(Duration::from_millis(350)).await; - - let (renewed_by, renewed_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + let (renewed_by, renewed_expires_at) = timeout(Duration::from_secs(1), async { + loop { + let (renewed_by, renewed_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + if renewed_expires_at > initial_expires_at { + return (renewed_by, renewed_expires_at); + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); assert_eq!(renewed_by, locked_by); - assert!(renewed_expires_at > initial_expires_at); assert!(renewed_expires_at > Utc::now()); stop_worker(&pool).await; From 1046aee94bf31af941e51af438320f817203650b Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 09:15:39 +0600 Subject: [PATCH 33/44] Stabilize leased task expiry test --- src/worker.rs | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/worker.rs b/src/worker.rs index acd15fc..9a6adc1 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1776,21 +1776,20 @@ mod tests { true, ) .await; - set_task_lease(&pool, id, Utc::now() + ChronoDuration::milliseconds(100)).await; + let lock_expires_at = Utc::now() + ChronoDuration::seconds(1); + set_task_lease(&pool, id, lock_expires_at).await; let worker = Worker::::new(pool); let lease = worker_lease(&worker); let recv = tokio::spawn(async move { worker.recv_task(lease).await }); - sleep(Duration::from_millis(50)).await; - assert!(!recv.is_finished()); - - let (task, step, _lease) = timeout(Duration::from_secs(1), recv) + let (task, step, _lease) = timeout(Duration::from_secs(2), recv) .await .unwrap() .unwrap() .unwrap() .unwrap(); + assert!(Utc::now() >= lock_expires_at); assert_eq!(task.id, id); assert!(matches!(step, TestTask::Noop(Noop))); } From 0a48b09bc0359830b8ede7d4868d1e2f69e76d28 Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 09:22:46 +0600 Subject: [PATCH 34/44] Simplify task lease interval conversion --- src/task.rs | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/src/task.rs b/src/task.rs index 1ed1a47..9052bdb 100644 --- a/src/task.rs +++ b/src/task.rs @@ -27,19 +27,22 @@ pub(crate) struct TaskLease { impl TaskLease { pub(crate) fn new(worker_id: Uuid, timeout: Duration) -> Self { - let microseconds = timeout.as_nanos().saturating_add(999) / 1_000; - let microseconds = microseconds.min(i64::MAX as u128) as i64; Self { worker_id, - timeout: PgInterval { - months: 0, - days: 0, - microseconds, - }, + timeout: duration_to_pg_interval(timeout), } } } +fn duration_to_pg_interval(duration: Duration) -> PgInterval { + let microseconds = duration.as_nanos().div_ceil(1_000); + PgInterval { + months: 0, + days: 0, + microseconds: microseconds.min(i64::MAX as u128) as i64, + } +} + impl Task { /// Returns a delay before a task should run at the given time. pub fn delay_until(wakeup_at: DateTime) -> Option { @@ -603,6 +606,24 @@ mod tests { TaskLease::new(worker_id(), Duration::from_secs(60)) } + #[test] + fn task_lease_converts_timeout_to_microsecond_interval() { + let lease = TaskLease::new(worker_id(), Duration::from_micros(42)); + assert_eq!(lease.timeout.microseconds, 42); + } + + #[test] + fn task_lease_rounds_timeout_up_to_microseconds() { + let lease = TaskLease::new(worker_id(), Duration::from_nanos(1)); + assert_eq!(lease.timeout.microseconds, 1); + } + + #[test] + fn task_lease_saturates_large_timeouts() { + let lease = TaskLease::new(worker_id(), Duration::MAX); + assert_eq!(lease.timeout.microseconds, i64::MAX); + } + async fn insert_task_row( pool: &PgPool, step: &str, From fa5fe6bfa7abda26991b88c3ef485d0d5708ee40 Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 11:32:40 +0600 Subject: [PATCH 35/44] Align ready task query with availability index --- ...e887598b499bb439da0e15a5b7aba6ee5ab2b.json | 32 ------------------- ...11eeacd3ea47bb9683ae1b3b60d7923cf41cf.json | 32 +++++++++++++++++++ src/task.rs | 18 +++++++++-- 3 files changed, 47 insertions(+), 35 deletions(-) delete mode 100644 .sqlx/query-0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b.json create mode 100644 .sqlx/query-4f8f60c2a3232ab099284a6c80d11eeacd3ea47bb9683ae1b3b60d7923cf41cf.json diff --git a/.sqlx/query-0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b.json b/.sqlx/query-0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b.json deleted file mode 100644 index 6ca2ff2..0000000 --- a/.sqlx/query-0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b.json +++ /dev/null @@ -1,32 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT\n id,\n step,\n tried\n FROM pg_task\n WHERE error IS NULL\n AND wakeup_at <= now()\n AND (locked_by IS NULL OR lock_expires_at <= now())\n ORDER BY wakeup_at\n LIMIT 1\n FOR UPDATE SKIP LOCKED\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Uuid" - }, - { - "ordinal": 1, - "name": "step", - "type_info": "Text" - }, - { - "ordinal": 2, - "name": "tried", - "type_info": "Int4" - } - ], - "parameters": { - "Left": [] - }, - "nullable": [ - false, - false, - false - ] - }, - "hash": "0ab59bb88ea3816ec98e78ad071e887598b499bb439da0e15a5b7aba6ee5ab2b" -} diff --git a/.sqlx/query-4f8f60c2a3232ab099284a6c80d11eeacd3ea47bb9683ae1b3b60d7923cf41cf.json b/.sqlx/query-4f8f60c2a3232ab099284a6c80d11eeacd3ea47bb9683ae1b3b60d7923cf41cf.json new file mode 100644 index 0000000..28434f3 --- /dev/null +++ b/.sqlx/query-4f8f60c2a3232ab099284a6c80d11eeacd3ea47bb9683ae1b3b60d7923cf41cf.json @@ -0,0 +1,32 @@ +{ + "db_name": "PostgreSQL", + "query": "\n SELECT\n id,\n step,\n tried\n FROM pg_task\n WHERE error IS NULL\n AND (\n CASE\n WHEN locked_by IS NOT NULL THEN\n GREATEST(wakeup_at, lock_expires_at)\n ELSE\n wakeup_at\n END\n ) <= now()\n ORDER BY\n CASE\n WHEN locked_by IS NOT NULL THEN\n GREATEST(wakeup_at, lock_expires_at)\n ELSE\n wakeup_at\n END\n LIMIT 1\n FOR UPDATE SKIP LOCKED\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "step", + "type_info": "Text" + }, + { + "ordinal": 2, + "name": "tried", + "type_info": "Int4" + } + ], + "parameters": { + "Left": [] + }, + "nullable": [ + false, + false, + false + ] + }, + "hash": "4f8f60c2a3232ab099284a6c80d11eeacd3ea47bb9683ae1b3b60d7923cf41cf" +} diff --git a/src/task.rs b/src/task.rs index 9052bdb..738e525 100644 --- a/src/task.rs +++ b/src/task.rs @@ -66,9 +66,21 @@ impl Task { tried FROM pg_task WHERE error IS NULL - AND wakeup_at <= now() - AND (locked_by IS NULL OR lock_expires_at <= now()) - ORDER BY wakeup_at + AND ( + CASE + WHEN locked_by IS NOT NULL THEN + GREATEST(wakeup_at, lock_expires_at) + ELSE + wakeup_at + END + ) <= now() + ORDER BY + CASE + WHEN locked_by IS NOT NULL THEN + GREATEST(wakeup_at, lock_expires_at) + ELSE + wakeup_at + END LIMIT 1 FOR UPDATE SKIP LOCKED "#, From 4e71971785aba95dbd41e19bd4b00e6129efd5e0 Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 11:56:13 +0600 Subject: [PATCH 36/44] Clarify save_step_error locals --- src/task.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/task.rs b/src/task.rs index 738e525..7f1643f 100644 --- a/src/task.rs +++ b/src/task.rs @@ -273,7 +273,7 @@ impl Task { ) -> Result<()> { let err_str = source_chain::to_string(&*err); let tried_increment = if increment_tried { 1 } else { 0 }; - let step = sqlx::query!( + let updated_task = sqlx::query!( r#" UPDATE pg_task SET tried = tried + $3, @@ -295,7 +295,7 @@ impl Task { .await .map_err(db_error!())?; - let Some(row) = step else { + let Some(updated_task) = updated_task else { self.log_lost_lease(lease.worker_id, "save the step error"); return Ok(()); }; @@ -305,14 +305,14 @@ impl Task { error!( "[{id}] resulted in an error at step {step} on {attempt} attempt: {err_str}", id = self.id, - step = row.step, + step = updated_task.step, attempt = ordinal(attempt) ); } else { error!( "[{id}] couldn't deserialize step {step}: {err_str}", id = self.id, - step = row.step + step = updated_task.step ); } From 0c591c53a3434c6b8543e685dd587f0d56048030 Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 13:18:55 +0600 Subject: [PATCH 37/44] up --- README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 624757c..5976900 100644 --- a/README.md +++ b/README.md @@ -6,9 +6,9 @@ FSM-based resumable Postgres tasks. -pg_task stores task state in Postgres and runs each task as a resumable state -machine, with scheduling, retries, delays, errors, and worker leases handled by -a single table. +Stores task state in Postgres and runs each task as a resumable state machine, +with scheduling, retries, delays, errors, and worker leases handled by a single +table. The full crate documentation, tutorial, and API examples live on [docs.rs/pg_task][docs] From badafc4ad9c9be44031046a42d3c2beaf6d2ed2b Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 13:30:09 +0600 Subject: [PATCH 38/44] Simplify worker lease naming and delay conversion --- src/task.rs | 96 +++++++++++++++++++++++++++++++-------------------- src/util.rs | 22 ++---------- src/worker.rs | 25 ++++++++------ 3 files changed, 74 insertions(+), 69 deletions(-) diff --git a/src/task.rs b/src/task.rs index 7f1643f..80bcf90 100644 --- a/src/task.rs +++ b/src/task.rs @@ -1,5 +1,5 @@ use crate::{ - util::{chrono_duration_to_std, db_error, ordinal, std_duration_to_chrono}, + util::{db_error, ordinal, std_duration_to_chrono}, Error, NextStep, Result, Step, StepError, }; use chrono::{DateTime, Utc}; @@ -20,12 +20,12 @@ pub struct Task { } #[derive(Clone, Copy, Debug)] -pub(crate) struct TaskLease { +pub(crate) struct WorkerLease { worker_id: Uuid, timeout: PgInterval, } -impl TaskLease { +impl WorkerLease { pub(crate) fn new(worker_id: Uuid, timeout: Duration) -> Self { Self { worker_id, @@ -46,12 +46,10 @@ fn duration_to_pg_interval(duration: Duration) -> PgInterval { impl Task { /// Returns a delay before a task should run at the given time. pub fn delay_until(wakeup_at: DateTime) -> Option { - let delay = wakeup_at - Utc::now(); - if delay <= chrono::Duration::zero() { - None - } else { - Some(chrono_duration_to_std(delay)) - } + (wakeup_at - Utc::now()) + .to_std() + .ok() + .filter(|delay| !delay.is_zero()) } /// Fetches the closest ready task to run. @@ -117,7 +115,7 @@ impl Task { pub(crate) async fn mark_running( &self, con: &mut PgConnection, - lease: TaskLease, + lease: WorkerLease, ) -> Result<()> { trace!("[{}] mark running", self.id); sqlx::query!( @@ -140,7 +138,7 @@ impl Task { /// Renews live task leases owned by a worker. pub(crate) async fn renew_leases( db: &PgPool, - lease: TaskLease, + lease: WorkerLease, task_ids: &[Uuid], ) -> Result> { if task_ids.is_empty() { @@ -173,7 +171,7 @@ impl Task { pub(crate) async fn claim>( &self, con: &mut PgConnection, - lease: TaskLease, + lease: WorkerLease, ) -> Result> { let step = match self.parse_step() { Ok(step) => step, @@ -192,7 +190,7 @@ impl Task { &self, db: &PgPool, step: S, - lease: TaskLease, + lease: WorkerLease, ) -> Result<()> { info!( "[{id}]{attempt} run step {step}", @@ -269,7 +267,7 @@ impl Task { db: &PgPool, err: StepError, increment_tried: bool, - lease: TaskLease, + lease: WorkerLease, ) -> Result<()> { let err_str = source_chain::to_string(&*err); let tried_increment = if increment_tried { 1 } else { 0 }; @@ -325,7 +323,7 @@ impl Task { db: &PgPool, step: impl Serialize + fmt::Debug, delay: Duration, - lease: TaskLease, + lease: WorkerLease, ) -> Result<()> { let step = match serde_json::to_string(&step) .map_err(|e| Error::SerializeStep(e, format!("{step:?}"))) @@ -362,7 +360,7 @@ impl Task { } /// Removes the finished task - async fn complete(&self, db: &PgPool, lease: TaskLease) -> Result<()> { + async fn complete(&self, db: &PgPool, lease: WorkerLease) -> Result<()> { let result = sqlx::query!( r#" DELETE FROM pg_task @@ -392,7 +390,7 @@ impl Task { retry_limit: i32, delay: Duration, err: StepError, - lease: TaskLease, + lease: WorkerLease, ) -> Result<()> { let delay = std_duration_to_chrono(delay); @@ -438,7 +436,7 @@ impl Task { #[cfg(test)] mod tests { - use super::{Task, TaskLease}; + use super::{Task, WorkerLease}; use crate::{NextStep, Step}; use chrono::{DateTime, Duration as ChronoDuration, Utc}; use sqlx::PgPool; @@ -614,25 +612,25 @@ mod tests { Uuid::from_u128(2) } - fn task_lease() -> TaskLease { - TaskLease::new(worker_id(), Duration::from_secs(60)) + fn worker_lease() -> WorkerLease { + WorkerLease::new(worker_id(), Duration::from_secs(60)) } #[test] - fn task_lease_converts_timeout_to_microsecond_interval() { - let lease = TaskLease::new(worker_id(), Duration::from_micros(42)); + fn worker_lease_converts_timeout_to_microsecond_interval() { + let lease = WorkerLease::new(worker_id(), Duration::from_micros(42)); assert_eq!(lease.timeout.microseconds, 42); } #[test] - fn task_lease_rounds_timeout_up_to_microseconds() { - let lease = TaskLease::new(worker_id(), Duration::from_nanos(1)); + fn worker_lease_rounds_timeout_up_to_microseconds() { + let lease = WorkerLease::new(worker_id(), Duration::from_nanos(1)); assert_eq!(lease.timeout.microseconds, 1); } #[test] - fn task_lease_saturates_large_timeouts() { - let lease = TaskLease::new(worker_id(), Duration::MAX); + fn worker_lease_saturates_large_timeouts() { + let lease = WorkerLease::new(worker_id(), Duration::MAX); assert_eq!(lease.timeout.microseconds, i64::MAX); } @@ -706,12 +704,16 @@ mod tests { .unwrap(); } - async fn claim_task(pool: &PgPool, step: TestTask, tried: i32) -> (Task, TestTask, TaskLease) { + async fn claim_task( + pool: &PgPool, + step: TestTask, + tried: i32, + ) -> (Task, TestTask, WorkerLease) { let id = insert_task(pool, &step, tried, false).await; let mut tx = pool.begin().await.unwrap(); let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); assert_eq!(task.id, id); - let lease = task_lease(); + let lease = worker_lease(); let claimed = task .claim::(&mut tx, lease) .await @@ -786,7 +788,7 @@ mod tests { let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); assert!(task - .claim::(&mut tx, task_lease()) + .claim::(&mut tx, worker_lease()) .await .unwrap() .is_none()); @@ -859,7 +861,10 @@ mod tests { .unwrap(); let mut tx = pool.begin().await.unwrap(); - let err = task.mark_running(&mut tx, task_lease()).await.unwrap_err(); + let err = task + .mark_running(&mut tx, worker_lease()) + .await + .unwrap_err(); assert_database_error(err); } @@ -875,7 +880,7 @@ mod tests { let mut tx = pool.begin().await.unwrap(); let err = task - .claim::(&mut tx, task_lease()) + .claim::(&mut tx, worker_lease()) .await .unwrap_err(); @@ -889,7 +894,10 @@ mod tests { let started_at = Utc::now(); let mut tx = pool.begin().await.unwrap(); let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - let claimed = task.claim::(&mut tx, task_lease()).await.unwrap(); + let claimed = task + .claim::(&mut tx, worker_lease()) + .await + .unwrap(); tx.commit().await.unwrap(); let finished_at = Utc::now(); @@ -986,7 +994,7 @@ mod tests { .await; let started_at = Utc::now(); - let renewed = Task::renew_leases(&pool, task_lease(), &[owned, expired, other_worker]) + let renewed = Task::renew_leases(&pool, worker_lease(), &[owned, expired, other_worker]) .await .unwrap(); let finished_at = Utc::now(); @@ -1042,7 +1050,7 @@ mod tests { .await; assert!( - Task::renew_leases(&pool, task_lease(), &[expired, other_worker]) + Task::renew_leases(&pool, worker_lease(), &[expired, other_worker]) .await .unwrap() .is_empty() @@ -1374,7 +1382,10 @@ mod tests { .await .unwrap(); - let err = task.run_step(&pool, step, task_lease()).await.unwrap_err(); + let err = task + .run_step(&pool, step, worker_lease()) + .await + .unwrap_err(); assert_database_error(err); } @@ -1473,7 +1484,10 @@ mod tests { .await .unwrap(); - let err = task.run_step(&pool, step, task_lease()).await.unwrap_err(); + let err = task + .run_step(&pool, step, worker_lease()) + .await + .unwrap_err(); assert_database_error(err); } @@ -1533,7 +1547,10 @@ mod tests { .await .unwrap(); - let err = task.run_step(&pool, step, task_lease()).await.unwrap_err(); + let err = task + .run_step(&pool, step, worker_lease()) + .await + .unwrap_err(); assert_database_error(err); } @@ -1599,7 +1616,10 @@ mod tests { .await .unwrap(); - let err = task.run_step(&pool, step, task_lease()).await.unwrap_err(); + let err = task + .run_step(&pool, step, worker_lease()) + .await + .unwrap_err(); assert_database_error(err); } diff --git a/src/util.rs b/src/util.rs index 8a37013..95d5492 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,11 +1,3 @@ -/// Converts a chrono duration to std, it uses absolute value of the chrono -/// duration -pub fn chrono_duration_to_std(chrono_duration: chrono::Duration) -> std::time::Duration { - let seconds = chrono_duration.num_seconds(); - let nanos = chrono_duration.num_nanoseconds().unwrap_or(0) % 1_000_000_000; - std::time::Duration::new(seconds.unsigned_abs(), nanos.unsigned_abs() as u32) -} - /// Converts a std duration to chrono pub fn std_duration_to_chrono(std_duration: std::time::Duration) -> chrono::Duration { chrono::Duration::from_std(std_duration).unwrap_or(chrono::Duration::MAX) @@ -103,8 +95,8 @@ pub(crate) use db_error; #[cfg(test)] mod tests { use super::{ - chrono_duration_to_std, is_connection_error, is_pool_timeout, is_retryable_database_error, - ordinal, std_duration_to_chrono, wait_for_reconnection, + is_connection_error, is_pool_timeout, is_retryable_database_error, ordinal, + std_duration_to_chrono, wait_for_reconnection, }; use chrono::Duration as ChronoDuration; use sqlx::{ @@ -117,16 +109,6 @@ mod tests { // the first TCP connection before the pool is intentionally exhausted. const POOL_TIMEOUT: Duration = Duration::from_millis(100); - #[test] - fn chrono_duration_to_std_uses_the_absolute_value() { - let duration = ChronoDuration::seconds(-1) - ChronoDuration::milliseconds(250); - - assert_eq!( - chrono_duration_to_std(duration), - Duration::from_millis(1250) - ); - } - #[test] fn std_duration_to_chrono_saturates_on_overflow() { assert_eq!(std_duration_to_chrono(Duration::MAX), ChronoDuration::MAX); diff --git a/src/worker.rs b/src/worker.rs index 9a6adc1..e13f713 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,6 +1,6 @@ use crate::{ listener::Listener, - task::{Task, TaskLease}, + task::{Task, WorkerLease}, util::{db_error, is_connection_error, is_pool_timeout, wait_for_reconnection}, Error, Result, Step, LOST_CONNECTION_SLEEP, }; @@ -100,7 +100,7 @@ impl + 'static> Worker { self.validate_lease_timing(); self.listener.listen(self.db.clone()).await?; - let lease = TaskLease::new(Uuid::new_v4(), self.lease_timeout); + let lease = WorkerLease::new(Uuid::new_v4(), self.lease_timeout); let semaphore = Arc::new(Semaphore::new(self.concurrency.get())); let running_steps = Arc::new(Mutex::new(Vec::new())); let (heartbeat_events_sender, mut heartbeat_events) = mpsc::unbounded_channel(); @@ -266,7 +266,10 @@ impl + 'static> Worker { } /// Claims a currently available task and marks it running. - async fn claim_available_task(&self, lease: TaskLease) -> Result> { + async fn claim_available_task( + &self, + lease: WorkerLease, + ) -> Result> { trace!("Claiming an available task"); let mut tx = self.db.begin().await.map_err(db_error!("begin"))?; @@ -286,7 +289,7 @@ impl + 'static> Worker { /// Waits until the next task is ready, marks it running and returns it. /// Returns `None` if the worker is stopped #[cfg(test)] - async fn recv_task(&self, lease: TaskLease) -> Result> { + async fn recv_task(&self, lease: WorkerLease) -> Result> { trace!("Receiving the next task"); loop { @@ -370,7 +373,7 @@ impl + 'static> Worker { &self, events: mpsc::UnboundedSender, running_steps: Arc>>, - lease: TaskLease, + lease: WorkerLease, ) -> tokio::task::AbortHandle { self.validate_lease_timing(); let db = self.db.clone(); @@ -643,7 +646,7 @@ mod tests { HeartbeatEvent, RunEvents, RunningStep, TaskAvailability, Worker, DEFAULT_HEARTBEAT_INTERVAL, DEFAULT_LEASE_TIMEOUT, }; - use crate::{task::TaskLease, Error, NextStep, Step}; + use crate::{task::WorkerLease, Error, NextStep, Step}; use chrono::{Duration as ChronoDuration, Utc}; use parking_lot::Mutex; use sqlx::{ @@ -1063,8 +1066,8 @@ mod tests { NonZeroUsize::new(value).unwrap() } - fn worker_lease(worker: &Worker) -> TaskLease { - TaskLease::new(Uuid::new_v4(), worker.lease_timeout) + fn worker_lease(worker: &Worker) -> WorkerLease { + WorkerLease::new(Uuid::new_v4(), worker.lease_timeout) } fn running_step_entry(task_id: Uuid, abort_handle: tokio::task::AbortHandle) -> RunningStep { @@ -1325,7 +1328,7 @@ mod tests { .with_lease_timeout(Duration::from_millis(80)) .with_heartbeat_interval(Duration::from_millis(20)); let worker_id = Uuid::new_v4(); - let lease = TaskLease::new(worker_id, worker.lease_timeout); + let lease = WorkerLease::new(worker_id, worker.lease_timeout); let live = insert_task_at( &pool, &TestTask::Noop(Noop), @@ -1419,7 +1422,7 @@ mod tests { .with_lease_timeout(Duration::from_millis(500)) .with_heartbeat_interval(Duration::from_millis(20)); let worker_id = Uuid::new_v4(); - let lease = TaskLease::new(worker_id, worker.lease_timeout); + let lease = WorkerLease::new(worker_id, worker.lease_timeout); let id = insert_task_at( &pool, &TestTask::Noop(Noop), @@ -1806,7 +1809,7 @@ mod tests { set_task_lease(&pool, id, Utc::now() - ChronoDuration::milliseconds(1)).await; let worker = Worker::::new(pool.clone()); let worker_id = Uuid::new_v4(); - let lease = TaskLease::new(worker_id, worker.lease_timeout); + let lease = WorkerLease::new(worker_id, worker.lease_timeout); let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); From 8e9e8e90d5a3608c93fd20a4a20c6781dbdcc17b Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 13:35:21 +0600 Subject: [PATCH 39/44] Share database interruption classification --- src/listener.rs | 58 ++++++++++++++++++++++++--------------------- src/util.rs | 62 ++++++++++++++++++++++++++++++++++++++++++------- src/worker.rs | 45 ++++++++++++++++++++--------------- 3 files changed, 110 insertions(+), 55 deletions(-) diff --git a/src/listener.rs b/src/listener.rs index 1f193dc..0e68c07 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -154,39 +154,43 @@ impl Listener { db: &PgPool, error: sqlx::Error, ) -> bool { - if util::is_connection_error(&error) { - warn!( - "Listening for task table changes stopped because the database connection was interrupted:\n{}", - source_chain::to_string(&error) - ); - sleep(LOST_CONNECTION_SLEEP).await; - if let Err(error) = util::wait_for_reconnection(db, LOST_CONNECTION_SLEEP).await { + match util::db_interruption(&error) { + util::DbInterruption::Connection => { warn!( - "Couldn't wait for the database to recover after listener failure:\n{}", + "Listening for task table changes stopped because the database connection was interrupted:\n{}", source_chain::to_string(&error) ); - Self::set_error(error_slot, error); + sleep(LOST_CONNECTION_SLEEP).await; + if let Err(error) = util::wait_for_reconnection(db, LOST_CONNECTION_SLEEP).await { + warn!( + "Couldn't wait for the database to recover after listener failure:\n{}", + source_chain::to_string(&error) + ); + Self::set_error(error_slot, error); + notify.notify_one(); + true + } else { + warn!("Listening for task table changes resumed"); + false + } + } + util::DbInterruption::PoolTimeout => { + warn!( + "Listening for task table changes is waiting for a free database connection from the pool:\n{}", + source_chain::to_string(&error) + ); + sleep(LOST_CONNECTION_SLEEP).await; + false + } + util::DbInterruption::Permanent => { + warn!( + "Listening for task table changes failed:\n{}", + source_chain::to_string(&error) + ); + Self::set_error(error_slot, crate::Error::ListenerReceive(error)); notify.notify_one(); true - } else { - warn!("Listening for task table changes resumed"); - false } - } else if util::is_pool_timeout(&error) { - warn!( - "Listening for task table changes is waiting for a free database connection from the pool:\n{}", - source_chain::to_string(&error) - ); - sleep(LOST_CONNECTION_SLEEP).await; - false - } else { - warn!( - "Listening for task table changes failed:\n{}", - source_chain::to_string(&error) - ); - Self::set_error(error_slot, crate::Error::ListenerReceive(error)); - notify.notify_one(); - true } } diff --git a/src/util.rs b/src/util.rs index 95d5492..f3a23ef 100644 --- a/src/util.rs +++ b/src/util.rs @@ -16,8 +16,27 @@ pub fn ordinal(n: i32) -> String { } } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub(crate) enum DbInterruption { + Connection, + PoolTimeout, + Permanent, +} + +/// Classifies whether a SQLx error should interrupt database work permanently +/// or be retried after waiting. +pub(crate) fn db_interruption(error: &sqlx::Error) -> DbInterruption { + if is_connection_error(error) { + DbInterruption::Connection + } else if is_pool_timeout(error) { + DbInterruption::PoolTimeout + } else { + DbInterruption::Permanent + } +} + /// Returns true if the SQLx error points to a lost or unavailable connection. -pub(crate) fn is_connection_error(error: &sqlx::Error) -> bool { +fn is_connection_error(error: &sqlx::Error) -> bool { match error { sqlx::Error::Io(_) => true, sqlx::Error::Database(error) => is_retryable_database_error( @@ -30,7 +49,7 @@ pub(crate) fn is_connection_error(error: &sqlx::Error) -> bool { /// Returns true if the SQLx error indicates that the pool has no free /// connections right now. -pub(crate) fn is_pool_timeout(error: &sqlx::Error) -> bool { +fn is_pool_timeout(error: &sqlx::Error) -> bool { matches!(error, sqlx::Error::PoolTimedOut) } @@ -71,11 +90,15 @@ pub async fn wait_for_reconnection( .await { Ok(_) => return Ok(()), - Err(error) if is_connection_error(&error) || is_pool_timeout(&error) => { - tracing::trace!("Waiting for a database connection to become available"); - tokio::time::sleep(sleep).await; - } - Err(error) => return Err(db_error!("wait for reconnection")(error)), + Err(error) => match db_interruption(&error) { + DbInterruption::Connection | DbInterruption::PoolTimeout => { + tracing::trace!("Waiting for a database connection to become available"); + tokio::time::sleep(sleep).await; + } + DbInterruption::Permanent => { + return Err(db_error!("wait for reconnection")(error)); + } + }, } } } @@ -95,8 +118,8 @@ pub(crate) use db_error; #[cfg(test)] mod tests { use super::{ - is_connection_error, is_pool_timeout, is_retryable_database_error, ordinal, - std_duration_to_chrono, wait_for_reconnection, + db_interruption, is_connection_error, is_pool_timeout, is_retryable_database_error, + ordinal, std_duration_to_chrono, wait_for_reconnection, DbInterruption, }; use chrono::Duration as ChronoDuration; use sqlx::{ @@ -141,6 +164,27 @@ mod tests { assert!(is_pool_timeout(&sqlx::Error::PoolTimedOut)); } + #[test] + fn db_interruption_classifies_retryable_errors() { + assert_eq!( + db_interruption(&sqlx::Error::Io(io::Error::new( + io::ErrorKind::BrokenPipe, + "connection dropped", + ))), + DbInterruption::Connection, + ); + assert_eq!( + db_interruption(&sqlx::Error::PoolTimedOut), + DbInterruption::PoolTimeout, + ); + assert_eq!( + db_interruption(&sqlx::Error::Tls( + io::Error::other("bad certificate").into(), + )), + DbInterruption::Permanent, + ); + } + #[test] fn permanent_non_database_errors_are_not_retryable() { assert!(!is_connection_error(&sqlx::Error::Tls( diff --git a/src/worker.rs b/src/worker.rs index e13f713..05af454 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,7 +1,7 @@ use crate::{ listener::Listener, task::{Task, WorkerLease}, - util::{db_error, is_connection_error, is_pool_timeout, wait_for_reconnection}, + util::{db_error, db_interruption, wait_for_reconnection, DbInterruption}, Error, Result, Step, LOST_CONNECTION_SLEEP, }; use parking_lot::Mutex; @@ -348,24 +348,31 @@ impl + 'static> Worker { } async fn handle_recv_task_error(&self, error: Error) -> Result<()> { - if matches!(&error, Error::Db(db_error, _) if is_connection_error(db_error)) { - warn!( - "Task fetching stopped because the database connection was interrupted:\n{}", - source_chain::to_string(&error) - ); - sleep(LOST_CONNECTION_SLEEP).await; - wait_for_reconnection(&self.db, LOST_CONNECTION_SLEEP).await?; - warn!("Task fetching resumed"); - Ok(()) - } else if matches!(&error, Error::Db(db_error, _) if is_pool_timeout(db_error)) { - warn!( - "Task fetching is waiting for a free database connection from the pool:\n{}", - source_chain::to_string(&error) - ); - sleep(LOST_CONNECTION_SLEEP).await; - Ok(()) - } else { - Err(error) + let interruption = match &error { + Error::Db(db_error, _) => db_interruption(db_error), + _ => DbInterruption::Permanent, + }; + + match interruption { + DbInterruption::Connection => { + warn!( + "Task fetching stopped because the database connection was interrupted:\n{}", + source_chain::to_string(&error) + ); + sleep(LOST_CONNECTION_SLEEP).await; + wait_for_reconnection(&self.db, LOST_CONNECTION_SLEEP).await?; + warn!("Task fetching resumed"); + Ok(()) + } + DbInterruption::PoolTimeout => { + warn!( + "Task fetching is waiting for a free database connection from the pool:\n{}", + source_chain::to_string(&error) + ); + sleep(LOST_CONNECTION_SLEEP).await; + Ok(()) + } + DbInterruption::Permanent => Err(error), } } From 2a0cb2b2cb226156cd7215d83238b485a51ca55d Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 14:12:17 +0600 Subject: [PATCH 40/44] Move worker tests into child module --- src/worker.rs | 2235 +------------------------------------------ src/worker/tests.rs | 2225 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 2226 insertions(+), 2234 deletions(-) create mode 100644 src/worker/tests.rs diff --git a/src/worker.rs b/src/worker.rs index 05af454..1c50fc4 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -648,2237 +648,4 @@ impl + 'static> Worker { } #[cfg(test)] -mod tests { - use super::{ - HeartbeatEvent, RunEvents, RunningStep, TaskAvailability, Worker, - DEFAULT_HEARTBEAT_INTERVAL, DEFAULT_LEASE_TIMEOUT, - }; - use crate::{task::WorkerLease, Error, NextStep, Step}; - use chrono::{Duration as ChronoDuration, Utc}; - use parking_lot::Mutex; - use sqlx::{ - postgres::{PgConnectOptions, PgPoolOptions}, - PgPool, - }; - use std::{ - collections::HashMap, - io, - num::NonZeroUsize, - sync::{ - atomic::{AtomicU64, Ordering}, - Arc, OnceLock, - }, - time::Duration, - }; - use tokio::{ - sync::{mpsc, Notify, Semaphore}, - time::{sleep, timeout}, - }; - use uuid::Uuid; - - // Short enough to exercise PoolTimedOut, but long enough for CI to open - // the first TCP connection before the pool is intentionally exhausted. - const POOL_TIMEOUT: Duration = Duration::from_millis(100); - - fn init_tracing() { - static INIT: std::sync::Once = std::sync::Once::new(); - INIT.call_once(|| { - let _ = tracing_subscriber::fmt() - .with_max_level(tracing::Level::TRACE) - .with_test_writer() - .without_time() - .try_init(); - }); - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct Noop; - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct Advance { - key: u64, - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct Finish { - key: u64, - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct Complete { - key: u64, - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct Blocking { - key: u64, - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct FailSavingError { - key: u64, - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct FailStep { - key: u64, - } - - crate::task!(TestTask { - Noop, - Advance, - Finish, - Complete, - Blocking, - FailSavingError, - FailStep, - }); - - #[async_trait::async_trait] - impl Step for Noop { - async fn step(self, _db: &PgPool) -> crate::StepResult { - Ok(NextStep::None) - } - } - - #[async_trait::async_trait] - impl Step for Advance { - async fn step(self, _db: &PgPool) -> crate::StepResult { - step_state(self.key).record("advance"); - NextStep::now(Finish { key: self.key }) - } - } - - #[async_trait::async_trait] - impl Step for Finish { - async fn step(self, _db: &PgPool) -> crate::StepResult { - step_state(self.key).record("finish"); - NextStep::none() - } - } - - #[async_trait::async_trait] - impl Step for Complete { - async fn step(self, _db: &PgPool) -> crate::StepResult { - step_state(self.key).record("complete"); - NextStep::none() - } - } - - #[async_trait::async_trait] - impl Step for Blocking { - async fn step(self, _db: &PgPool) -> crate::StepResult { - let state = step_state(self.key); - state.record("started"); - state.wait_for_release().await; - state.record("completed"); - NextStep::none() - } - } - - #[async_trait::async_trait] - impl Step for FailSavingError { - async fn step(self, db: &PgPool) -> crate::StepResult { - let state = step_state(self.key); - state.record("started"); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN error TO task_error") - .execute(db) - .await - .unwrap(); - state.record("save error failed"); - Err(io::Error::other("step failed").into()) - } - - fn retry_limit(&self) -> i32 { - 0 - } - } - - #[async_trait::async_trait] - impl Step for FailStep { - async fn step(self, _db: &PgPool) -> crate::StepResult { - step_state(self.key).record("started"); - Err(io::Error::other("step failed").into()) - } - - fn retry_limit(&self) -> i32 { - 0 - } - } - - struct StepState { - events: Mutex>, - events_changed: Notify, - release: Notify, - } - - impl StepState { - fn new() -> Self { - Self { - events: Mutex::new(Vec::new()), - events_changed: Notify::new(), - release: Notify::new(), - } - } - - fn record(&self, event: &'static str) { - self.events.lock().push(event); - self.events_changed.notify_waiters(); - } - - fn release(&self) { - self.release.notify_waiters(); - } - - fn events(&self) -> Vec<&'static str> { - self.events.lock().clone() - } - - async fn wait_for_events(&self, count: usize) { - timeout(Duration::from_secs(1), async { - loop { - if self.events.lock().len() >= count { - return; - } - self.events_changed.notified().await; - } - }) - .await - .unwrap(); - } - - async fn wait_for_release(&self) { - self.release.notified().await; - } - } - - struct StepStateGuard { - key: u64, - state: Arc, - } - - impl StepStateGuard { - fn new() -> Self { - let key = NEXT_STEP_STATE_KEY.fetch_add(1, Ordering::Relaxed); - let state = Arc::new(StepState::new()); - step_states().lock().insert(key, state.clone()); - Self { key, state } - } - - fn key(&self) -> u64 { - self.key - } - - fn state(&self) -> Arc { - self.state.clone() - } - } - - impl Drop for StepStateGuard { - fn drop(&mut self) { - step_states().lock().remove(&self.key); - } - } - - static NEXT_STEP_STATE_KEY: AtomicU64 = AtomicU64::new(1); - static STEP_STATES: OnceLock>>> = OnceLock::new(); - - fn step_states() -> &'static Mutex>> { - STEP_STATES.get_or_init(|| Mutex::new(HashMap::new())) - } - - fn step_state(key: u64) -> Arc { - step_states().lock().get(&key).cloned().unwrap() - } - - fn connection_error() -> Error { - Error::Db( - sqlx::Error::Io(io::Error::new( - io::ErrorKind::BrokenPipe, - "worker connection dropped", - )), - "test".into(), - ) - } - - async fn insert_raw_task( - pool: &PgPool, - step: &str, - wakeup_at: chrono::DateTime, - is_leased: bool, - error: Option<&str>, - ) -> Uuid { - let (locked_by, lock_expires_at) = if is_leased { - ( - Some(Uuid::from_u128(1)), - Some(Utc::now() + ChronoDuration::seconds(60)), - ) - } else { - (None, None) - }; - sqlx::query!( - " - INSERT INTO pg_task (step, wakeup_at, locked_by, lock_expires_at, error) - VALUES ($1, $2, $3, $4, $5) - RETURNING id - ", - step, - wakeup_at, - locked_by, - lock_expires_at, - error, - ) - .fetch_one(pool) - .await - .unwrap() - .id - } - - async fn insert_task_at( - pool: &PgPool, - step: &TestTask, - wakeup_at: chrono::DateTime, - is_leased: bool, - ) -> Uuid { - insert_raw_task( - pool, - &serde_json::to_string(step).unwrap(), - wakeup_at, - is_leased, - None, - ) - .await - } - - async fn insert_task(pool: &PgPool, step: &TestTask, is_leased: bool) { - insert_task_at( - pool, - step, - Utc::now() - ChronoDuration::milliseconds(1), - is_leased, - ) - .await; - } - - async fn set_task_lease(pool: &PgPool, id: Uuid, lock_expires_at: chrono::DateTime) { - set_task_lease_for_worker(pool, id, Uuid::from_u128(1), lock_expires_at).await; - } - - async fn set_task_lease_for_worker( - pool: &PgPool, - id: Uuid, - worker_id: Uuid, - lock_expires_at: chrono::DateTime, - ) { - sqlx::query!( - r#" - UPDATE pg_task - SET locked_by = $2, - lock_expires_at = $3 - WHERE id = $1 - "#, - id, - worker_id, - lock_expires_at, - ) - .execute(pool) - .await - .unwrap(); - } - - async fn fetch_task_lease(pool: &PgPool, id: Uuid) -> Option<(Uuid, chrono::DateTime)> { - sqlx::query!( - " - SELECT locked_by, lock_expires_at - FROM pg_task - WHERE id = $1 - ", - id, - ) - .fetch_optional(pool) - .await - .unwrap() - .map(|row| (row.locked_by.unwrap(), row.lock_expires_at.unwrap())) - } - - fn idle_heartbeat() -> tokio::task::AbortHandle { - tokio::spawn(async { - std::future::pending::<()>().await; - }) - .abort_handle() - } - - fn idle_heartbeat_events() -> mpsc::UnboundedReceiver { - let (_sender, receiver) = mpsc::unbounded_channel(); - receiver - } - - fn idle_step_errors() -> mpsc::UnboundedReceiver { - let (_sender, receiver) = mpsc::unbounded_channel(); - receiver - } - - fn idle_run_events() -> RunEvents { - RunEvents { - heartbeat: idle_heartbeat_events(), - step_errors: idle_step_errors(), - } - } - - async fn connect_to_current_db( - pool: &PgPool, - max_connections: u32, - acquire_timeout: Duration, - ) -> PgPool { - let db_name: String = sqlx::query_scalar!(r#"SELECT current_database() AS "db_name!""#) - .fetch_one(pool) - .await - .unwrap(); - - PgPoolOptions::new() - .max_connections(max_connections) - .acquire_timeout(acquire_timeout) - .connect_with(current_database_options(&db_name)) - .await - .unwrap() - } - - // Connect to the database created by sqlx::test while keeping the - // connection settings from DATABASE_URL. CI needs its TCP host and password; - // postgres:///{db_name} only works for local peer-auth socket setups. - fn current_database_options(db_name: &str) -> PgConnectOptions { - std::env::var("DATABASE_URL") - .expect("DATABASE_URL must be set") - .parse::() - .unwrap() - .database(db_name) - } - - async fn task_count(pool: &PgPool) -> i64 { - sqlx::query!("SELECT id FROM pg_task") - .fetch_all(pool) - .await - .unwrap() - .len() as i64 - } - - async fn stop_worker(pool: &PgPool) { - sqlx::query!("NOTIFY pg_task_changed, 'stop_worker'") - .execute(pool) - .await - .unwrap(); - } - - fn nonzero(value: usize) -> NonZeroUsize { - NonZeroUsize::new(value).unwrap() - } - - fn worker_lease(worker: &Worker) -> WorkerLease { - WorkerLease::new(Uuid::new_v4(), worker.lease_timeout) - } - - fn running_step_entry(task_id: Uuid, abort_handle: tokio::task::AbortHandle) -> RunningStep { - RunningStep { - task_id, - abort_handle, - } - } - - fn spawn_worker(pool: PgPool) -> tokio::task::JoinHandle> { - spawn_worker_with_concurrency(pool, 1) - } - - fn spawn_worker_with_concurrency( - pool: PgPool, - concurrency: usize, - ) -> tokio::task::JoinHandle> { - tokio::spawn(async move { - Worker::::new(pool) - .with_concurrency(nonzero(concurrency)) - .run() - .await - }) - } - - #[tokio::test] - #[should_panic(expected = "lease timeout must be non-zero")] - async fn with_lease_timeout_rejects_zero() { - Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_lease_timeout(Duration::ZERO); - } - - #[tokio::test] - #[should_panic(expected = "heartbeat interval must be non-zero")] - async fn with_heartbeat_interval_rejects_zero() { - Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_heartbeat_interval(Duration::ZERO); - } - - #[tokio::test] - #[should_panic(expected = "heartbeat interval must be shorter than lease timeout")] - async fn run_rejects_lease_timeout_that_is_not_longer_than_the_heartbeat_interval() { - let worker = Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_lease_timeout(DEFAULT_HEARTBEAT_INTERVAL); - - let _ = worker.run().await; - } - - #[tokio::test] - #[should_panic(expected = "heartbeat interval must be shorter than lease timeout")] - async fn run_rejects_heartbeat_interval_that_is_not_shorter_than_the_lease_timeout() { - let worker = Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_heartbeat_interval(DEFAULT_LEASE_TIMEOUT); - - let _ = worker.run().await; - } - - #[test] - fn heartbeat_events_pause_resume_and_expire_fetching() { - let mut heartbeat_healthy = true; - Worker::::handle_heartbeat_event(HeartbeatEvent::Failed, &mut heartbeat_healthy) - .unwrap(); - assert!(!heartbeat_healthy); - Worker::::handle_heartbeat_event(HeartbeatEvent::Failed, &mut heartbeat_healthy) - .unwrap(); - assert!(!heartbeat_healthy); - - Worker::::handle_heartbeat_event( - HeartbeatEvent::Recovered, - &mut heartbeat_healthy, - ) - .unwrap(); - assert!(heartbeat_healthy); - Worker::::handle_heartbeat_event( - HeartbeatEvent::Recovered, - &mut heartbeat_healthy, - ) - .unwrap(); - assert!(heartbeat_healthy); - - let err = Worker::::handle_heartbeat_event( - HeartbeatEvent::Expired(Error::Db(sqlx::Error::PoolTimedOut, "test".into())), - &mut heartbeat_healthy, - ) - .unwrap_err(); - assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); - } - - #[tokio::test] - async fn heartbeat_expiry_interrupts_retryable_fetch_error_handling() { - init_tracing(); - let worker = Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ); - let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); - heartbeat_events - .send(HeartbeatEvent::Expired(Error::Db( - sqlx::Error::PoolTimedOut, - "heartbeat".into(), - ))) - .unwrap(); - let mut heartbeat_healthy = true; - let mut abort_running_steps = false; - - let err = timeout( - Duration::from_millis(100), - worker.handle_recv_task_error_or_heartbeat( - Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), - &mut heartbeat_events_receiver, - &mut heartbeat_healthy, - &mut abort_running_steps, - ), - ) - .await - .unwrap() - .unwrap_err(); - - assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); - assert!(abort_running_steps); - } - - #[tokio::test] - async fn heartbeat_recovery_preserves_retryable_fetch_error_handling() { - init_tracing(); - let worker = Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ); - let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); - heartbeat_events.send(HeartbeatEvent::Failed).unwrap(); - heartbeat_events.send(HeartbeatEvent::Recovered).unwrap(); - let mut heartbeat_healthy = true; - let mut abort_running_steps = false; - - worker - .handle_recv_task_error_or_heartbeat( - Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), - &mut heartbeat_events_receiver, - &mut heartbeat_healthy, - &mut abort_running_steps, - ) - .await - .unwrap(); - - assert!(heartbeat_healthy); - assert!(!abort_running_steps); - } - - #[tokio::test] - async fn heartbeat_failure_pauses_after_retryable_fetch_error_handling() { - init_tracing(); - let worker = Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ); - let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); - heartbeat_events.send(HeartbeatEvent::Failed).unwrap(); - let mut heartbeat_healthy = true; - let mut abort_running_steps = false; - - worker - .handle_recv_task_error_or_heartbeat( - Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), - &mut heartbeat_events_receiver, - &mut heartbeat_healthy, - &mut abort_running_steps, - ) - .await - .unwrap(); - - assert!(!heartbeat_healthy); - assert!(!abort_running_steps); - } - - #[sqlx::test(migrations = "./migrations")] - async fn heartbeat_skips_renewal_without_running_steps(pool: PgPool) { - init_tracing(); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN lock_expires_at TO task_lock_expires_at") - .execute(&pool) - .await - .unwrap(); - let worker = Worker::::new(pool) - .with_lease_timeout(Duration::from_millis(80)) - .with_heartbeat_interval(Duration::from_millis(20)); - let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat( - events, - Arc::new(Mutex::new(Vec::new())), - worker_lease(&worker), - ); - - assert!(timeout(Duration::from_millis(150), events_receiver.recv()) - .await - .is_err()); - - heartbeat.abort(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn heartbeat_expires_when_running_steps_have_no_live_leases(pool: PgPool) { - init_tracing(); - let worker = Worker::::new(pool) - .with_lease_timeout(Duration::from_millis(80)) - .with_heartbeat_interval(Duration::from_millis(20)); - let running_step = tokio::spawn(async { - std::future::pending::<()>().await; - }); - let running_steps = Arc::new(Mutex::new(vec![running_step_entry( - Uuid::new_v4(), - running_step.abort_handle(), - )])); - let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat(events, running_steps, worker_lease(&worker)); - - let event = timeout(Duration::from_secs(1), events_receiver.recv()) - .await - .unwrap() - .unwrap(); - assert!(matches!(event, HeartbeatEvent::Failed)); - - let event = timeout(Duration::from_secs(1), events_receiver.recv()) - .await - .unwrap() - .unwrap(); - assert!(matches!( - event, - HeartbeatEvent::Expired(Error::TaskLeaseExpired) - )); - - heartbeat.abort(); - running_step.abort(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn heartbeat_expires_when_any_running_step_loses_its_lease(pool: PgPool) { - init_tracing(); - let worker = Worker::::new(pool.clone()) - .with_lease_timeout(Duration::from_millis(80)) - .with_heartbeat_interval(Duration::from_millis(20)); - let worker_id = Uuid::new_v4(); - let lease = WorkerLease::new(worker_id, worker.lease_timeout); - let live = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - let expired = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - set_task_lease_for_worker( - &pool, - live, - worker_id, - Utc::now() + ChronoDuration::milliseconds(200), - ) - .await; - set_task_lease_for_worker( - &pool, - expired, - worker_id, - Utc::now() - ChronoDuration::milliseconds(1), - ) - .await; - let live_step = tokio::spawn(async { - std::future::pending::<()>().await; - }); - let expired_step = tokio::spawn(async { - std::future::pending::<()>().await; - }); - let running_steps = Arc::new(Mutex::new(vec![ - running_step_entry(live, live_step.abort_handle()), - running_step_entry(expired, expired_step.abort_handle()), - ])); - let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); - - let event = timeout(Duration::from_secs(1), events_receiver.recv()) - .await - .unwrap() - .unwrap(); - assert!(matches!(event, HeartbeatEvent::Failed)); - - let event = timeout(Duration::from_secs(1), events_receiver.recv()) - .await - .unwrap() - .unwrap(); - assert!(matches!( - event, - HeartbeatEvent::Expired(Error::TaskLeaseExpired) - )); - - heartbeat.abort(); - live_step.abort(); - expired_step.abort(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn heartbeat_skips_pool_timeouts_without_running_steps(pool: PgPool) { - init_tracing(); - let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; - let held_connection = worker_pool.acquire().await.unwrap(); - let worker = Worker::::new(worker_pool) - .with_lease_timeout(Duration::from_millis(500)) - .with_heartbeat_interval(Duration::from_millis(20)); - let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat( - events, - Arc::new(Mutex::new(Vec::new())), - worker_lease(&worker), - ); - - assert!(timeout(Duration::from_millis(150), events_receiver.recv()) - .await - .is_err()); - - drop(held_connection); - heartbeat.abort(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn heartbeat_reports_recovery_after_live_leases_are_renewed(pool: PgPool) { - init_tracing(); - let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; - let held_connection = worker_pool.acquire().await.unwrap(); - let worker = Worker::::new(worker_pool) - .with_lease_timeout(Duration::from_millis(500)) - .with_heartbeat_interval(Duration::from_millis(20)); - let worker_id = Uuid::new_v4(); - let lease = WorkerLease::new(worker_id, worker.lease_timeout); - let id = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - let initial_expires_at = Utc::now() + ChronoDuration::milliseconds(200); - sqlx::query!( - " - UPDATE pg_task - SET locked_by = $2, - lock_expires_at = $3 - WHERE id = $1 - ", - id, - worker_id, - initial_expires_at, - ) - .execute(&pool) - .await - .unwrap(); - let running_step = tokio::spawn(async { - std::future::pending::<()>().await; - }); - let running_steps = Arc::new(Mutex::new(vec![running_step_entry( - id, - running_step.abort_handle(), - )])); - let (events, mut events_receiver) = mpsc::unbounded_channel(); - let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); - - let event = timeout(Duration::from_secs(1), events_receiver.recv()) - .await - .unwrap() - .unwrap(); - assert!(matches!(event, HeartbeatEvent::Failed)); - - drop(held_connection); - - let event = timeout(Duration::from_secs(1), events_receiver.recv()) - .await - .unwrap() - .unwrap(); - assert!(matches!(event, HeartbeatEvent::Recovered)); - - let (_locked_by, renewed_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); - assert!(renewed_expires_at > initial_expires_at); - - heartbeat.abort(); - running_step.abort(); - } - - #[tokio::test] - async fn running_step_tracking_prunes_finished_steps() { - let finished_step = tokio::spawn(async {}); - let finished_step_abort = finished_step.abort_handle(); - finished_step.await.unwrap(); - - let running_step = tokio::spawn(async { - std::future::pending::<()>().await; - }); - let running_step_abort = running_step.abort_handle(); - let running_steps = Mutex::new(vec![running_step_entry( - Uuid::new_v4(), - finished_step_abort, - )]); - - Worker::::track_running_step(&running_steps, Uuid::new_v4(), running_step_abort); - - assert_eq!(running_steps.lock().len(), 1); - assert!(Worker::::has_running_steps(&running_steps)); - - Worker::::abort_running_steps(&running_steps); - assert!(running_step.await.unwrap_err().is_cancelled()); - assert!(!Worker::::has_running_steps(&running_steps)); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_returns_listener_startup_errors(pool: PgPool) { - let worker = Worker::::new(pool); - worker.listener.fail_next_listen_for_tests(); - - let err = worker.run().await.unwrap_err(); - - assert!(matches!( - err, - Error::ListenerListen(sqlx::Error::Protocol(_)) - )); - } - - #[sqlx::test(migrations = "./migrations")] - async fn handle_recv_task_error_returns_permanent_fetch_errors(pool: PgPool) { - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN step TO task_step") - .execute(&pool) - .await - .unwrap(); - - let err = sqlx::query!("SELECT step FROM pg_task") - .fetch_one(&pool) - .await - .unwrap_err(); - - let worker = Worker::::new(pool); - let err = worker - .handle_recv_task_error(Error::Db(err, "test".into())) - .await - .unwrap_err(); - - assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); - } - - #[tokio::test] - async fn handle_recv_task_error_retries_pool_timeouts() { - init_tracing(); - let worker = Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ); - - worker - .handle_recv_task_error(Error::Db(sqlx::Error::PoolTimedOut, "test".into())) - .await - .unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn handle_recv_task_error_waits_for_reconnection_after_connection_errors(pool: PgPool) { - init_tracing(); - let worker = Worker::::new(pool); - - worker - .handle_recv_task_error(connection_error()) - .await - .unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn handle_recv_task_error_returns_reconnection_failures(pool: PgPool) { - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") - .execute(&pool) - .await - .unwrap(); - - let worker = Worker::::new(pool); - let err = worker - .handle_recv_task_error(connection_error()) - .await - .unwrap_err(); - - assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_returns_listener_errors(pool: PgPool) { - let worker = Worker::::new(pool); - worker - .listener - .set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( - "listener failed".into(), - ))); - - let lease = worker_lease(&worker); - let err = worker.recv_task(lease).await.unwrap_err(); - - assert!(matches!( - err, - Error::ListenerReceive(sqlx::Error::Protocol(_)) - )); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_stops_even_if_listener_has_failed(pool: PgPool) { - let worker = Worker::::new(pool); - worker - .listener - .set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( - "listener failed".into(), - ))); - worker.listener.stop_worker_for_tests(); - - let lease = worker_lease(&worker); - assert!(worker.recv_task(lease).await.unwrap().is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_returns_begin_errors_when_the_pool_is_closed(pool: PgPool) { - let worker = Worker::::new(pool.clone()); - pool.close().await; - - let lease = worker_lease(&worker); - let err = worker.recv_task(lease).await.unwrap_err(); - - match err { - Error::Db(sqlx::Error::PoolClosed, context) => { - assert!(context.contains("begin")); - } - _ => panic!("expected a pool-closed begin error"), - } - } - - #[sqlx::test(migrations = "./migrations")] - async fn wait_for_available_task_does_not_claim_ready_tasks(pool: PgPool) { - let id = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - let worker = Worker::::new(pool.clone()); - - let availability = worker.wait_for_available_task().await.unwrap(); - - assert!(matches!(availability, TaskAvailability::Ready)); - let lease = sqlx::query!( - "SELECT locked_by, lock_expires_at FROM pg_task WHERE id = $1", - id, - ) - .fetch_one(&pool) - .await - .unwrap(); - assert!(lease.locked_by.is_none()); - assert!(lease.lock_expires_at.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_stops_while_waiting_for_work(pool: PgPool) { - let worker = Arc::new(Worker::::new(pool)); - let lease = worker_lease(&worker); - let recv = tokio::spawn({ - let worker = worker.clone(); - async move { worker.recv_task(lease).await } - }); - - sleep(Duration::from_millis(50)).await; - assert!(!recv.is_finished()); - worker.listener.stop_worker_for_tests(); - - let received = timeout(Duration::from_secs(1), recv) - .await - .unwrap() - .unwrap() - .unwrap(); - assert!(received.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_returns_listener_errors_while_waiting_for_work(pool: PgPool) { - let worker = Arc::new(Worker::::new(pool)); - let lease = worker_lease(&worker); - let recv = tokio::spawn({ - let worker = worker.clone(); - async move { worker.recv_task(lease).await } - }); - - sleep(Duration::from_millis(50)).await; - assert!(!recv.is_finished()); - worker - .listener - .set_error_and_notify_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( - "listener failed".into(), - ))); - - let err = timeout(Duration::from_secs(1), recv) - .await - .unwrap() - .unwrap() - .unwrap_err(); - assert!(matches!( - err, - Error::ListenerReceive(sqlx::Error::Protocol(_)) - )); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_skips_invalid_tasks_and_returns_next_ready_task(pool: PgPool) { - let invalid_id = insert_raw_task( - &pool, - "not-json", - Utc::now() - ChronoDuration::seconds(2), - false, - None, - ) - .await; - let expected = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::seconds(1), - false, - ) - .await; - let worker = Worker::::new(pool.clone()); - let lease = worker_lease(&worker); - - let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); - - assert_eq!(task.id, expected); - assert!(matches!(step, TestTask::Noop(Noop))); - let invalid_row = sqlx::query!( - "SELECT locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", - invalid_id, - ) - .fetch_one(&pool) - .await - .unwrap(); - assert!(invalid_row.locked_by.is_none()); - assert!(invalid_row.lock_expires_at.is_none()); - assert!(invalid_row.error.is_some()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_rechecks_locked_ready_tasks_without_notifications(pool: PgPool) { - let id = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - let mut tx = pool.begin().await.unwrap(); - let locked = sqlx::query!("SELECT id FROM pg_task WHERE id = $1 FOR UPDATE", id) - .fetch_one(&mut *tx) - .await - .unwrap(); - assert_eq!(locked.id, id); - - let worker = Worker::::new(pool); - let lease = worker_lease(&worker); - let recv = tokio::spawn(async move { worker.recv_task(lease).await }); - - sleep(Duration::from_millis(50)).await; - assert!(!recv.is_finished()); - - tx.rollback().await.unwrap(); - - let (task, step, _lease) = timeout(Duration::from_secs(1), recv) - .await - .unwrap() - .unwrap() - .unwrap() - .unwrap(); - assert_eq!(task.id, id); - assert!(matches!(step, TestTask::Noop(Noop))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_rechecks_leased_tasks_when_their_lease_expires(pool: PgPool) { - let id = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - true, - ) - .await; - let lock_expires_at = Utc::now() + ChronoDuration::seconds(1); - set_task_lease(&pool, id, lock_expires_at).await; - - let worker = Worker::::new(pool); - let lease = worker_lease(&worker); - let recv = tokio::spawn(async move { worker.recv_task(lease).await }); - - let (task, step, _lease) = timeout(Duration::from_secs(2), recv) - .await - .unwrap() - .unwrap() - .unwrap() - .unwrap(); - assert!(Utc::now() >= lock_expires_at); - assert_eq!(task.id, id); - assert!(matches!(step, TestTask::Noop(Noop))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn recv_task_replaces_expired_lease_with_the_current_worker(pool: PgPool) { - let id = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - true, - ) - .await; - set_task_lease(&pool, id, Utc::now() - ChronoDuration::milliseconds(1)).await; - let worker = Worker::::new(pool.clone()); - let worker_id = Uuid::new_v4(); - let lease = WorkerLease::new(worker_id, worker.lease_timeout); - - let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); - - assert_eq!(task.id, id); - assert!(matches!(step, TestTask::Noop(Noop))); - let (locked_by, lock_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); - assert_eq!(locked_by, worker_id); - assert!(lock_expires_at > Utc::now()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn two_workers_claim_ready_tasks_once(pool: PgPool) { - let first_id = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - let second_id = insert_task_at( - &pool, - &TestTask::Noop(Noop), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - let first_worker = Worker::::new(pool.clone()); - let second_worker = Worker::::new(pool.clone()); - let first_lease = worker_lease(&first_worker); - let second_lease = worker_lease(&second_worker); - - let first_recv = tokio::spawn(async move { first_worker.recv_task(first_lease).await }); - let second_recv = tokio::spawn(async move { second_worker.recv_task(second_lease).await }); - - let (first_task, first_step, _first_lease) = timeout(Duration::from_secs(1), first_recv) - .await - .unwrap() - .unwrap() - .unwrap() - .unwrap(); - let (second_task, second_step, _second_lease) = - timeout(Duration::from_secs(1), second_recv) - .await - .unwrap() - .unwrap() - .unwrap() - .unwrap(); - - assert!(matches!(first_step, TestTask::Noop(Noop))); - assert!(matches!(second_step, TestTask::Noop(Noop))); - assert_ne!(first_task.id, second_task.id); - assert!([first_id, second_id].contains(&first_task.id)); - assert!([first_id, second_id].contains(&second_task.id)); - - let running = sqlx::query!("SELECT id FROM pg_task WHERE locked_by IS NOT NULL") - .fetch_all(&pool) - .await - .unwrap(); - assert_eq!(running.len(), 2); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_renews_leases_for_running_tasks(pool: PgPool) { - let state = StepStateGuard::new(); - let id = insert_task_at( - &pool, - &TestTask::Blocking(Blocking { key: state.key() }), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - - let worker = tokio::spawn({ - let pool = pool.clone(); - async move { - Worker::::new(pool) - .with_concurrency(nonzero(1)) - .with_lease_timeout(Duration::from_secs(1)) - .with_heartbeat_interval(Duration::from_millis(50)) - .run() - .await - } - }); - - state.state().wait_for_events(1).await; - let (locked_by, initial_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); - - let (renewed_by, renewed_expires_at) = timeout(Duration::from_secs(1), async { - loop { - let (renewed_by, renewed_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); - if renewed_expires_at > initial_expires_at { - return (renewed_by, renewed_expires_at); - } - sleep(Duration::from_millis(10)).await; - } - }) - .await - .unwrap(); - assert_eq!(renewed_by, locked_by); - assert!(renewed_expires_at > Utc::now()); - - stop_worker(&pool).await; - state.state().release(); - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_aborts_running_steps_when_heartbeat_cannot_renew_before_the_lease_expires( - pool: PgPool, - ) { - let state = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::Blocking(Blocking { key: state.key() }), - false, - ) - .await; - - let worker = tokio::spawn({ - let pool = pool.clone(); - async move { - Worker::::new(pool) - .with_concurrency(nonzero(1)) - .with_lease_timeout(Duration::from_millis(250)) - .with_heartbeat_interval(Duration::from_millis(100)) - .run() - .await - } - }); - - state.state().wait_for_events(1).await; - sleep(Duration::from_millis(50)).await; - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN lock_expires_at TO task_lock_expires_at") - .execute(&pool) - .await - .unwrap(); - - let err = timeout(Duration::from_secs(2), worker) - .await - .unwrap() - .unwrap() - .unwrap_err(); - - assert_eq!(state.state().events(), vec!["started"]); - assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_pauses_fetching_while_heartbeat_cannot_renew(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; - let state = StepStateGuard::new(); - - let worker = tokio::spawn(async move { - Worker::::new(worker_pool) - .with_concurrency(nonzero(1)) - .with_lease_timeout(Duration::from_millis(200)) - .with_heartbeat_interval(Duration::from_millis(50)) - .run() - .await - }); - - sleep(Duration::from_millis(100)).await; - insert_task( - &pool, - &TestTask::Complete(Complete { key: state.key() }), - false, - ) - .await; - sleep(Duration::from_millis(150)).await; - - stop_worker(&pool).await; - - timeout(Duration::from_secs(2), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert!(state.state().events().is_empty()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_returns_listener_errors_while_fetching_is_paused(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; - let worker = Arc::new( - Worker::::new(worker_pool) - .with_concurrency(nonzero(1)) - .with_lease_timeout(Duration::from_millis(200)) - .with_heartbeat_interval(Duration::from_millis(50)), - ); - let run = tokio::spawn({ - let worker = worker.clone(); - async move { worker.run().await } - }); - - sleep(Duration::from_millis(1250)).await; - worker - .listener - .set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( - "listener failed".into(), - ))); - - let err = timeout(Duration::from_secs(3), run) - .await - .unwrap() - .unwrap() - .unwrap_err(); - assert!(matches!( - err, - Error::ListenerReceive(sqlx::Error::Protocol(_)) - )); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_keeps_waiting_after_retryable_errors_while_fetching_is_paused(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; - let worker = Arc::new( - Worker::::new(worker_pool) - .with_concurrency(nonzero(1)) - .with_lease_timeout(Duration::from_millis(200)) - .with_heartbeat_interval(Duration::from_millis(50)), - ); - let run = tokio::spawn({ - let worker = worker.clone(); - async move { worker.run().await } - }); - - sleep(Duration::from_millis(1250)).await; - worker - .listener - .set_error_for_tests(Error::Db(sqlx::Error::PoolTimedOut, "fetch task".into())); - sleep(Duration::from_millis(1300)).await; - assert!(!run.is_finished()); - - stop_worker(&pool).await; - - timeout(Duration::from_secs(3), run) - .await - .unwrap() - .unwrap() - .unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_resumes_fetching_after_heartbeat_recovers(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 2, POOL_TIMEOUT).await; - let held_connection = worker_pool.acquire().await.unwrap(); - let state = StepStateGuard::new(); - - let worker = tokio::spawn({ - let worker_pool = worker_pool.clone(); - async move { - Worker::::new(worker_pool) - .with_concurrency(nonzero(1)) - .with_lease_timeout(Duration::from_millis(300)) - .with_heartbeat_interval(Duration::from_millis(50)) - .run() - .await - } - }); - - sleep(Duration::from_millis(100)).await; - insert_task( - &pool, - &TestTask::Complete(Complete { key: state.key() }), - false, - ) - .await; - sleep(Duration::from_millis(150)).await; - - assert!(state.state().events().is_empty()); - drop(held_connection); - - timeout(Duration::from_secs(3), async { - loop { - if !state.state().events().is_empty() { - break; - } - sleep(Duration::from_millis(10)).await; - } - }) - .await - .unwrap(); - - stop_worker(&pool).await; - - timeout(Duration::from_secs(2), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert_eq!(state.state().events(), vec!["complete"]); - } - - #[tokio::test] - async fn finish_run_waits_for_inflight_steps_before_returning_errors() { - init_tracing(); - let worker = Arc::new( - Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_concurrency(nonzero(1)), - ); - let semaphore = Arc::new(Semaphore::new(1)); - let permit = semaphore.clone().acquire_owned().await.unwrap(); - - let task = tokio::spawn({ - let worker = worker.clone(); - let semaphore = semaphore.clone(); - async move { - worker - .finish_run( - Err(Error::ListenerReceive(sqlx::Error::Protocol( - "listener failed".into(), - ))), - semaphore, - idle_heartbeat(), - Arc::new(Mutex::new(Vec::new())), - false, - idle_run_events(), - ) - .await - } - }); - - sleep(Duration::from_millis(50)).await; - assert!(!task.is_finished()); - - drop(permit); - - let err = task.await.unwrap().unwrap_err(); - assert!(matches!( - err, - Error::ListenerReceive(sqlx::Error::Protocol(_)) - )); - } - - #[tokio::test] - async fn finish_run_returns_step_errors_received_while_draining() { - init_tracing(); - let worker = Arc::new( - Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_concurrency(nonzero(1)), - ); - let semaphore = Arc::new(Semaphore::new(1)); - let permit = semaphore.clone().acquire_owned().await.unwrap(); - let (step_error_sender, step_errors) = mpsc::unbounded_channel(); - - let finish = tokio::spawn({ - let worker = worker.clone(); - let semaphore = semaphore.clone(); - async move { - worker - .finish_run( - Ok(()), - semaphore, - idle_heartbeat(), - Arc::new(Mutex::new(Vec::new())), - false, - RunEvents { - heartbeat: idle_heartbeat_events(), - step_errors, - }, - ) - .await - } - }); - - sleep(Duration::from_millis(50)).await; - assert!(!finish.is_finished()); - - step_error_sender - .send(Error::Db(sqlx::Error::PoolTimedOut, "step".into())) - .unwrap(); - drop(permit); - - let err = timeout(Duration::from_secs(1), finish) - .await - .unwrap() - .unwrap() - .unwrap_err(); - assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); - } - - #[tokio::test] - async fn finish_run_keeps_heartbeat_alive_while_waiting_for_inflight_steps() { - init_tracing(); - let worker = Arc::new( - Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_concurrency(nonzero(1)), - ); - let semaphore = Arc::new(Semaphore::new(1)); - let permit = semaphore.clone().acquire_owned().await.unwrap(); - let heartbeat = tokio::spawn(async { - std::future::pending::<()>().await; - }); - - let finish = tokio::spawn({ - let worker = worker.clone(); - let semaphore = semaphore.clone(); - let heartbeat_abort = heartbeat.abort_handle(); - async move { - worker - .finish_run( - Ok(()), - semaphore, - heartbeat_abort, - Arc::new(Mutex::new(Vec::new())), - false, - idle_run_events(), - ) - .await - } - }); - - sleep(Duration::from_millis(50)).await; - assert!(!finish.is_finished()); - assert!(!heartbeat.is_finished()); - - drop(permit); - - timeout(Duration::from_secs(1), finish) - .await - .unwrap() - .unwrap() - .unwrap(); - assert!(heartbeat.await.unwrap_err().is_cancelled()); - } - - #[tokio::test] - async fn finish_run_aborts_inflight_steps_when_lease_renewal_expires() { - init_tracing(); - let worker = Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_concurrency(nonzero(1)); - let semaphore = Arc::new(Semaphore::new(1)); - let permit = semaphore.clone().acquire_owned().await.unwrap(); - let running_step = tokio::spawn(async move { - let _permit = permit; - std::future::pending::<()>().await; - }); - let running_steps = Arc::new(Mutex::new(vec![running_step_entry( - Uuid::new_v4(), - running_step.abort_handle(), - )])); - - let err = timeout( - Duration::from_secs(1), - worker.finish_run( - Err(Error::Db(sqlx::Error::PoolTimedOut, "test".into())), - semaphore, - idle_heartbeat(), - running_steps, - true, - idle_run_events(), - ), - ) - .await - .unwrap() - .unwrap_err(); - - assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); - assert!(running_step.await.unwrap_err().is_cancelled()); - } - - #[tokio::test] - async fn finish_run_aborts_inflight_steps_when_heartbeat_expires_while_draining() { - init_tracing(); - let worker = Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_concurrency(nonzero(1)); - let semaphore = Arc::new(Semaphore::new(1)); - let permit = semaphore.clone().acquire_owned().await.unwrap(); - let running_step = tokio::spawn(async move { - let _permit = permit; - std::future::pending::<()>().await; - }); - let running_steps = Arc::new(Mutex::new(vec![running_step_entry( - Uuid::new_v4(), - running_step.abort_handle(), - )])); - let (heartbeat_events_sender, heartbeat_events) = mpsc::unbounded_channel(); - heartbeat_events_sender - .send(HeartbeatEvent::Expired(Error::TaskLeaseExpired)) - .unwrap(); - - let err = timeout( - Duration::from_secs(1), - worker.finish_run( - Ok(()), - semaphore, - idle_heartbeat(), - running_steps, - false, - RunEvents { - heartbeat: heartbeat_events, - step_errors: idle_step_errors(), - }, - ), - ) - .await - .unwrap() - .unwrap_err(); - - assert!(matches!(err, Error::TaskLeaseExpired)); - assert!(running_step.await.unwrap_err().is_cancelled()); - } - - #[tokio::test] - async fn wait_for_steps_to_finish_rechecks_when_the_inflight_task_count_changes() { - init_tracing(); - let worker = Arc::new( - Worker::::new( - PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(), - ) - .with_concurrency(nonzero(2)), - ); - let semaphore = Arc::new(Semaphore::new(2)); - let first = semaphore.clone().acquire_owned().await.unwrap(); - let second = semaphore.clone().acquire_owned().await.unwrap(); - - let wait = tokio::spawn({ - let worker = worker.clone(); - let semaphore = semaphore.clone(); - async move { - worker.wait_for_steps_to_finish(semaphore).await; - } - }); - - sleep(Duration::from_millis(10)).await; - drop(first); - - sleep(Duration::from_millis(150)).await; - assert!(!wait.is_finished()); - - drop(second); - - timeout(Duration::from_secs(1), wait) - .await - .unwrap() - .unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_processes_followup_steps_to_completion(pool: PgPool) { - let state = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::Advance(Advance { key: state.key() }), - false, - ) - .await; - - let worker = spawn_worker(pool.clone()); - - state.state().wait_for_events(2).await; - stop_worker(&pool).await; - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert_eq!(state.state().events(), vec!["advance", "finish"]); - assert_eq!(task_count(&pool).await, 0); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_returns_listener_errors_when_the_pool_is_closed(pool: PgPool) { - let worker = spawn_worker(pool.clone()); - - sleep(Duration::from_millis(100)).await; - pool.close().await; - - let err = timeout(Duration::from_secs(2), worker) - .await - .unwrap() - .unwrap() - .unwrap_err(); - - assert!(matches!( - err, - Error::ListenerReceive(sqlx::Error::PoolClosed) - )); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_recovers_from_pool_timeouts_until_a_stop_notification_arrives(pool: PgPool) { - let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; - let worker = spawn_worker(worker_pool); - - sleep(Duration::from_millis(100)).await; - assert!(!worker.is_finished()); - - stop_worker(&pool).await; - - timeout(Duration::from_secs(3), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_stops_when_stop_notification_arrives_while_idle(pool: PgPool) { - let worker = spawn_worker(pool.clone()); - - sleep(Duration::from_millis(100)).await; - stop_worker(&pool).await; - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - assert_eq!(task_count(&pool).await, 0); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_wakes_up_for_tasks_inserted_while_idle(pool: PgPool) { - let state = StepStateGuard::new(); - let worker = spawn_worker(pool.clone()); - - sleep(Duration::from_millis(100)).await; - insert_task( - &pool, - &TestTask::Complete(Complete { key: state.key() }), - false, - ) - .await; - - state.state().wait_for_events(1).await; - stop_worker(&pool).await; - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert_eq!(state.state().events(), vec!["complete"]); - assert_eq!(task_count(&pool).await, 0); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_processes_noop_tasks_to_completion(pool: PgPool) { - insert_task(&pool, &TestTask::Noop(Noop), false).await; - - let worker = spawn_worker(pool.clone()); - - timeout(Duration::from_secs(1), async { - loop { - if task_count(&pool).await == 0 { - break; - } - sleep(Duration::from_millis(10)).await; - } - }) - .await - .unwrap(); - - stop_worker(&pool).await; - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_waits_for_future_tasks_to_become_ready_without_notifications(pool: PgPool) { - let state = StepStateGuard::new(); - insert_task_at( - &pool, - &TestTask::Complete(Complete { key: state.key() }), - Utc::now() + ChronoDuration::milliseconds(150), - false, - ) - .await; - - let worker = spawn_worker(pool.clone()); - - assert!( - timeout(Duration::from_millis(50), state.state().wait_for_events(1)) - .await - .is_err() - ); - - state.state().wait_for_events(1).await; - stop_worker(&pool).await; - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert_eq!(state.state().events(), vec!["complete"]); - assert_eq!(task_count(&pool).await, 0); - } - - #[sqlx::test(migrations = "./migrations")] - async fn starting_another_worker_does_not_unlock_live_tasks(pool: PgPool) { - let state = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::Blocking(Blocking { key: state.key() }), - false, - ) - .await; - - let first_worker = spawn_worker(pool.clone()); - state.state().wait_for_events(1).await; - - let second_worker = spawn_worker(pool.clone()); - sleep(Duration::from_millis(150)).await; - assert_eq!(state.state().events(), vec!["started"]); - - stop_worker(&pool).await; - state.state().release(); - - timeout(Duration::from_secs(1), first_worker) - .await - .unwrap() - .unwrap() - .unwrap(); - timeout(Duration::from_secs(1), second_worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert_eq!(state.state().events(), vec!["started", "completed"]); - assert_eq!(task_count(&pool).await, 0); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_skips_invalid_tasks_and_keeps_processing_ready_tasks(pool: PgPool) { - let invalid_id = insert_raw_task( - &pool, - "not-json", - Utc::now() - ChronoDuration::milliseconds(10), - false, - None, - ) - .await; - let state = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::Complete(Complete { key: state.key() }), - false, - ) - .await; - - let worker = spawn_worker(pool.clone()); - - state.state().wait_for_events(1).await; - stop_worker(&pool).await; - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - let invalid_row = sqlx::query!( - "SELECT tried, locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", - invalid_id, - ) - .fetch_one(&pool) - .await - .unwrap(); - - assert_eq!(state.state().events(), vec!["complete"]); - assert_eq!(invalid_row.tried, 0); - assert!(invalid_row.locked_by.is_none()); - assert!(invalid_row.lock_expires_at.is_none()); - assert!(invalid_row.error.is_some()); - assert_eq!(task_count(&pool).await, 1); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_stops_after_running_steps_finish(pool: PgPool) { - let state = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::Blocking(Blocking { key: state.key() }), - false, - ) - .await; - - let worker = spawn_worker(pool.clone()); - - state.state().wait_for_events(1).await; - stop_worker(&pool).await; - - sleep(Duration::from_millis(50)).await; - assert!(!worker.is_finished()); - - state.state().release(); - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert_eq!(state.state().events(), vec!["started", "completed"]); - assert_eq!(task_count(&pool).await, 0); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_returns_step_errors_received_after_stop_while_draining(pool: PgPool) { - let state = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::Blocking(Blocking { key: state.key() }), - false, - ) - .await; - - let worker = spawn_worker_with_concurrency(pool.clone(), 2); - - state.state().wait_for_events(1).await; - stop_worker(&pool).await; - sleep(Duration::from_millis(50)).await; - assert!(!worker.is_finished()); - - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") - .execute(&pool) - .await - .unwrap(); - state.state().release(); - - let err = timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap_err(); - - assert_eq!(state.state().events(), vec!["started", "completed"]); - assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_returns_spawned_step_persistence_errors(pool: PgPool) { - let state = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::FailSavingError(FailSavingError { key: state.key() }), - false, - ) - .await; - - let worker = spawn_worker(pool.clone()); - - state.state().wait_for_events(2).await; - let err = timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap_err(); - - assert_eq!(state.state().events(), vec!["started", "save error failed"]); - assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn rerunning_worker_does_not_renew_abandoned_leases_from_previous_runs(pool: PgPool) { - init_tracing(); - sqlx::query!("ALTER TABLE pg_task ADD CONSTRAINT reject_errors CHECK (error IS NULL)") - .execute(&pool) - .await - .unwrap(); - let state = StepStateGuard::new(); - let id = insert_task_at( - &pool, - &TestTask::FailStep(FailStep { key: state.key() }), - Utc::now() - ChronoDuration::milliseconds(1), - false, - ) - .await; - let worker = Worker::::new(pool.clone()) - .with_concurrency(nonzero(1)) - .with_lease_timeout(Duration::from_secs(1)) - .with_heartbeat_interval(Duration::from_millis(50)); - - let err = timeout(Duration::from_secs(1), worker.run()) - .await - .unwrap() - .unwrap_err(); - assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); - let (abandoned_owner, abandoned_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); - - let rerun = tokio::spawn({ - let worker = worker; - async move { worker.run().await } - }); - - sleep(Duration::from_millis(150)).await; - let (locked_by, lock_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); - assert_eq!(locked_by, abandoned_owner); - assert_eq!(lock_expires_at, abandoned_expires_at); - - stop_worker(&pool).await; - - timeout(Duration::from_secs(1), rerun) - .await - .unwrap() - .unwrap() - .unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_returns_step_errors_from_spawned_tasks(pool: PgPool) { - let state = StepStateGuard::new(); - sqlx::query!("ALTER TABLE pg_task ADD CONSTRAINT reject_errors CHECK (error IS NULL)") - .execute(&pool) - .await - .unwrap(); - insert_task( - &pool, - &TestTask::FailStep(FailStep { key: state.key() }), - false, - ) - .await; - - let worker = spawn_worker(pool); - - state.state().wait_for_events(1).await; - let err = timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap_err(); - - assert_eq!(state.state().events(), vec!["started"]); - assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_processes_multiple_blocking_steps_up_to_the_concurrency_limit(pool: PgPool) { - let first = StepStateGuard::new(); - let second = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::Blocking(Blocking { key: first.key() }), - false, - ) - .await; - insert_task( - &pool, - &TestTask::Blocking(Blocking { key: second.key() }), - false, - ) - .await; - - let worker = spawn_worker_with_concurrency(pool.clone(), 2); - - first.state().wait_for_events(1).await; - second.state().wait_for_events(1).await; - stop_worker(&pool).await; - - assert!(!worker.is_finished()); - - first.state().release(); - second.state().release(); - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert_eq!(first.state().events(), vec!["started", "completed"]); - assert_eq!(second.state().events(), vec!["started", "completed"]); - assert_eq!(task_count(&pool).await, 0); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_respects_the_configured_concurrency_limit(pool: PgPool) { - let first = StepStateGuard::new(); - let second = StepStateGuard::new(); - insert_task( - &pool, - &TestTask::Blocking(Blocking { key: first.key() }), - false, - ) - .await; - insert_task( - &pool, - &TestTask::Blocking(Blocking { key: second.key() }), - false, - ) - .await; - - let worker = spawn_worker_with_concurrency(pool.clone(), 1); - - timeout(Duration::from_secs(1), async { - loop { - let started_count = usize::from(!first.state().events().is_empty()) - + usize::from(!second.state().events().is_empty()); - if started_count == 1 { - break; - } - sleep(Duration::from_millis(10)).await; - } - }) - .await - .unwrap(); - - sleep(Duration::from_millis(100)).await; - let first_started = !first.state().events().is_empty(); - let second_started = !second.state().events().is_empty(); - assert_ne!(first_started, second_started); - - if first_started { - first.state().release(); - second.state().wait_for_events(1).await; - stop_worker(&pool).await; - second.state().release(); - } else { - second.state().release(); - first.state().wait_for_events(1).await; - stop_worker(&pool).await; - first.state().release(); - } - - timeout(Duration::from_secs(1), worker) - .await - .unwrap() - .unwrap() - .unwrap(); - - assert_eq!(first.state().events(), vec!["started", "completed"]); - assert_eq!(second.state().events(), vec!["started", "completed"]); - assert_eq!(task_count(&pool).await, 0); - } -} +mod tests; diff --git a/src/worker/tests.rs b/src/worker/tests.rs new file mode 100644 index 0000000..9813b6b --- /dev/null +++ b/src/worker/tests.rs @@ -0,0 +1,2225 @@ +use super::{ + HeartbeatEvent, RunEvents, RunningStep, TaskAvailability, Worker, DEFAULT_HEARTBEAT_INTERVAL, + DEFAULT_LEASE_TIMEOUT, +}; +use crate::{task::WorkerLease, Error, NextStep, Step}; +use chrono::{Duration as ChronoDuration, Utc}; +use parking_lot::Mutex; +use sqlx::{ + postgres::{PgConnectOptions, PgPoolOptions}, + PgPool, +}; +use std::{ + collections::HashMap, + io, + num::NonZeroUsize, + sync::{ + atomic::{AtomicU64, Ordering}, + Arc, OnceLock, + }, + time::Duration, +}; +use tokio::{ + sync::{mpsc, Notify, Semaphore}, + time::{sleep, timeout}, +}; +use uuid::Uuid; + +// Short enough to exercise PoolTimedOut, but long enough for CI to open +// the first TCP connection before the pool is intentionally exhausted. +const POOL_TIMEOUT: Duration = Duration::from_millis(100); + +fn init_tracing() { + static INIT: std::sync::Once = std::sync::Once::new(); + INIT.call_once(|| { + let _ = tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .with_test_writer() + .without_time() + .try_init(); + }); +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct Noop; + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct Advance { + key: u64, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct Finish { + key: u64, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct Complete { + key: u64, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct Blocking { + key: u64, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct FailSavingError { + key: u64, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct FailStep { + key: u64, +} + +crate::task!(TestTask { + Noop, + Advance, + Finish, + Complete, + Blocking, + FailSavingError, + FailStep, +}); + +#[async_trait::async_trait] +impl Step for Noop { + async fn step(self, _db: &PgPool) -> crate::StepResult { + Ok(NextStep::None) + } +} + +#[async_trait::async_trait] +impl Step for Advance { + async fn step(self, _db: &PgPool) -> crate::StepResult { + step_state(self.key).record("advance"); + NextStep::now(Finish { key: self.key }) + } +} + +#[async_trait::async_trait] +impl Step for Finish { + async fn step(self, _db: &PgPool) -> crate::StepResult { + step_state(self.key).record("finish"); + NextStep::none() + } +} + +#[async_trait::async_trait] +impl Step for Complete { + async fn step(self, _db: &PgPool) -> crate::StepResult { + step_state(self.key).record("complete"); + NextStep::none() + } +} + +#[async_trait::async_trait] +impl Step for Blocking { + async fn step(self, _db: &PgPool) -> crate::StepResult { + let state = step_state(self.key); + state.record("started"); + state.wait_for_release().await; + state.record("completed"); + NextStep::none() + } +} + +#[async_trait::async_trait] +impl Step for FailSavingError { + async fn step(self, db: &PgPool) -> crate::StepResult { + let state = step_state(self.key); + state.record("started"); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN error TO task_error") + .execute(db) + .await + .unwrap(); + state.record("save error failed"); + Err(io::Error::other("step failed").into()) + } + + fn retry_limit(&self) -> i32 { + 0 + } +} + +#[async_trait::async_trait] +impl Step for FailStep { + async fn step(self, _db: &PgPool) -> crate::StepResult { + step_state(self.key).record("started"); + Err(io::Error::other("step failed").into()) + } + + fn retry_limit(&self) -> i32 { + 0 + } +} + +struct StepState { + events: Mutex>, + events_changed: Notify, + release: Notify, +} + +impl StepState { + fn new() -> Self { + Self { + events: Mutex::new(Vec::new()), + events_changed: Notify::new(), + release: Notify::new(), + } + } + + fn record(&self, event: &'static str) { + self.events.lock().push(event); + self.events_changed.notify_waiters(); + } + + fn release(&self) { + self.release.notify_waiters(); + } + + fn events(&self) -> Vec<&'static str> { + self.events.lock().clone() + } + + async fn wait_for_events(&self, count: usize) { + timeout(Duration::from_secs(1), async { + loop { + if self.events.lock().len() >= count { + return; + } + self.events_changed.notified().await; + } + }) + .await + .unwrap(); + } + + async fn wait_for_release(&self) { + self.release.notified().await; + } +} + +struct StepStateGuard { + key: u64, + state: Arc, +} + +impl StepStateGuard { + fn new() -> Self { + let key = NEXT_STEP_STATE_KEY.fetch_add(1, Ordering::Relaxed); + let state = Arc::new(StepState::new()); + step_states().lock().insert(key, state.clone()); + Self { key, state } + } + + fn key(&self) -> u64 { + self.key + } + + fn state(&self) -> Arc { + self.state.clone() + } +} + +impl Drop for StepStateGuard { + fn drop(&mut self) { + step_states().lock().remove(&self.key); + } +} + +static NEXT_STEP_STATE_KEY: AtomicU64 = AtomicU64::new(1); +static STEP_STATES: OnceLock>>> = OnceLock::new(); + +fn step_states() -> &'static Mutex>> { + STEP_STATES.get_or_init(|| Mutex::new(HashMap::new())) +} + +fn step_state(key: u64) -> Arc { + step_states().lock().get(&key).cloned().unwrap() +} + +fn connection_error() -> Error { + Error::Db( + sqlx::Error::Io(io::Error::new( + io::ErrorKind::BrokenPipe, + "worker connection dropped", + )), + "test".into(), + ) +} + +async fn insert_raw_task( + pool: &PgPool, + step: &str, + wakeup_at: chrono::DateTime, + is_leased: bool, + error: Option<&str>, +) -> Uuid { + let (locked_by, lock_expires_at) = if is_leased { + ( + Some(Uuid::from_u128(1)), + Some(Utc::now() + ChronoDuration::seconds(60)), + ) + } else { + (None, None) + }; + sqlx::query!( + " + INSERT INTO pg_task (step, wakeup_at, locked_by, lock_expires_at, error) + VALUES ($1, $2, $3, $4, $5) + RETURNING id + ", + step, + wakeup_at, + locked_by, + lock_expires_at, + error, + ) + .fetch_one(pool) + .await + .unwrap() + .id +} + +async fn insert_task_at( + pool: &PgPool, + step: &TestTask, + wakeup_at: chrono::DateTime, + is_leased: bool, +) -> Uuid { + insert_raw_task( + pool, + &serde_json::to_string(step).unwrap(), + wakeup_at, + is_leased, + None, + ) + .await +} + +async fn insert_task(pool: &PgPool, step: &TestTask, is_leased: bool) { + insert_task_at( + pool, + step, + Utc::now() - ChronoDuration::milliseconds(1), + is_leased, + ) + .await; +} + +async fn set_task_lease(pool: &PgPool, id: Uuid, lock_expires_at: chrono::DateTime) { + set_task_lease_for_worker(pool, id, Uuid::from_u128(1), lock_expires_at).await; +} + +async fn set_task_lease_for_worker( + pool: &PgPool, + id: Uuid, + worker_id: Uuid, + lock_expires_at: chrono::DateTime, +) { + sqlx::query!( + r#" + UPDATE pg_task + SET locked_by = $2, + lock_expires_at = $3 + WHERE id = $1 + "#, + id, + worker_id, + lock_expires_at, + ) + .execute(pool) + .await + .unwrap(); +} + +async fn fetch_task_lease(pool: &PgPool, id: Uuid) -> Option<(Uuid, chrono::DateTime)> { + sqlx::query!( + " + SELECT locked_by, lock_expires_at + FROM pg_task + WHERE id = $1 + ", + id, + ) + .fetch_optional(pool) + .await + .unwrap() + .map(|row| (row.locked_by.unwrap(), row.lock_expires_at.unwrap())) +} + +fn idle_heartbeat() -> tokio::task::AbortHandle { + tokio::spawn(async { + std::future::pending::<()>().await; + }) + .abort_handle() +} + +fn idle_heartbeat_events() -> mpsc::UnboundedReceiver { + let (_sender, receiver) = mpsc::unbounded_channel(); + receiver +} + +fn idle_step_errors() -> mpsc::UnboundedReceiver { + let (_sender, receiver) = mpsc::unbounded_channel(); + receiver +} + +fn idle_run_events() -> RunEvents { + RunEvents { + heartbeat: idle_heartbeat_events(), + step_errors: idle_step_errors(), + } +} + +async fn connect_to_current_db( + pool: &PgPool, + max_connections: u32, + acquire_timeout: Duration, +) -> PgPool { + let db_name: String = sqlx::query_scalar!(r#"SELECT current_database() AS "db_name!""#) + .fetch_one(pool) + .await + .unwrap(); + + PgPoolOptions::new() + .max_connections(max_connections) + .acquire_timeout(acquire_timeout) + .connect_with(current_database_options(&db_name)) + .await + .unwrap() +} + +// Connect to the database created by sqlx::test while keeping the +// connection settings from DATABASE_URL. CI needs its TCP host and password; +// postgres:///{db_name} only works for local peer-auth socket setups. +fn current_database_options(db_name: &str) -> PgConnectOptions { + std::env::var("DATABASE_URL") + .expect("DATABASE_URL must be set") + .parse::() + .unwrap() + .database(db_name) +} + +async fn task_count(pool: &PgPool) -> i64 { + sqlx::query!("SELECT id FROM pg_task") + .fetch_all(pool) + .await + .unwrap() + .len() as i64 +} + +async fn stop_worker(pool: &PgPool) { + sqlx::query!("NOTIFY pg_task_changed, 'stop_worker'") + .execute(pool) + .await + .unwrap(); +} + +fn nonzero(value: usize) -> NonZeroUsize { + NonZeroUsize::new(value).unwrap() +} + +fn worker_lease(worker: &Worker) -> WorkerLease { + WorkerLease::new(Uuid::new_v4(), worker.lease_timeout) +} + +fn running_step_entry(task_id: Uuid, abort_handle: tokio::task::AbortHandle) -> RunningStep { + RunningStep { + task_id, + abort_handle, + } +} + +fn spawn_worker(pool: PgPool) -> tokio::task::JoinHandle> { + spawn_worker_with_concurrency(pool, 1) +} + +fn spawn_worker_with_concurrency( + pool: PgPool, + concurrency: usize, +) -> tokio::task::JoinHandle> { + tokio::spawn(async move { + Worker::::new(pool) + .with_concurrency(nonzero(concurrency)) + .run() + .await + }) +} + +#[tokio::test] +#[should_panic(expected = "lease timeout must be non-zero")] +async fn with_lease_timeout_rejects_zero() { + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_lease_timeout(Duration::ZERO); +} + +#[tokio::test] +#[should_panic(expected = "heartbeat interval must be non-zero")] +async fn with_heartbeat_interval_rejects_zero() { + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_heartbeat_interval(Duration::ZERO); +} + +#[tokio::test] +#[should_panic(expected = "heartbeat interval must be shorter than lease timeout")] +async fn run_rejects_lease_timeout_that_is_not_longer_than_the_heartbeat_interval() { + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_lease_timeout(DEFAULT_HEARTBEAT_INTERVAL); + + let _ = worker.run().await; +} + +#[tokio::test] +#[should_panic(expected = "heartbeat interval must be shorter than lease timeout")] +async fn run_rejects_heartbeat_interval_that_is_not_shorter_than_the_lease_timeout() { + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_heartbeat_interval(DEFAULT_LEASE_TIMEOUT); + + let _ = worker.run().await; +} + +#[test] +fn heartbeat_events_pause_resume_and_expire_fetching() { + let mut heartbeat_healthy = true; + Worker::::handle_heartbeat_event(HeartbeatEvent::Failed, &mut heartbeat_healthy) + .unwrap(); + assert!(!heartbeat_healthy); + Worker::::handle_heartbeat_event(HeartbeatEvent::Failed, &mut heartbeat_healthy) + .unwrap(); + assert!(!heartbeat_healthy); + + Worker::::handle_heartbeat_event(HeartbeatEvent::Recovered, &mut heartbeat_healthy) + .unwrap(); + assert!(heartbeat_healthy); + Worker::::handle_heartbeat_event(HeartbeatEvent::Recovered, &mut heartbeat_healthy) + .unwrap(); + assert!(heartbeat_healthy); + + let err = Worker::::handle_heartbeat_event( + HeartbeatEvent::Expired(Error::Db(sqlx::Error::PoolTimedOut, "test".into())), + &mut heartbeat_healthy, + ) + .unwrap_err(); + assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); +} + +#[tokio::test] +async fn heartbeat_expiry_interrupts_retryable_fetch_error_handling() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ); + let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); + heartbeat_events + .send(HeartbeatEvent::Expired(Error::Db( + sqlx::Error::PoolTimedOut, + "heartbeat".into(), + ))) + .unwrap(); + let mut heartbeat_healthy = true; + let mut abort_running_steps = false; + + let err = timeout( + Duration::from_millis(100), + worker.handle_recv_task_error_or_heartbeat( + Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), + &mut heartbeat_events_receiver, + &mut heartbeat_healthy, + &mut abort_running_steps, + ), + ) + .await + .unwrap() + .unwrap_err(); + + assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); + assert!(abort_running_steps); +} + +#[tokio::test] +async fn heartbeat_recovery_preserves_retryable_fetch_error_handling() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ); + let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); + heartbeat_events.send(HeartbeatEvent::Failed).unwrap(); + heartbeat_events.send(HeartbeatEvent::Recovered).unwrap(); + let mut heartbeat_healthy = true; + let mut abort_running_steps = false; + + worker + .handle_recv_task_error_or_heartbeat( + Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), + &mut heartbeat_events_receiver, + &mut heartbeat_healthy, + &mut abort_running_steps, + ) + .await + .unwrap(); + + assert!(heartbeat_healthy); + assert!(!abort_running_steps); +} + +#[tokio::test] +async fn heartbeat_failure_pauses_after_retryable_fetch_error_handling() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ); + let (heartbeat_events, mut heartbeat_events_receiver) = mpsc::unbounded_channel(); + heartbeat_events.send(HeartbeatEvent::Failed).unwrap(); + let mut heartbeat_healthy = true; + let mut abort_running_steps = false; + + worker + .handle_recv_task_error_or_heartbeat( + Error::Db(sqlx::Error::PoolTimedOut, "fetch".into()), + &mut heartbeat_events_receiver, + &mut heartbeat_healthy, + &mut abort_running_steps, + ) + .await + .unwrap(); + + assert!(!heartbeat_healthy); + assert!(!abort_running_steps); +} + +#[sqlx::test(migrations = "./migrations")] +async fn heartbeat_skips_renewal_without_running_steps(pool: PgPool) { + init_tracing(); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN lock_expires_at TO task_lock_expires_at") + .execute(&pool) + .await + .unwrap(); + let worker = Worker::::new(pool) + .with_lease_timeout(Duration::from_millis(80)) + .with_heartbeat_interval(Duration::from_millis(20)); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat( + events, + Arc::new(Mutex::new(Vec::new())), + worker_lease(&worker), + ); + + assert!(timeout(Duration::from_millis(150), events_receiver.recv()) + .await + .is_err()); + + heartbeat.abort(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn heartbeat_expires_when_running_steps_have_no_live_leases(pool: PgPool) { + init_tracing(); + let worker = Worker::::new(pool) + .with_lease_timeout(Duration::from_millis(80)) + .with_heartbeat_interval(Duration::from_millis(20)); + let running_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + Uuid::new_v4(), + running_step.abort_handle(), + )])); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat(events, running_steps, worker_lease(&worker)); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Failed)); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!( + event, + HeartbeatEvent::Expired(Error::TaskLeaseExpired) + )); + + heartbeat.abort(); + running_step.abort(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn heartbeat_expires_when_any_running_step_loses_its_lease(pool: PgPool) { + init_tracing(); + let worker = Worker::::new(pool.clone()) + .with_lease_timeout(Duration::from_millis(80)) + .with_heartbeat_interval(Duration::from_millis(20)); + let worker_id = Uuid::new_v4(); + let lease = WorkerLease::new(worker_id, worker.lease_timeout); + let live = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let expired = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + set_task_lease_for_worker( + &pool, + live, + worker_id, + Utc::now() + ChronoDuration::milliseconds(200), + ) + .await; + set_task_lease_for_worker( + &pool, + expired, + worker_id, + Utc::now() - ChronoDuration::milliseconds(1), + ) + .await; + let live_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let expired_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![ + running_step_entry(live, live_step.abort_handle()), + running_step_entry(expired, expired_step.abort_handle()), + ])); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Failed)); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!( + event, + HeartbeatEvent::Expired(Error::TaskLeaseExpired) + )); + + heartbeat.abort(); + live_step.abort(); + expired_step.abort(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn heartbeat_skips_pool_timeouts_without_running_steps(pool: PgPool) { + init_tracing(); + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; + let held_connection = worker_pool.acquire().await.unwrap(); + let worker = Worker::::new(worker_pool) + .with_lease_timeout(Duration::from_millis(500)) + .with_heartbeat_interval(Duration::from_millis(20)); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat( + events, + Arc::new(Mutex::new(Vec::new())), + worker_lease(&worker), + ); + + assert!(timeout(Duration::from_millis(150), events_receiver.recv()) + .await + .is_err()); + + drop(held_connection); + heartbeat.abort(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn heartbeat_reports_recovery_after_live_leases_are_renewed(pool: PgPool) { + init_tracing(); + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; + let held_connection = worker_pool.acquire().await.unwrap(); + let worker = Worker::::new(worker_pool) + .with_lease_timeout(Duration::from_millis(500)) + .with_heartbeat_interval(Duration::from_millis(20)); + let worker_id = Uuid::new_v4(); + let lease = WorkerLease::new(worker_id, worker.lease_timeout); + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let initial_expires_at = Utc::now() + ChronoDuration::milliseconds(200); + sqlx::query!( + " + UPDATE pg_task + SET locked_by = $2, + lock_expires_at = $3 + WHERE id = $1 + ", + id, + worker_id, + initial_expires_at, + ) + .execute(&pool) + .await + .unwrap(); + let running_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + id, + running_step.abort_handle(), + )])); + let (events, mut events_receiver) = mpsc::unbounded_channel(); + let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Failed)); + + drop(held_connection); + + let event = timeout(Duration::from_secs(1), events_receiver.recv()) + .await + .unwrap() + .unwrap(); + assert!(matches!(event, HeartbeatEvent::Recovered)); + + let (_locked_by, renewed_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + assert!(renewed_expires_at > initial_expires_at); + + heartbeat.abort(); + running_step.abort(); +} + +#[tokio::test] +async fn running_step_tracking_prunes_finished_steps() { + let finished_step = tokio::spawn(async {}); + let finished_step_abort = finished_step.abort_handle(); + finished_step.await.unwrap(); + + let running_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_step_abort = running_step.abort_handle(); + let running_steps = Mutex::new(vec![running_step_entry( + Uuid::new_v4(), + finished_step_abort, + )]); + + Worker::::track_running_step(&running_steps, Uuid::new_v4(), running_step_abort); + + assert_eq!(running_steps.lock().len(), 1); + assert!(Worker::::has_running_steps(&running_steps)); + + Worker::::abort_running_steps(&running_steps); + assert!(running_step.await.unwrap_err().is_cancelled()); + assert!(!Worker::::has_running_steps(&running_steps)); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_returns_listener_startup_errors(pool: PgPool) { + let worker = Worker::::new(pool); + worker.listener.fail_next_listen_for_tests(); + + let err = worker.run().await.unwrap_err(); + + assert!(matches!( + err, + Error::ListenerListen(sqlx::Error::Protocol(_)) + )); +} + +#[sqlx::test(migrations = "./migrations")] +async fn handle_recv_task_error_returns_permanent_fetch_errors(pool: PgPool) { + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN step TO task_step") + .execute(&pool) + .await + .unwrap(); + + let err = sqlx::query!("SELECT step FROM pg_task") + .fetch_one(&pool) + .await + .unwrap_err(); + + let worker = Worker::::new(pool); + let err = worker + .handle_recv_task_error(Error::Db(err, "test".into())) + .await + .unwrap_err(); + + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); +} + +#[tokio::test] +async fn handle_recv_task_error_retries_pool_timeouts() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ); + + worker + .handle_recv_task_error(Error::Db(sqlx::Error::PoolTimedOut, "test".into())) + .await + .unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn handle_recv_task_error_waits_for_reconnection_after_connection_errors(pool: PgPool) { + init_tracing(); + let worker = Worker::::new(pool); + + worker + .handle_recv_task_error(connection_error()) + .await + .unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn handle_recv_task_error_returns_reconnection_failures(pool: PgPool) { + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") + .execute(&pool) + .await + .unwrap(); + + let worker = Worker::::new(pool); + let err = worker + .handle_recv_task_error(connection_error()) + .await + .unwrap_err(); + + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_returns_listener_errors(pool: PgPool) { + let worker = Worker::::new(pool); + worker + .listener + .set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + + let lease = worker_lease(&worker); + let err = worker.recv_task(lease).await.unwrap_err(); + + assert!(matches!( + err, + Error::ListenerReceive(sqlx::Error::Protocol(_)) + )); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_stops_even_if_listener_has_failed(pool: PgPool) { + let worker = Worker::::new(pool); + worker + .listener + .set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + worker.listener.stop_worker_for_tests(); + + let lease = worker_lease(&worker); + assert!(worker.recv_task(lease).await.unwrap().is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_returns_begin_errors_when_the_pool_is_closed(pool: PgPool) { + let worker = Worker::::new(pool.clone()); + pool.close().await; + + let lease = worker_lease(&worker); + let err = worker.recv_task(lease).await.unwrap_err(); + + match err { + Error::Db(sqlx::Error::PoolClosed, context) => { + assert!(context.contains("begin")); + } + _ => panic!("expected a pool-closed begin error"), + } +} + +#[sqlx::test(migrations = "./migrations")] +async fn wait_for_available_task_does_not_claim_ready_tasks(pool: PgPool) { + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let worker = Worker::::new(pool.clone()); + + let availability = worker.wait_for_available_task().await.unwrap(); + + assert!(matches!(availability, TaskAvailability::Ready)); + let lease = sqlx::query!( + "SELECT locked_by, lock_expires_at FROM pg_task WHERE id = $1", + id, + ) + .fetch_one(&pool) + .await + .unwrap(); + assert!(lease.locked_by.is_none()); + assert!(lease.lock_expires_at.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_stops_while_waiting_for_work(pool: PgPool) { + let worker = Arc::new(Worker::::new(pool)); + let lease = worker_lease(&worker); + let recv = tokio::spawn({ + let worker = worker.clone(); + async move { worker.recv_task(lease).await } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!recv.is_finished()); + worker.listener.stop_worker_for_tests(); + + let received = timeout(Duration::from_secs(1), recv) + .await + .unwrap() + .unwrap() + .unwrap(); + assert!(received.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_returns_listener_errors_while_waiting_for_work(pool: PgPool) { + let worker = Arc::new(Worker::::new(pool)); + let lease = worker_lease(&worker); + let recv = tokio::spawn({ + let worker = worker.clone(); + async move { worker.recv_task(lease).await } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!recv.is_finished()); + worker + .listener + .set_error_and_notify_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + + let err = timeout(Duration::from_secs(1), recv) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!( + err, + Error::ListenerReceive(sqlx::Error::Protocol(_)) + )); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_skips_invalid_tasks_and_returns_next_ready_task(pool: PgPool) { + let invalid_id = insert_raw_task( + &pool, + "not-json", + Utc::now() - ChronoDuration::seconds(2), + false, + None, + ) + .await; + let expected = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::seconds(1), + false, + ) + .await; + let worker = Worker::::new(pool.clone()); + let lease = worker_lease(&worker); + + let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); + + assert_eq!(task.id, expected); + assert!(matches!(step, TestTask::Noop(Noop))); + let invalid_row = sqlx::query!( + "SELECT locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", + invalid_id, + ) + .fetch_one(&pool) + .await + .unwrap(); + assert!(invalid_row.locked_by.is_none()); + assert!(invalid_row.lock_expires_at.is_none()); + assert!(invalid_row.error.is_some()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_rechecks_locked_ready_tasks_without_notifications(pool: PgPool) { + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let mut tx = pool.begin().await.unwrap(); + let locked = sqlx::query!("SELECT id FROM pg_task WHERE id = $1 FOR UPDATE", id) + .fetch_one(&mut *tx) + .await + .unwrap(); + assert_eq!(locked.id, id); + + let worker = Worker::::new(pool); + let lease = worker_lease(&worker); + let recv = tokio::spawn(async move { worker.recv_task(lease).await }); + + sleep(Duration::from_millis(50)).await; + assert!(!recv.is_finished()); + + tx.rollback().await.unwrap(); + + let (task, step, _lease) = timeout(Duration::from_secs(1), recv) + .await + .unwrap() + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(task.id, id); + assert!(matches!(step, TestTask::Noop(Noop))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_rechecks_leased_tasks_when_their_lease_expires(pool: PgPool) { + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + true, + ) + .await; + let lock_expires_at = Utc::now() + ChronoDuration::seconds(1); + set_task_lease(&pool, id, lock_expires_at).await; + + let worker = Worker::::new(pool); + let lease = worker_lease(&worker); + let recv = tokio::spawn(async move { worker.recv_task(lease).await }); + + let (task, step, _lease) = timeout(Duration::from_secs(2), recv) + .await + .unwrap() + .unwrap() + .unwrap() + .unwrap(); + assert!(Utc::now() >= lock_expires_at); + assert_eq!(task.id, id); + assert!(matches!(step, TestTask::Noop(Noop))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn recv_task_replaces_expired_lease_with_the_current_worker(pool: PgPool) { + let id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + true, + ) + .await; + set_task_lease(&pool, id, Utc::now() - ChronoDuration::milliseconds(1)).await; + let worker = Worker::::new(pool.clone()); + let worker_id = Uuid::new_v4(); + let lease = WorkerLease::new(worker_id, worker.lease_timeout); + + let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); + + assert_eq!(task.id, id); + assert!(matches!(step, TestTask::Noop(Noop))); + let (locked_by, lock_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + assert_eq!(locked_by, worker_id); + assert!(lock_expires_at > Utc::now()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn two_workers_claim_ready_tasks_once(pool: PgPool) { + let first_id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let second_id = insert_task_at( + &pool, + &TestTask::Noop(Noop), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let first_worker = Worker::::new(pool.clone()); + let second_worker = Worker::::new(pool.clone()); + let first_lease = worker_lease(&first_worker); + let second_lease = worker_lease(&second_worker); + + let first_recv = tokio::spawn(async move { first_worker.recv_task(first_lease).await }); + let second_recv = tokio::spawn(async move { second_worker.recv_task(second_lease).await }); + + let (first_task, first_step, _first_lease) = timeout(Duration::from_secs(1), first_recv) + .await + .unwrap() + .unwrap() + .unwrap() + .unwrap(); + let (second_task, second_step, _second_lease) = timeout(Duration::from_secs(1), second_recv) + .await + .unwrap() + .unwrap() + .unwrap() + .unwrap(); + + assert!(matches!(first_step, TestTask::Noop(Noop))); + assert!(matches!(second_step, TestTask::Noop(Noop))); + assert_ne!(first_task.id, second_task.id); + assert!([first_id, second_id].contains(&first_task.id)); + assert!([first_id, second_id].contains(&second_task.id)); + + let running = sqlx::query!("SELECT id FROM pg_task WHERE locked_by IS NOT NULL") + .fetch_all(&pool) + .await + .unwrap(); + assert_eq!(running.len(), 2); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_renews_leases_for_running_tasks(pool: PgPool) { + let state = StepStateGuard::new(); + let id = insert_task_at( + &pool, + &TestTask::Blocking(Blocking { key: state.key() }), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + + let worker = tokio::spawn({ + let pool = pool.clone(); + async move { + Worker::::new(pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_secs(1)) + .with_heartbeat_interval(Duration::from_millis(50)) + .run() + .await + } + }); + + state.state().wait_for_events(1).await; + let (locked_by, initial_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + + let (renewed_by, renewed_expires_at) = timeout(Duration::from_secs(1), async { + loop { + let (renewed_by, renewed_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + if renewed_expires_at > initial_expires_at { + return (renewed_by, renewed_expires_at); + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + assert_eq!(renewed_by, locked_by); + assert!(renewed_expires_at > Utc::now()); + + stop_worker(&pool).await; + state.state().release(); + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_aborts_running_steps_when_heartbeat_cannot_renew_before_the_lease_expires( + pool: PgPool, +) { + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: state.key() }), + false, + ) + .await; + + let worker = tokio::spawn({ + let pool = pool.clone(); + async move { + Worker::::new(pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(250)) + .with_heartbeat_interval(Duration::from_millis(100)) + .run() + .await + } + }); + + state.state().wait_for_events(1).await; + sleep(Duration::from_millis(50)).await; + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN lock_expires_at TO task_lock_expires_at") + .execute(&pool) + .await + .unwrap(); + + let err = timeout(Duration::from_secs(2), worker) + .await + .unwrap() + .unwrap() + .unwrap_err(); + + assert_eq!(state.state().events(), vec!["started"]); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_pauses_fetching_while_heartbeat_cannot_renew(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; + let state = StepStateGuard::new(); + + let worker = tokio::spawn(async move { + Worker::::new(worker_pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(200)) + .with_heartbeat_interval(Duration::from_millis(50)) + .run() + .await + }); + + sleep(Duration::from_millis(100)).await; + insert_task( + &pool, + &TestTask::Complete(Complete { key: state.key() }), + false, + ) + .await; + sleep(Duration::from_millis(150)).await; + + stop_worker(&pool).await; + + timeout(Duration::from_secs(2), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert!(state.state().events().is_empty()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_returns_listener_errors_while_fetching_is_paused(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; + let worker = Arc::new( + Worker::::new(worker_pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(200)) + .with_heartbeat_interval(Duration::from_millis(50)), + ); + let run = tokio::spawn({ + let worker = worker.clone(); + async move { worker.run().await } + }); + + sleep(Duration::from_millis(1250)).await; + worker + .listener + .set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + + let err = timeout(Duration::from_secs(3), run) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!( + err, + Error::ListenerReceive(sqlx::Error::Protocol(_)) + )); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_keeps_waiting_after_retryable_errors_while_fetching_is_paused(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; + let worker = Arc::new( + Worker::::new(worker_pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(200)) + .with_heartbeat_interval(Duration::from_millis(50)), + ); + let run = tokio::spawn({ + let worker = worker.clone(); + async move { worker.run().await } + }); + + sleep(Duration::from_millis(1250)).await; + worker + .listener + .set_error_for_tests(Error::Db(sqlx::Error::PoolTimedOut, "fetch task".into())); + sleep(Duration::from_millis(1300)).await; + assert!(!run.is_finished()); + + stop_worker(&pool).await; + + timeout(Duration::from_secs(3), run) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_resumes_fetching_after_heartbeat_recovers(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 2, POOL_TIMEOUT).await; + let held_connection = worker_pool.acquire().await.unwrap(); + let state = StepStateGuard::new(); + + let worker = tokio::spawn({ + let worker_pool = worker_pool.clone(); + async move { + Worker::::new(worker_pool) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_millis(300)) + .with_heartbeat_interval(Duration::from_millis(50)) + .run() + .await + } + }); + + sleep(Duration::from_millis(100)).await; + insert_task( + &pool, + &TestTask::Complete(Complete { key: state.key() }), + false, + ) + .await; + sleep(Duration::from_millis(150)).await; + + assert!(state.state().events().is_empty()); + drop(held_connection); + + timeout(Duration::from_secs(3), async { + loop { + if !state.state().events().is_empty() { + break; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + stop_worker(&pool).await; + + timeout(Duration::from_secs(2), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(state.state().events(), vec!["complete"]); +} + +#[tokio::test] +async fn finish_run_waits_for_inflight_steps_before_returning_errors() { + init_tracing(); + let worker = Arc::new( + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)), + ); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + + let task = tokio::spawn({ + let worker = worker.clone(); + let semaphore = semaphore.clone(); + async move { + worker + .finish_run( + Err(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))), + semaphore, + idle_heartbeat(), + Arc::new(Mutex::new(Vec::new())), + false, + idle_run_events(), + ) + .await + } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!task.is_finished()); + + drop(permit); + + let err = task.await.unwrap().unwrap_err(); + assert!(matches!( + err, + Error::ListenerReceive(sqlx::Error::Protocol(_)) + )); +} + +#[tokio::test] +async fn finish_run_returns_step_errors_received_while_draining() { + init_tracing(); + let worker = Arc::new( + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)), + ); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let (step_error_sender, step_errors) = mpsc::unbounded_channel(); + + let finish = tokio::spawn({ + let worker = worker.clone(); + let semaphore = semaphore.clone(); + async move { + worker + .finish_run( + Ok(()), + semaphore, + idle_heartbeat(), + Arc::new(Mutex::new(Vec::new())), + false, + RunEvents { + heartbeat: idle_heartbeat_events(), + step_errors, + }, + ) + .await + } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!finish.is_finished()); + + step_error_sender + .send(Error::Db(sqlx::Error::PoolTimedOut, "step".into())) + .unwrap(); + drop(permit); + + let err = timeout(Duration::from_secs(1), finish) + .await + .unwrap() + .unwrap() + .unwrap_err(); + assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); +} + +#[tokio::test] +async fn finish_run_keeps_heartbeat_alive_while_waiting_for_inflight_steps() { + init_tracing(); + let worker = Arc::new( + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)), + ); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let heartbeat = tokio::spawn(async { + std::future::pending::<()>().await; + }); + + let finish = tokio::spawn({ + let worker = worker.clone(); + let semaphore = semaphore.clone(); + let heartbeat_abort = heartbeat.abort_handle(); + async move { + worker + .finish_run( + Ok(()), + semaphore, + heartbeat_abort, + Arc::new(Mutex::new(Vec::new())), + false, + idle_run_events(), + ) + .await + } + }); + + sleep(Duration::from_millis(50)).await; + assert!(!finish.is_finished()); + assert!(!heartbeat.is_finished()); + + drop(permit); + + timeout(Duration::from_secs(1), finish) + .await + .unwrap() + .unwrap() + .unwrap(); + assert!(heartbeat.await.unwrap_err().is_cancelled()); +} + +#[tokio::test] +async fn finish_run_aborts_inflight_steps_when_lease_renewal_expires() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let running_step = tokio::spawn(async move { + let _permit = permit; + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + Uuid::new_v4(), + running_step.abort_handle(), + )])); + + let err = timeout( + Duration::from_secs(1), + worker.finish_run( + Err(Error::Db(sqlx::Error::PoolTimedOut, "test".into())), + semaphore, + idle_heartbeat(), + running_steps, + true, + idle_run_events(), + ), + ) + .await + .unwrap() + .unwrap_err(); + + assert!(matches!(err, Error::Db(sqlx::Error::PoolTimedOut, _))); + assert!(running_step.await.unwrap_err().is_cancelled()); +} + +#[tokio::test] +async fn finish_run_aborts_inflight_steps_when_heartbeat_expires_while_draining() { + init_tracing(); + let worker = Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(1)); + let semaphore = Arc::new(Semaphore::new(1)); + let permit = semaphore.clone().acquire_owned().await.unwrap(); + let running_step = tokio::spawn(async move { + let _permit = permit; + std::future::pending::<()>().await; + }); + let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + Uuid::new_v4(), + running_step.abort_handle(), + )])); + let (heartbeat_events_sender, heartbeat_events) = mpsc::unbounded_channel(); + heartbeat_events_sender + .send(HeartbeatEvent::Expired(Error::TaskLeaseExpired)) + .unwrap(); + + let err = timeout( + Duration::from_secs(1), + worker.finish_run( + Ok(()), + semaphore, + idle_heartbeat(), + running_steps, + false, + RunEvents { + heartbeat: heartbeat_events, + step_errors: idle_step_errors(), + }, + ), + ) + .await + .unwrap() + .unwrap_err(); + + assert!(matches!(err, Error::TaskLeaseExpired)); + assert!(running_step.await.unwrap_err().is_cancelled()); +} + +#[tokio::test] +async fn wait_for_steps_to_finish_rechecks_when_the_inflight_task_count_changes() { + init_tracing(); + let worker = Arc::new( + Worker::::new( + PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(), + ) + .with_concurrency(nonzero(2)), + ); + let semaphore = Arc::new(Semaphore::new(2)); + let first = semaphore.clone().acquire_owned().await.unwrap(); + let second = semaphore.clone().acquire_owned().await.unwrap(); + + let wait = tokio::spawn({ + let worker = worker.clone(); + let semaphore = semaphore.clone(); + async move { + worker.wait_for_steps_to_finish(semaphore).await; + } + }); + + sleep(Duration::from_millis(10)).await; + drop(first); + + sleep(Duration::from_millis(150)).await; + assert!(!wait.is_finished()); + + drop(second); + + timeout(Duration::from_secs(1), wait) + .await + .unwrap() + .unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_processes_followup_steps_to_completion(pool: PgPool) { + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Advance(Advance { key: state.key() }), + false, + ) + .await; + + let worker = spawn_worker(pool.clone()); + + state.state().wait_for_events(2).await; + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(state.state().events(), vec!["advance", "finish"]); + assert_eq!(task_count(&pool).await, 0); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_returns_listener_errors_when_the_pool_is_closed(pool: PgPool) { + let worker = spawn_worker(pool.clone()); + + sleep(Duration::from_millis(100)).await; + pool.close().await; + + let err = timeout(Duration::from_secs(2), worker) + .await + .unwrap() + .unwrap() + .unwrap_err(); + + assert!(matches!( + err, + Error::ListenerReceive(sqlx::Error::PoolClosed) + )); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_recovers_from_pool_timeouts_until_a_stop_notification_arrives(pool: PgPool) { + let worker_pool = connect_to_current_db(&pool, 1, POOL_TIMEOUT).await; + let worker = spawn_worker(worker_pool); + + sleep(Duration::from_millis(100)).await; + assert!(!worker.is_finished()); + + stop_worker(&pool).await; + + timeout(Duration::from_secs(3), worker) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_stops_when_stop_notification_arrives_while_idle(pool: PgPool) { + let worker = spawn_worker(pool.clone()); + + sleep(Duration::from_millis(100)).await; + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + assert_eq!(task_count(&pool).await, 0); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_wakes_up_for_tasks_inserted_while_idle(pool: PgPool) { + let state = StepStateGuard::new(); + let worker = spawn_worker(pool.clone()); + + sleep(Duration::from_millis(100)).await; + insert_task( + &pool, + &TestTask::Complete(Complete { key: state.key() }), + false, + ) + .await; + + state.state().wait_for_events(1).await; + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(state.state().events(), vec!["complete"]); + assert_eq!(task_count(&pool).await, 0); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_processes_noop_tasks_to_completion(pool: PgPool) { + insert_task(&pool, &TestTask::Noop(Noop), false).await; + + let worker = spawn_worker(pool.clone()); + + timeout(Duration::from_secs(1), async { + loop { + if task_count(&pool).await == 0 { + break; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_waits_for_future_tasks_to_become_ready_without_notifications(pool: PgPool) { + let state = StepStateGuard::new(); + insert_task_at( + &pool, + &TestTask::Complete(Complete { key: state.key() }), + Utc::now() + ChronoDuration::milliseconds(150), + false, + ) + .await; + + let worker = spawn_worker(pool.clone()); + + assert!( + timeout(Duration::from_millis(50), state.state().wait_for_events(1)) + .await + .is_err() + ); + + state.state().wait_for_events(1).await; + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(state.state().events(), vec!["complete"]); + assert_eq!(task_count(&pool).await, 0); +} + +#[sqlx::test(migrations = "./migrations")] +async fn starting_another_worker_does_not_unlock_live_tasks(pool: PgPool) { + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: state.key() }), + false, + ) + .await; + + let first_worker = spawn_worker(pool.clone()); + state.state().wait_for_events(1).await; + + let second_worker = spawn_worker(pool.clone()); + sleep(Duration::from_millis(150)).await; + assert_eq!(state.state().events(), vec!["started"]); + + stop_worker(&pool).await; + state.state().release(); + + timeout(Duration::from_secs(1), first_worker) + .await + .unwrap() + .unwrap() + .unwrap(); + timeout(Duration::from_secs(1), second_worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(state.state().events(), vec!["started", "completed"]); + assert_eq!(task_count(&pool).await, 0); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_skips_invalid_tasks_and_keeps_processing_ready_tasks(pool: PgPool) { + let invalid_id = insert_raw_task( + &pool, + "not-json", + Utc::now() - ChronoDuration::milliseconds(10), + false, + None, + ) + .await; + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Complete(Complete { key: state.key() }), + false, + ) + .await; + + let worker = spawn_worker(pool.clone()); + + state.state().wait_for_events(1).await; + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + let invalid_row = sqlx::query!( + "SELECT tried, locked_by, lock_expires_at, error FROM pg_task WHERE id = $1", + invalid_id, + ) + .fetch_one(&pool) + .await + .unwrap(); + + assert_eq!(state.state().events(), vec!["complete"]); + assert_eq!(invalid_row.tried, 0); + assert!(invalid_row.locked_by.is_none()); + assert!(invalid_row.lock_expires_at.is_none()); + assert!(invalid_row.error.is_some()); + assert_eq!(task_count(&pool).await, 1); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_stops_after_running_steps_finish(pool: PgPool) { + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: state.key() }), + false, + ) + .await; + + let worker = spawn_worker(pool.clone()); + + state.state().wait_for_events(1).await; + stop_worker(&pool).await; + + sleep(Duration::from_millis(50)).await; + assert!(!worker.is_finished()); + + state.state().release(); + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(state.state().events(), vec!["started", "completed"]); + assert_eq!(task_count(&pool).await, 0); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_returns_step_errors_received_after_stop_while_draining(pool: PgPool) { + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: state.key() }), + false, + ) + .await; + + let worker = spawn_worker_with_concurrency(pool.clone(), 2); + + state.state().wait_for_events(1).await; + stop_worker(&pool).await; + sleep(Duration::from_millis(50)).await; + assert!(!worker.is_finished()); + + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") + .execute(&pool) + .await + .unwrap(); + state.state().release(); + + let err = timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap_err(); + + assert_eq!(state.state().events(), vec!["started", "completed"]); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_returns_spawned_step_persistence_errors(pool: PgPool) { + let state = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::FailSavingError(FailSavingError { key: state.key() }), + false, + ) + .await; + + let worker = spawn_worker(pool.clone()); + + state.state().wait_for_events(2).await; + let err = timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap_err(); + + assert_eq!(state.state().events(), vec!["started", "save error failed"]); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn rerunning_worker_does_not_renew_abandoned_leases_from_previous_runs(pool: PgPool) { + init_tracing(); + sqlx::query!("ALTER TABLE pg_task ADD CONSTRAINT reject_errors CHECK (error IS NULL)") + .execute(&pool) + .await + .unwrap(); + let state = StepStateGuard::new(); + let id = insert_task_at( + &pool, + &TestTask::FailStep(FailStep { key: state.key() }), + Utc::now() - ChronoDuration::milliseconds(1), + false, + ) + .await; + let worker = Worker::::new(pool.clone()) + .with_concurrency(nonzero(1)) + .with_lease_timeout(Duration::from_secs(1)) + .with_heartbeat_interval(Duration::from_millis(50)); + + let err = timeout(Duration::from_secs(1), worker.run()) + .await + .unwrap() + .unwrap_err(); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); + let (abandoned_owner, abandoned_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + + let rerun = tokio::spawn({ + let worker = worker; + async move { worker.run().await } + }); + + sleep(Duration::from_millis(150)).await; + let (locked_by, lock_expires_at) = fetch_task_lease(&pool, id).await.unwrap(); + assert_eq!(locked_by, abandoned_owner); + assert_eq!(lock_expires_at, abandoned_expires_at); + + stop_worker(&pool).await; + + timeout(Duration::from_secs(1), rerun) + .await + .unwrap() + .unwrap() + .unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_returns_step_errors_from_spawned_tasks(pool: PgPool) { + let state = StepStateGuard::new(); + sqlx::query!("ALTER TABLE pg_task ADD CONSTRAINT reject_errors CHECK (error IS NULL)") + .execute(&pool) + .await + .unwrap(); + insert_task( + &pool, + &TestTask::FailStep(FailStep { key: state.key() }), + false, + ) + .await; + + let worker = spawn_worker(pool); + + state.state().wait_for_events(1).await; + let err = timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap_err(); + + assert_eq!(state.state().events(), vec!["started"]); + assert!(matches!(err, Error::Db(sqlx::Error::Database(_), _))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_processes_multiple_blocking_steps_up_to_the_concurrency_limit(pool: PgPool) { + let first = StepStateGuard::new(); + let second = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: first.key() }), + false, + ) + .await; + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: second.key() }), + false, + ) + .await; + + let worker = spawn_worker_with_concurrency(pool.clone(), 2); + + first.state().wait_for_events(1).await; + second.state().wait_for_events(1).await; + stop_worker(&pool).await; + + assert!(!worker.is_finished()); + + first.state().release(); + second.state().release(); + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(first.state().events(), vec!["started", "completed"]); + assert_eq!(second.state().events(), vec!["started", "completed"]); + assert_eq!(task_count(&pool).await, 0); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_respects_the_configured_concurrency_limit(pool: PgPool) { + let first = StepStateGuard::new(); + let second = StepStateGuard::new(); + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: first.key() }), + false, + ) + .await; + insert_task( + &pool, + &TestTask::Blocking(Blocking { key: second.key() }), + false, + ) + .await; + + let worker = spawn_worker_with_concurrency(pool.clone(), 1); + + timeout(Duration::from_secs(1), async { + loop { + let started_count = usize::from(!first.state().events().is_empty()) + + usize::from(!second.state().events().is_empty()); + if started_count == 1 { + break; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + sleep(Duration::from_millis(100)).await; + let first_started = !first.state().events().is_empty(); + let second_started = !second.state().events().is_empty(); + assert_ne!(first_started, second_started); + + if first_started { + first.state().release(); + second.state().wait_for_events(1).await; + stop_worker(&pool).await; + second.state().release(); + } else { + second.state().release(); + first.state().wait_for_events(1).await; + stop_worker(&pool).await; + first.state().release(); + } + + timeout(Duration::from_secs(1), worker) + .await + .unwrap() + .unwrap() + .unwrap(); + + assert_eq!(first.state().events(), vec!["started", "completed"]); + assert_eq!(second.state().events(), vec!["started", "completed"]); + assert_eq!(task_count(&pool).await, 0); +} From eee169f645f526ba87f2edb624ad410a67c8c3ef Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 14:53:49 +0600 Subject: [PATCH 41/44] Move task tests into child module --- src/task.rs | 1230 +-------------------------------------------- src/task/tests.rs | 1210 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 1211 insertions(+), 1229 deletions(-) create mode 100644 src/task/tests.rs diff --git a/src/task.rs b/src/task.rs index 80bcf90..52ade4f 100644 --- a/src/task.rs +++ b/src/task.rs @@ -435,1232 +435,4 @@ impl Task { } #[cfg(test)] -mod tests { - use super::{Task, WorkerLease}; - use crate::{NextStep, Step}; - use chrono::{DateTime, Duration as ChronoDuration, Utc}; - use sqlx::PgPool; - use std::{io, time::Duration}; - use uuid::Uuid; - - fn init_tracing() { - static INIT: std::sync::Once = std::sync::Once::new(); - INIT.call_once(|| { - let _ = tracing_subscriber::fmt() - .with_max_level(tracing::Level::TRACE) - .with_test_writer() - .without_time() - .try_init(); - }); - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct Valid; - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct AdvanceNow { - value: i32, - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct AdvanceLater { - value: i32, - delay_ms: u64, - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct Followup { - value: i32, - } - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct RetryFail; - - #[derive(Debug, serde::Deserialize, serde::Serialize)] - pub(super) struct BrokenNext; - - #[derive(Debug)] - pub(super) struct Unserializable; - - crate::task!(TestTask { - Valid, - AdvanceNow, - AdvanceLater, - Followup, - RetryFail, - BrokenNext, - Unserializable, - }); - - #[async_trait::async_trait] - impl Step for Valid { - async fn step(self, _db: &PgPool) -> crate::StepResult { - Ok(NextStep::None) - } - } - - #[async_trait::async_trait] - impl Step for AdvanceNow { - async fn step(self, _db: &PgPool) -> crate::StepResult { - NextStep::now(Followup { - value: self.value + 1, - }) - } - } - - #[async_trait::async_trait] - impl Step for AdvanceLater { - async fn step(self, _db: &PgPool) -> crate::StepResult { - NextStep::delay( - Followup { - value: self.value + 1, - }, - Duration::from_millis(self.delay_ms), - ) - } - } - - #[async_trait::async_trait] - impl Step for Followup { - async fn step(self, _db: &PgPool) -> crate::StepResult { - NextStep::none() - } - } - - #[async_trait::async_trait] - impl Step for RetryFail { - const RETRY_LIMIT: i32 = 2; - const RETRY_DELAY: Duration = Duration::from_millis(250); - - async fn step(self, _db: &PgPool) -> crate::StepResult { - Err(io::Error::other("retryable failure").into()) - } - } - - #[async_trait::async_trait] - impl Step for BrokenNext { - async fn step(self, _db: &PgPool) -> crate::StepResult { - NextStep::now(Unserializable) - } - } - - impl serde::Serialize for Unserializable { - fn serialize(&self, _serializer: S) -> Result - where - S: serde::Serializer, - { - Err(serde::ser::Error::custom("can't serialize test step")) - } - } - - impl<'de> serde::Deserialize<'de> for Unserializable { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - <() as serde::Deserialize>::deserialize(deserializer)?; - Ok(Self) - } - } - - #[async_trait::async_trait] - impl Step for Unserializable { - async fn step(self, _db: &PgPool) -> crate::StepResult { - unreachable!("the test never executes the unserializable step") - } - } - - #[derive(Debug)] - struct TaskRow { - step: String, - wakeup_at: DateTime, - tried: i32, - locked_by: Option, - lock_expires_at: Option>, - error: Option, - } - - fn serialized_step(step: &TestTask) -> String { - serde_json::to_string(step).unwrap() - } - - fn task_with_step(id: Uuid, step: &TestTask, tried: i32) -> Task { - Task { - id, - step: serialized_step(step), - tried, - } - } - - fn raw_task(id: Uuid, step: &str, tried: i32) -> Task { - Task { - id, - step: step.into(), - tried, - } - } - - fn assert_database_error(err: crate::Error) { - assert!(matches!(err, crate::Error::Db(sqlx::Error::Database(_), _))); - } - - fn worker_id() -> Uuid { - Uuid::from_u128(1) - } - - fn other_worker_id() -> Uuid { - Uuid::from_u128(2) - } - - fn worker_lease() -> WorkerLease { - WorkerLease::new(worker_id(), Duration::from_secs(60)) - } - - #[test] - fn worker_lease_converts_timeout_to_microsecond_interval() { - let lease = WorkerLease::new(worker_id(), Duration::from_micros(42)); - assert_eq!(lease.timeout.microseconds, 42); - } - - #[test] - fn worker_lease_rounds_timeout_up_to_microseconds() { - let lease = WorkerLease::new(worker_id(), Duration::from_nanos(1)); - assert_eq!(lease.timeout.microseconds, 1); - } - - #[test] - fn worker_lease_saturates_large_timeouts() { - let lease = WorkerLease::new(worker_id(), Duration::MAX); - assert_eq!(lease.timeout.microseconds, i64::MAX); - } - - async fn insert_task_row( - pool: &PgPool, - step: &str, - wakeup_at: DateTime, - tried: i32, - is_leased: bool, - error: Option<&str>, - ) -> Uuid { - let (locked_by, lock_expires_at) = if is_leased { - ( - Some(other_worker_id()), - Some(Utc::now() + ChronoDuration::seconds(60)), - ) - } else { - (None, None) - }; - sqlx::query!( - " - INSERT INTO pg_task (step, wakeup_at, tried, locked_by, lock_expires_at, error) - VALUES ($1, $2, $3, $4, $5, $6) - RETURNING id - ", - step, - wakeup_at, - tried, - locked_by, - lock_expires_at, - error, - ) - .fetch_one(pool) - .await - .unwrap() - .id - } - - async fn insert_task(pool: &PgPool, step: &TestTask, tried: i32, is_leased: bool) -> Uuid { - let step = serialized_step(step); - insert_task_row( - pool, - &step, - Utc::now() - ChronoDuration::milliseconds(1), - tried, - is_leased, - None, - ) - .await - } - - async fn set_task_lease( - pool: &PgPool, - id: Uuid, - worker_id: Uuid, - lock_expires_at: DateTime, - ) { - sqlx::query!( - r#" - UPDATE pg_task - SET locked_by = $2, - lock_expires_at = $3 - WHERE id = $1 - "#, - id, - worker_id, - lock_expires_at, - ) - .execute(pool) - .await - .unwrap(); - } - - async fn claim_task( - pool: &PgPool, - step: TestTask, - tried: i32, - ) -> (Task, TestTask, WorkerLease) { - let id = insert_task(pool, &step, tried, false).await; - let mut tx = pool.begin().await.unwrap(); - let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - assert_eq!(task.id, id); - let lease = worker_lease(); - let claimed = task - .claim::(&mut tx, lease) - .await - .unwrap() - .unwrap(); - tx.commit().await.unwrap(); - (task, claimed, lease) - } - - async fn fetch_task_row(pool: &PgPool, id: Uuid) -> Option { - sqlx::query!( - " - SELECT step, wakeup_at, tried, locked_by, lock_expires_at, error - FROM pg_task - WHERE id = $1 - ", - id, - ) - .fetch_optional(pool) - .await - .unwrap() - .map(|row| TaskRow { - step: row.step, - wakeup_at: row.wakeup_at, - tried: row.tried, - locked_by: row.locked_by, - lock_expires_at: row.lock_expires_at, - error: row.error, - }) - } - - fn assert_timestamp_between( - actual: DateTime, - earliest: DateTime, - latest: DateTime, - ) { - assert!( - actual >= earliest, - "timestamp {actual:?} should be after {earliest:?}", - ); - assert!( - actual <= latest, - "timestamp {actual:?} should be before {latest:?}", - ); - } - - #[test] - fn delay_until_returns_none_for_ready_times() { - assert!(Task::delay_until(Utc::now() - ChronoDuration::milliseconds(1)).is_none()); - } - - #[test] - fn delay_until_returns_duration_for_future_times() { - let delay = Task::delay_until(Utc::now() + ChronoDuration::milliseconds(250)).unwrap(); - - assert!(delay <= Duration::from_millis(250)); - assert!(delay > Duration::ZERO); - } - - #[sqlx::test(migrations = "./migrations")] - async fn claim_marks_invalid_steps_errored(pool: PgPool) { - sqlx::query!( - "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2)", - "not-json", - Utc::now(), - ) - .execute(&pool) - .await - .unwrap(); - - let mut tx = pool.begin().await.unwrap(); - let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - - assert!(task - .claim::(&mut tx, worker_lease()) - .await - .unwrap() - .is_none()); - - tx.commit().await.unwrap(); - - let row = - sqlx::query!("SELECT tried, locked_by, lock_expires_at, error FROM pg_task LIMIT 1") - .fetch_one(&pool) - .await - .unwrap(); - - assert_eq!(row.tried, 0); - assert!(row.locked_by.is_none()); - assert!(row.lock_expires_at.is_none()); - assert!(row.error.is_some()); - } - - #[test] - fn unserializable_deserializes_from_unit() { - serde_json::from_str::("null").unwrap(); - } - - #[tokio::test] - async fn followup_step_returns_none() { - let pool = sqlx::postgres::PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(); - - assert!(matches!( - crate::Step::::step(TestTask::Followup(Followup { value: 7 }), &pool) - .await - .unwrap(), - NextStep::None - )); - } - - #[tokio::test] - #[should_panic(expected = "the test never executes the unserializable step")] - async fn unserializable_step_panics_if_executed() { - let pool = sqlx::postgres::PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(); - - let _ = - crate::Step::::step(TestTask::Unserializable(Unserializable), &pool).await; - } - - #[sqlx::test(migrations = "./migrations")] - async fn fetch_ready_returns_db_errors_for_query_failures(pool: PgPool) { - insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN step TO task_step") - .execute(&pool) - .await - .unwrap(); - - let mut tx = pool.begin().await.unwrap(); - let err = Task::fetch_ready(&mut tx).await.unwrap_err(); - - assert_database_error(err); - } - - #[sqlx::test(migrations = "./migrations")] - async fn mark_running_returns_db_errors_for_update_failures(pool: PgPool) { - let id = insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; - let task = task_with_step(id, &TestTask::Valid(Valid), 0); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN locked_by TO task_locked_by") - .execute(&pool) - .await - .unwrap(); - - let mut tx = pool.begin().await.unwrap(); - let err = task - .mark_running(&mut tx, worker_lease()) - .await - .unwrap_err(); - - assert_database_error(err); - } - - #[sqlx::test(migrations = "./migrations")] - async fn claim_returns_db_errors_when_saving_deserialization_failures_fails(pool: PgPool) { - let id = insert_task_row(&pool, "not-json", Utc::now(), 0, false, None).await; - let task = raw_task(id, "not-json", 0); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN error TO task_error") - .execute(&pool) - .await - .unwrap(); - - let mut tx = pool.begin().await.unwrap(); - let err = task - .claim::(&mut tx, worker_lease()) - .await - .unwrap_err(); - - assert_database_error(err); - } - - #[sqlx::test(migrations = "./migrations")] - async fn claim_marks_valid_steps_leased(pool: PgPool) { - let id = insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; - - let started_at = Utc::now(); - let mut tx = pool.begin().await.unwrap(); - let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - let claimed = task - .claim::(&mut tx, worker_lease()) - .await - .unwrap(); - tx.commit().await.unwrap(); - let finished_at = Utc::now(); - - assert!(matches!(claimed, Some(TestTask::Valid(Valid)))); - - let row = fetch_task_row(&pool, id).await.unwrap(); - assert_eq!(row.step, serialized_step(&TestTask::Valid(Valid))); - assert_eq!(row.tried, 0); - assert_eq!(row.locked_by, Some(worker_id())); - assert_timestamp_between( - row.lock_expires_at.unwrap(), - started_at + ChronoDuration::seconds(60), - finished_at + ChronoDuration::seconds(61), - ); - assert!(row.error.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn task_lease_columns_must_be_set_together(pool: PgPool) { - let valid = serialized_step(&TestTask::Valid(Valid)); - let now = Utc::now(); - - let err = sqlx::query!( - " - INSERT INTO pg_task (step, wakeup_at, locked_by) - VALUES ($1, $2, $3) - ", - &valid, - now, - worker_id(), - ) - .execute(&pool) - .await - .unwrap_err(); - assert!(matches!(err, sqlx::Error::Database(_))); - - let err = sqlx::query!( - " - INSERT INTO pg_task (step, wakeup_at, lock_expires_at) - VALUES ($1, $2, $3) - ", - &valid, - now, - now, - ) - .execute(&pool) - .await - .unwrap_err(); - assert!(matches!(err, sqlx::Error::Database(_))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn renew_leases_extends_only_live_owned_leases(pool: PgPool) { - let now = Utc::now(); - let valid = serialized_step(&TestTask::Valid(Valid)); - let owned = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - let owned_expires_at = now + ChronoDuration::seconds(30); - set_task_lease(&pool, owned, worker_id(), owned_expires_at).await; - let expired = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - let expired_expires_at = now - ChronoDuration::seconds(1); - set_task_lease(&pool, expired, worker_id(), expired_expires_at).await; - let other_worker = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - let other_worker_expires_at = now + ChronoDuration::seconds(30); - set_task_lease( - &pool, - other_worker, - other_worker_id(), - other_worker_expires_at, - ) - .await; - - let started_at = Utc::now(); - let renewed = Task::renew_leases(&pool, worker_lease(), &[owned, expired, other_worker]) - .await - .unwrap(); - let finished_at = Utc::now(); - - assert_eq!(renewed, vec![owned]); - let owned = fetch_task_row(&pool, owned).await.unwrap(); - assert_timestamp_between( - owned.lock_expires_at.unwrap(), - started_at + ChronoDuration::seconds(60), - finished_at + ChronoDuration::seconds(61), - ); - let expired = fetch_task_row(&pool, expired).await.unwrap(); - assert!(expired.lock_expires_at.unwrap() < started_at); - let other_worker = fetch_task_row(&pool, other_worker).await.unwrap(); - assert!(other_worker.lock_expires_at.unwrap() < started_at + ChronoDuration::seconds(45)); - } - - #[sqlx::test(migrations = "./migrations")] - async fn renew_leases_returns_zero_when_no_live_owned_leases_exist(pool: PgPool) { - let now = Utc::now(); - let valid = serialized_step(&TestTask::Valid(Valid)); - let expired = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - set_task_lease( - &pool, - expired, - worker_id(), - now - ChronoDuration::seconds(1), - ) - .await; - let other_worker = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - set_task_lease( - &pool, - other_worker, - other_worker_id(), - now + ChronoDuration::seconds(30), - ) - .await; - - assert!( - Task::renew_leases(&pool, worker_lease(), &[expired, other_worker]) - .await - .unwrap() - .is_empty() - ); - } - - #[sqlx::test(migrations = "./migrations")] - async fn fetch_ready_ignores_leased_errored_and_future_tasks_and_picks_the_earliest_ready_one( - pool: PgPool, - ) { - let now = Utc::now(); - let valid = serialized_step(&TestTask::Valid(Valid)); - - let live_lease = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(3), - 0, - true, - None, - ) - .await; - set_task_lease( - &pool, - live_lease, - other_worker_id(), - now + ChronoDuration::seconds(60), - ) - .await; - insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(2), - 0, - false, - Some("boom"), - ) - .await; - let expected = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - insert_task_row( - &pool, - &valid, - now + ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - - let mut tx = pool.begin().await.unwrap(); - let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - tx.commit().await.unwrap(); - - assert_eq!(task.id, expected); - assert_eq!(task.step, valid); - assert_eq!(task.tried, 0); - } - - #[sqlx::test(migrations = "./migrations")] - async fn fetch_ready_returns_expired_leased_tasks(pool: PgPool) { - let now = Utc::now(); - let valid = serialized_step(&TestTask::Valid(Valid)); - let expected = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - true, - None, - ) - .await; - set_task_lease( - &pool, - expected, - other_worker_id(), - now - ChronoDuration::seconds(1), - ) - .await; - - let mut tx = pool.begin().await.unwrap(); - let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); - tx.commit().await.unwrap(); - - assert_eq!(task.id, expected); - } - - #[sqlx::test(migrations = "./migrations")] - async fn fetch_ready_returns_none_when_no_tasks_are_ready(pool: PgPool) { - let now = Utc::now(); - let valid = serialized_step(&TestTask::Valid(Valid)); - insert_task_row( - &pool, - &valid, - now + ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - false, - Some("boom"), - ) - .await; - - let mut tx = pool.begin().await.unwrap(); - assert!(Task::fetch_ready(&mut tx).await.unwrap().is_none()); - tx.commit().await.unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn fetch_next_available_at_returns_none_when_no_tasks_are_visible(pool: PgPool) { - let mut tx = pool.begin().await.unwrap(); - assert!(Task::fetch_next_available_at(&mut tx) - .await - .unwrap() - .is_none()); - tx.commit().await.unwrap(); - - insert_task_row( - &pool, - &serialized_step(&TestTask::Valid(Valid)), - Utc::now() - ChronoDuration::seconds(1), - 0, - false, - Some("boom"), - ) - .await; - - let mut tx = pool.begin().await.unwrap(); - assert!(Task::fetch_next_available_at(&mut tx) - .await - .unwrap() - .is_none()); - tx.commit().await.unwrap(); - } - - #[sqlx::test(migrations = "./migrations")] - async fn fetch_next_available_at_returns_the_earliest_visible_eligible_task(pool: PgPool) { - let now = Utc::now(); - let valid = serialized_step(&TestTask::Valid(Valid)); - let live_lease = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(3), - 0, - true, - None, - ) - .await; - set_task_lease( - &pool, - live_lease, - other_worker_id(), - now + ChronoDuration::seconds(2), - ) - .await; - insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(2), - 0, - false, - Some("boom"), - ) - .await; - insert_task_row( - &pool, - &valid, - now + ChronoDuration::seconds(2), - 0, - false, - None, - ) - .await; - let expected = insert_task_row( - &pool, - &valid, - now + ChronoDuration::seconds(1), - 0, - false, - None, - ) - .await; - - let mut tx = pool.begin().await.unwrap(); - let wakeup_at = Task::fetch_next_available_at(&mut tx) - .await - .unwrap() - .unwrap(); - tx.commit().await.unwrap(); - - let row = fetch_task_row(&pool, expected).await.unwrap(); - assert_eq!(wakeup_at, row.wakeup_at); - } - - #[sqlx::test(migrations = "./migrations")] - async fn fetch_next_available_at_returns_lease_expiry_for_ready_leased_tasks(pool: PgPool) { - let now = Utc::now(); - let valid = serialized_step(&TestTask::Valid(Valid)); - let expected = insert_task_row( - &pool, - &valid, - now - ChronoDuration::seconds(1), - 0, - true, - None, - ) - .await; - set_task_lease( - &pool, - expected, - other_worker_id(), - now + ChronoDuration::seconds(5), - ) - .await; - - let mut tx = pool.begin().await.unwrap(); - let wakeup_at = Task::fetch_next_available_at(&mut tx) - .await - .unwrap() - .unwrap(); - tx.commit().await.unwrap(); - - let row = fetch_task_row(&pool, expected).await.unwrap(); - assert_eq!(wakeup_at, row.lock_expires_at.unwrap()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn fetch_next_available_at_keeps_future_leased_tasks_delayed_until_wakeup(pool: PgPool) { - let now = Utc::now(); - let valid = serialized_step(&TestTask::Valid(Valid)); - let expected = insert_task_row( - &pool, - &valid, - now + ChronoDuration::seconds(5), - 0, - true, - None, - ) - .await; - set_task_lease( - &pool, - expected, - other_worker_id(), - now + ChronoDuration::seconds(1), - ) - .await; - - let mut tx = pool.begin().await.unwrap(); - let wakeup_at = Task::fetch_next_available_at(&mut tx) - .await - .unwrap() - .unwrap(); - tx.commit().await.unwrap(); - - let row = fetch_task_row(&pool, expected).await.unwrap(); - assert_eq!(wakeup_at, row.wakeup_at); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_completes_tasks(pool: PgPool) { - init_tracing(); - let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; - - task.run_step(&pool, step, lease).await.unwrap(); - - assert!(fetch_task_row(&pool, task.id).await.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_does_not_complete_tasks_after_losing_the_lease(pool: PgPool) { - init_tracing(); - let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; - set_task_lease( - &pool, - task.id, - other_worker_id(), - Utc::now() + ChronoDuration::seconds(60), - ) - .await; - - task.run_step(&pool, step, lease).await.unwrap(); - - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!(row.locked_by, Some(other_worker_id())); - assert!(row.error.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_does_not_complete_tasks_after_the_lease_expires(pool: PgPool) { - init_tracing(); - let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; - set_task_lease( - &pool, - task.id, - worker_id(), - Utc::now() - ChronoDuration::seconds(1), - ) - .await; - - task.run_step(&pool, step, lease).await.unwrap(); - - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!(row.locked_by, Some(worker_id())); - assert!(row.lock_expires_at.unwrap() < Utc::now()); - assert!(row.error.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_returns_db_errors_when_completing_tasks_fails(pool: PgPool) { - let step = TestTask::Valid(Valid); - let id = insert_task(&pool, &step, 0, false).await; - let task = task_with_step(id, &step, 0); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") - .execute(&pool) - .await - .unwrap(); - - let err = task - .run_step(&pool, step, worker_lease()) - .await - .unwrap_err(); - - assert_database_error(err); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_saves_immediate_next_step_and_resets_retries(pool: PgPool) { - let (task, step, lease) = - claim_task(&pool, TestTask::AdvanceNow(AdvanceNow { value: 41 }), 2).await; - let started_at = Utc::now(); - - task.run_step(&pool, step, lease).await.unwrap(); - - let finished_at = Utc::now(); - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!( - row.step, - serialized_step(&TestTask::Followup(Followup { value: 42 })), - ); - assert_eq!(row.tried, 0); - assert!(row.locked_by.is_none()); - assert!(row.lock_expires_at.is_none()); - assert!(row.error.is_none()); - assert_timestamp_between( - row.wakeup_at, - started_at, - finished_at + ChronoDuration::seconds(1), - ); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_saves_delayed_next_step_and_resets_retries(pool: PgPool) { - let delay = Duration::from_millis(250); - let (task, step, lease) = claim_task( - &pool, - TestTask::AdvanceLater(AdvanceLater { - value: 9, - delay_ms: delay.as_millis() as u64, - }), - 3, - ) - .await; - let started_at = Utc::now(); - - task.run_step(&pool, step, lease).await.unwrap(); - - let finished_at = Utc::now(); - let row = fetch_task_row(&pool, task.id).await.unwrap(); - let delay = ChronoDuration::from_std(delay).unwrap(); - assert_eq!( - row.step, - serialized_step(&TestTask::Followup(Followup { value: 10 })), - ); - assert_eq!(row.tried, 0); - assert!(row.locked_by.is_none()); - assert!(row.lock_expires_at.is_none()); - assert!(row.error.is_none()); - assert_timestamp_between( - row.wakeup_at, - started_at + delay, - finished_at + delay + ChronoDuration::seconds(1), - ); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_does_not_save_next_step_after_losing_the_lease(pool: PgPool) { - init_tracing(); - let (task, step, lease) = - claim_task(&pool, TestTask::AdvanceNow(AdvanceNow { value: 41 }), 2).await; - set_task_lease( - &pool, - task.id, - other_worker_id(), - Utc::now() + ChronoDuration::seconds(60), - ) - .await; - - task.run_step(&pool, step, lease).await.unwrap(); - - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!( - row.step, - serialized_step(&TestTask::AdvanceNow(AdvanceNow { value: 41 })), - ); - assert_eq!(row.tried, 2); - assert_eq!(row.locked_by, Some(other_worker_id())); - assert!(row.error.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_returns_db_errors_when_saving_next_step_fails(pool: PgPool) { - let step = TestTask::AdvanceNow(AdvanceNow { value: 41 }); - let id = insert_task(&pool, &step, 2, false).await; - let task = task_with_step(id, &step, 2); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN step TO task_step") - .execute(&pool) - .await - .unwrap(); - - let err = task - .run_step(&pool, step, worker_lease()) - .await - .unwrap_err(); - - assert_database_error(err); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_schedules_retries_before_the_retry_limit(pool: PgPool) { - init_tracing(); - let retry_delay = >::RETRY_DELAY; - let (task, step, lease) = claim_task(&pool, TestTask::RetryFail(RetryFail), 1).await; - let started_at = Utc::now(); - - task.run_step(&pool, step, lease).await.unwrap(); - - let finished_at = Utc::now(); - let row = fetch_task_row(&pool, task.id).await.unwrap(); - let retry_delay = ChronoDuration::from_std(retry_delay).unwrap(); - assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); - assert_eq!(row.tried, 2); - assert!(row.locked_by.is_none()); - assert!(row.lock_expires_at.is_none()); - assert!(row.error.is_none()); - assert_timestamp_between( - row.wakeup_at, - started_at + retry_delay, - finished_at + retry_delay + ChronoDuration::seconds(1), - ); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_does_not_schedule_retries_after_losing_the_lease(pool: PgPool) { - init_tracing(); - let (task, step, lease) = claim_task(&pool, TestTask::RetryFail(RetryFail), 1).await; - set_task_lease( - &pool, - task.id, - other_worker_id(), - Utc::now() + ChronoDuration::seconds(60), - ) - .await; - - task.run_step(&pool, step, lease).await.unwrap(); - - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); - assert_eq!(row.tried, 1); - assert_eq!(row.locked_by, Some(other_worker_id())); - assert!(row.error.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_returns_db_errors_when_retrying_fails(pool: PgPool) { - let step = TestTask::RetryFail(RetryFail); - let id = insert_task(&pool, &step, 1, false).await; - let task = task_with_step(id, &step, 1); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN wakeup_at TO scheduled_at") - .execute(&pool) - .await - .unwrap(); - - let err = task - .run_step(&pool, step, worker_lease()) - .await - .unwrap_err(); - - assert_database_error(err); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_saves_terminal_errors_after_retry_limit(pool: PgPool) { - init_tracing(); - let retry_limit = >::RETRY_LIMIT; - let (task, step, lease) = - claim_task(&pool, TestTask::RetryFail(RetryFail), retry_limit).await; - let started_at = Utc::now(); - - task.run_step(&pool, step, lease).await.unwrap(); - - let finished_at = Utc::now(); - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); - assert_eq!(row.tried, retry_limit + 1); - assert!(row.locked_by.is_none()); - assert!(row.lock_expires_at.is_none()); - assert!(row - .error - .as_deref() - .is_some_and(|error| error.contains("retryable failure"))); - assert_timestamp_between( - row.wakeup_at, - started_at, - finished_at + ChronoDuration::seconds(1), - ); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_does_not_save_terminal_errors_after_losing_the_lease(pool: PgPool) { - init_tracing(); - let retry_limit = >::RETRY_LIMIT; - let (task, step, lease) = - claim_task(&pool, TestTask::RetryFail(RetryFail), retry_limit).await; - set_task_lease( - &pool, - task.id, - other_worker_id(), - Utc::now() + ChronoDuration::seconds(60), - ) - .await; - - task.run_step(&pool, step, lease).await.unwrap(); - - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); - assert_eq!(row.tried, retry_limit); - assert_eq!(row.locked_by, Some(other_worker_id())); - assert!(row.error.is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_returns_db_errors_when_saving_terminal_errors_fails(pool: PgPool) { - let step = TestTask::RetryFail(RetryFail); - let retry_limit = >::RETRY_LIMIT; - let id = insert_task(&pool, &step, retry_limit, false).await; - let task = task_with_step(id, &step, retry_limit); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN error TO task_error") - .execute(&pool) - .await - .unwrap(); - - let err = task - .run_step(&pool, step, worker_lease()) - .await - .unwrap_err(); - - assert_database_error(err); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_saves_next_step_serialization_failures_as_errors(pool: PgPool) { - let (task, step, lease) = claim_task(&pool, TestTask::BrokenNext(BrokenNext), 0).await; - - task.run_step(&pool, step, lease).await.unwrap(); - - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!(row.step, serialized_step(&TestTask::BrokenNext(BrokenNext)),); - assert_eq!(row.tried, 1); - assert!(row.locked_by.is_none()); - assert!(row.lock_expires_at.is_none()); - assert!(row - .error - .as_deref() - .is_some_and(|error| error.contains("can't serialize test step"))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn run_step_does_not_save_next_step_serialization_errors_after_losing_the_lease( - pool: PgPool, - ) { - init_tracing(); - let (task, step, lease) = claim_task(&pool, TestTask::BrokenNext(BrokenNext), 0).await; - set_task_lease( - &pool, - task.id, - other_worker_id(), - Utc::now() + ChronoDuration::seconds(60), - ) - .await; - - task.run_step(&pool, step, lease).await.unwrap(); - - let row = fetch_task_row(&pool, task.id).await.unwrap(); - assert_eq!(row.step, serialized_step(&TestTask::BrokenNext(BrokenNext)),); - assert_eq!(row.tried, 0); - assert_eq!(row.locked_by, Some(other_worker_id())); - assert!(row.error.is_none()); - } -} +mod tests; diff --git a/src/task/tests.rs b/src/task/tests.rs new file mode 100644 index 0000000..1aba541 --- /dev/null +++ b/src/task/tests.rs @@ -0,0 +1,1210 @@ +use super::{Task, WorkerLease}; +use crate::{NextStep, Step}; +use chrono::{DateTime, Duration as ChronoDuration, Utc}; +use sqlx::PgPool; +use std::{io, time::Duration}; +use uuid::Uuid; + +fn init_tracing() { + static INIT: std::sync::Once = std::sync::Once::new(); + INIT.call_once(|| { + let _ = tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .with_test_writer() + .without_time() + .try_init(); + }); +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct Valid; + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct AdvanceNow { + value: i32, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct AdvanceLater { + value: i32, + delay_ms: u64, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct Followup { + value: i32, +} + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct RetryFail; + +#[derive(Debug, serde::Deserialize, serde::Serialize)] +pub(super) struct BrokenNext; + +#[derive(Debug)] +pub(super) struct Unserializable; + +crate::task!(TestTask { + Valid, + AdvanceNow, + AdvanceLater, + Followup, + RetryFail, + BrokenNext, + Unserializable, +}); + +#[async_trait::async_trait] +impl Step for Valid { + async fn step(self, _db: &PgPool) -> crate::StepResult { + Ok(NextStep::None) + } +} + +#[async_trait::async_trait] +impl Step for AdvanceNow { + async fn step(self, _db: &PgPool) -> crate::StepResult { + NextStep::now(Followup { + value: self.value + 1, + }) + } +} + +#[async_trait::async_trait] +impl Step for AdvanceLater { + async fn step(self, _db: &PgPool) -> crate::StepResult { + NextStep::delay( + Followup { + value: self.value + 1, + }, + Duration::from_millis(self.delay_ms), + ) + } +} + +#[async_trait::async_trait] +impl Step for Followup { + async fn step(self, _db: &PgPool) -> crate::StepResult { + NextStep::none() + } +} + +#[async_trait::async_trait] +impl Step for RetryFail { + const RETRY_LIMIT: i32 = 2; + const RETRY_DELAY: Duration = Duration::from_millis(250); + + async fn step(self, _db: &PgPool) -> crate::StepResult { + Err(io::Error::other("retryable failure").into()) + } +} + +#[async_trait::async_trait] +impl Step for BrokenNext { + async fn step(self, _db: &PgPool) -> crate::StepResult { + NextStep::now(Unserializable) + } +} + +impl serde::Serialize for Unserializable { + fn serialize(&self, _serializer: S) -> Result + where + S: serde::Serializer, + { + Err(serde::ser::Error::custom("can't serialize test step")) + } +} + +impl<'de> serde::Deserialize<'de> for Unserializable { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + <() as serde::Deserialize>::deserialize(deserializer)?; + Ok(Self) + } +} + +#[async_trait::async_trait] +impl Step for Unserializable { + async fn step(self, _db: &PgPool) -> crate::StepResult { + unreachable!("the test never executes the unserializable step") + } +} + +#[derive(Debug)] +struct TaskRow { + step: String, + wakeup_at: DateTime, + tried: i32, + locked_by: Option, + lock_expires_at: Option>, + error: Option, +} + +fn serialized_step(step: &TestTask) -> String { + serde_json::to_string(step).unwrap() +} + +fn task_with_step(id: Uuid, step: &TestTask, tried: i32) -> Task { + Task { + id, + step: serialized_step(step), + tried, + } +} + +fn raw_task(id: Uuid, step: &str, tried: i32) -> Task { + Task { + id, + step: step.into(), + tried, + } +} + +fn assert_database_error(err: crate::Error) { + assert!(matches!(err, crate::Error::Db(sqlx::Error::Database(_), _))); +} + +fn worker_id() -> Uuid { + Uuid::from_u128(1) +} + +fn other_worker_id() -> Uuid { + Uuid::from_u128(2) +} + +fn worker_lease() -> WorkerLease { + WorkerLease::new(worker_id(), Duration::from_secs(60)) +} + +#[test] +fn worker_lease_converts_timeout_to_microsecond_interval() { + let lease = WorkerLease::new(worker_id(), Duration::from_micros(42)); + assert_eq!(lease.timeout.microseconds, 42); +} + +#[test] +fn worker_lease_rounds_timeout_up_to_microseconds() { + let lease = WorkerLease::new(worker_id(), Duration::from_nanos(1)); + assert_eq!(lease.timeout.microseconds, 1); +} + +#[test] +fn worker_lease_saturates_large_timeouts() { + let lease = WorkerLease::new(worker_id(), Duration::MAX); + assert_eq!(lease.timeout.microseconds, i64::MAX); +} + +async fn insert_task_row( + pool: &PgPool, + step: &str, + wakeup_at: DateTime, + tried: i32, + is_leased: bool, + error: Option<&str>, +) -> Uuid { + let (locked_by, lock_expires_at) = if is_leased { + ( + Some(other_worker_id()), + Some(Utc::now() + ChronoDuration::seconds(60)), + ) + } else { + (None, None) + }; + sqlx::query!( + " + INSERT INTO pg_task (step, wakeup_at, tried, locked_by, lock_expires_at, error) + VALUES ($1, $2, $3, $4, $5, $6) + RETURNING id + ", + step, + wakeup_at, + tried, + locked_by, + lock_expires_at, + error, + ) + .fetch_one(pool) + .await + .unwrap() + .id +} + +async fn insert_task(pool: &PgPool, step: &TestTask, tried: i32, is_leased: bool) -> Uuid { + let step = serialized_step(step); + insert_task_row( + pool, + &step, + Utc::now() - ChronoDuration::milliseconds(1), + tried, + is_leased, + None, + ) + .await +} + +async fn set_task_lease(pool: &PgPool, id: Uuid, worker_id: Uuid, lock_expires_at: DateTime) { + sqlx::query!( + r#" + UPDATE pg_task + SET locked_by = $2, + lock_expires_at = $3 + WHERE id = $1 + "#, + id, + worker_id, + lock_expires_at, + ) + .execute(pool) + .await + .unwrap(); +} + +async fn claim_task(pool: &PgPool, step: TestTask, tried: i32) -> (Task, TestTask, WorkerLease) { + let id = insert_task(pool, &step, tried, false).await; + let mut tx = pool.begin().await.unwrap(); + let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); + assert_eq!(task.id, id); + let lease = worker_lease(); + let claimed = task + .claim::(&mut tx, lease) + .await + .unwrap() + .unwrap(); + tx.commit().await.unwrap(); + (task, claimed, lease) +} + +async fn fetch_task_row(pool: &PgPool, id: Uuid) -> Option { + sqlx::query!( + " + SELECT step, wakeup_at, tried, locked_by, lock_expires_at, error + FROM pg_task + WHERE id = $1 + ", + id, + ) + .fetch_optional(pool) + .await + .unwrap() + .map(|row| TaskRow { + step: row.step, + wakeup_at: row.wakeup_at, + tried: row.tried, + locked_by: row.locked_by, + lock_expires_at: row.lock_expires_at, + error: row.error, + }) +} + +fn assert_timestamp_between(actual: DateTime, earliest: DateTime, latest: DateTime) { + assert!( + actual >= earliest, + "timestamp {actual:?} should be after {earliest:?}", + ); + assert!( + actual <= latest, + "timestamp {actual:?} should be before {latest:?}", + ); +} + +#[test] +fn delay_until_returns_none_for_ready_times() { + assert!(Task::delay_until(Utc::now() - ChronoDuration::milliseconds(1)).is_none()); +} + +#[test] +fn delay_until_returns_duration_for_future_times() { + let delay = Task::delay_until(Utc::now() + ChronoDuration::milliseconds(250)).unwrap(); + + assert!(delay <= Duration::from_millis(250)); + assert!(delay > Duration::ZERO); +} + +#[sqlx::test(migrations = "./migrations")] +async fn claim_marks_invalid_steps_errored(pool: PgPool) { + sqlx::query!( + "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2)", + "not-json", + Utc::now(), + ) + .execute(&pool) + .await + .unwrap(); + + let mut tx = pool.begin().await.unwrap(); + let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); + + assert!(task + .claim::(&mut tx, worker_lease()) + .await + .unwrap() + .is_none()); + + tx.commit().await.unwrap(); + + let row = sqlx::query!("SELECT tried, locked_by, lock_expires_at, error FROM pg_task LIMIT 1") + .fetch_one(&pool) + .await + .unwrap(); + + assert_eq!(row.tried, 0); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); + assert!(row.error.is_some()); +} + +#[test] +fn unserializable_deserializes_from_unit() { + serde_json::from_str::("null").unwrap(); +} + +#[tokio::test] +async fn followup_step_returns_none() { + let pool = sqlx::postgres::PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(); + + assert!(matches!( + crate::Step::::step(TestTask::Followup(Followup { value: 7 }), &pool) + .await + .unwrap(), + NextStep::None + )); +} + +#[tokio::test] +#[should_panic(expected = "the test never executes the unserializable step")] +async fn unserializable_step_panics_if_executed() { + let pool = sqlx::postgres::PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(); + + let _ = crate::Step::::step(TestTask::Unserializable(Unserializable), &pool).await; +} + +#[sqlx::test(migrations = "./migrations")] +async fn fetch_ready_returns_db_errors_for_query_failures(pool: PgPool) { + insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN step TO task_step") + .execute(&pool) + .await + .unwrap(); + + let mut tx = pool.begin().await.unwrap(); + let err = Task::fetch_ready(&mut tx).await.unwrap_err(); + + assert_database_error(err); +} + +#[sqlx::test(migrations = "./migrations")] +async fn mark_running_returns_db_errors_for_update_failures(pool: PgPool) { + let id = insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; + let task = task_with_step(id, &TestTask::Valid(Valid), 0); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN locked_by TO task_locked_by") + .execute(&pool) + .await + .unwrap(); + + let mut tx = pool.begin().await.unwrap(); + let err = task + .mark_running(&mut tx, worker_lease()) + .await + .unwrap_err(); + + assert_database_error(err); +} + +#[sqlx::test(migrations = "./migrations")] +async fn claim_returns_db_errors_when_saving_deserialization_failures_fails(pool: PgPool) { + let id = insert_task_row(&pool, "not-json", Utc::now(), 0, false, None).await; + let task = raw_task(id, "not-json", 0); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN error TO task_error") + .execute(&pool) + .await + .unwrap(); + + let mut tx = pool.begin().await.unwrap(); + let err = task + .claim::(&mut tx, worker_lease()) + .await + .unwrap_err(); + + assert_database_error(err); +} + +#[sqlx::test(migrations = "./migrations")] +async fn claim_marks_valid_steps_leased(pool: PgPool) { + let id = insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; + + let started_at = Utc::now(); + let mut tx = pool.begin().await.unwrap(); + let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); + let claimed = task + .claim::(&mut tx, worker_lease()) + .await + .unwrap(); + tx.commit().await.unwrap(); + let finished_at = Utc::now(); + + assert!(matches!(claimed, Some(TestTask::Valid(Valid)))); + + let row = fetch_task_row(&pool, id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::Valid(Valid))); + assert_eq!(row.tried, 0); + assert_eq!(row.locked_by, Some(worker_id())); + assert_timestamp_between( + row.lock_expires_at.unwrap(), + started_at + ChronoDuration::seconds(60), + finished_at + ChronoDuration::seconds(61), + ); + assert!(row.error.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn task_lease_columns_must_be_set_together(pool: PgPool) { + let valid = serialized_step(&TestTask::Valid(Valid)); + let now = Utc::now(); + + let err = sqlx::query!( + " + INSERT INTO pg_task (step, wakeup_at, locked_by) + VALUES ($1, $2, $3) + ", + &valid, + now, + worker_id(), + ) + .execute(&pool) + .await + .unwrap_err(); + assert!(matches!(err, sqlx::Error::Database(_))); + + let err = sqlx::query!( + " + INSERT INTO pg_task (step, wakeup_at, lock_expires_at) + VALUES ($1, $2, $3) + ", + &valid, + now, + now, + ) + .execute(&pool) + .await + .unwrap_err(); + assert!(matches!(err, sqlx::Error::Database(_))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn renew_leases_extends_only_live_owned_leases(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let owned = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + let owned_expires_at = now + ChronoDuration::seconds(30); + set_task_lease(&pool, owned, worker_id(), owned_expires_at).await; + let expired = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + let expired_expires_at = now - ChronoDuration::seconds(1); + set_task_lease(&pool, expired, worker_id(), expired_expires_at).await; + let other_worker = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + let other_worker_expires_at = now + ChronoDuration::seconds(30); + set_task_lease( + &pool, + other_worker, + other_worker_id(), + other_worker_expires_at, + ) + .await; + + let started_at = Utc::now(); + let renewed = Task::renew_leases(&pool, worker_lease(), &[owned, expired, other_worker]) + .await + .unwrap(); + let finished_at = Utc::now(); + + assert_eq!(renewed, vec![owned]); + let owned = fetch_task_row(&pool, owned).await.unwrap(); + assert_timestamp_between( + owned.lock_expires_at.unwrap(), + started_at + ChronoDuration::seconds(60), + finished_at + ChronoDuration::seconds(61), + ); + let expired = fetch_task_row(&pool, expired).await.unwrap(); + assert!(expired.lock_expires_at.unwrap() < started_at); + let other_worker = fetch_task_row(&pool, other_worker).await.unwrap(); + assert!(other_worker.lock_expires_at.unwrap() < started_at + ChronoDuration::seconds(45)); +} + +#[sqlx::test(migrations = "./migrations")] +async fn renew_leases_returns_zero_when_no_live_owned_leases_exist(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let expired = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + set_task_lease( + &pool, + expired, + worker_id(), + now - ChronoDuration::seconds(1), + ) + .await; + let other_worker = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + set_task_lease( + &pool, + other_worker, + other_worker_id(), + now + ChronoDuration::seconds(30), + ) + .await; + + assert!( + Task::renew_leases(&pool, worker_lease(), &[expired, other_worker]) + .await + .unwrap() + .is_empty() + ); +} + +#[sqlx::test(migrations = "./migrations")] +async fn fetch_ready_ignores_leased_errored_and_future_tasks_and_picks_the_earliest_ready_one( + pool: PgPool, +) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + + let live_lease = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(3), + 0, + true, + None, + ) + .await; + set_task_lease( + &pool, + live_lease, + other_worker_id(), + now + ChronoDuration::seconds(60), + ) + .await; + insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(2), + 0, + false, + Some("boom"), + ) + .await; + let expected = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + insert_task_row( + &pool, + &valid, + now + ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); + tx.commit().await.unwrap(); + + assert_eq!(task.id, expected); + assert_eq!(task.step, valid); + assert_eq!(task.tried, 0); +} + +#[sqlx::test(migrations = "./migrations")] +async fn fetch_ready_returns_expired_leased_tasks(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let expected = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + true, + None, + ) + .await; + set_task_lease( + &pool, + expected, + other_worker_id(), + now - ChronoDuration::seconds(1), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + let task = Task::fetch_ready(&mut tx).await.unwrap().unwrap(); + tx.commit().await.unwrap(); + + assert_eq!(task.id, expected); +} + +#[sqlx::test(migrations = "./migrations")] +async fn fetch_ready_returns_none_when_no_tasks_are_ready(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + insert_task_row( + &pool, + &valid, + now + ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + false, + Some("boom"), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + assert!(Task::fetch_ready(&mut tx).await.unwrap().is_none()); + tx.commit().await.unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn fetch_next_available_at_returns_none_when_no_tasks_are_visible(pool: PgPool) { + let mut tx = pool.begin().await.unwrap(); + assert!(Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .is_none()); + tx.commit().await.unwrap(); + + insert_task_row( + &pool, + &serialized_step(&TestTask::Valid(Valid)), + Utc::now() - ChronoDuration::seconds(1), + 0, + false, + Some("boom"), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + assert!(Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .is_none()); + tx.commit().await.unwrap(); +} + +#[sqlx::test(migrations = "./migrations")] +async fn fetch_next_available_at_returns_the_earliest_visible_eligible_task(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let live_lease = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(3), + 0, + true, + None, + ) + .await; + set_task_lease( + &pool, + live_lease, + other_worker_id(), + now + ChronoDuration::seconds(2), + ) + .await; + insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(2), + 0, + false, + Some("boom"), + ) + .await; + insert_task_row( + &pool, + &valid, + now + ChronoDuration::seconds(2), + 0, + false, + None, + ) + .await; + let expected = insert_task_row( + &pool, + &valid, + now + ChronoDuration::seconds(1), + 0, + false, + None, + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + let wakeup_at = Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .unwrap(); + tx.commit().await.unwrap(); + + let row = fetch_task_row(&pool, expected).await.unwrap(); + assert_eq!(wakeup_at, row.wakeup_at); +} + +#[sqlx::test(migrations = "./migrations")] +async fn fetch_next_available_at_returns_lease_expiry_for_ready_leased_tasks(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let expected = insert_task_row( + &pool, + &valid, + now - ChronoDuration::seconds(1), + 0, + true, + None, + ) + .await; + set_task_lease( + &pool, + expected, + other_worker_id(), + now + ChronoDuration::seconds(5), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + let wakeup_at = Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .unwrap(); + tx.commit().await.unwrap(); + + let row = fetch_task_row(&pool, expected).await.unwrap(); + assert_eq!(wakeup_at, row.lock_expires_at.unwrap()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn fetch_next_available_at_keeps_future_leased_tasks_delayed_until_wakeup(pool: PgPool) { + let now = Utc::now(); + let valid = serialized_step(&TestTask::Valid(Valid)); + let expected = insert_task_row( + &pool, + &valid, + now + ChronoDuration::seconds(5), + 0, + true, + None, + ) + .await; + set_task_lease( + &pool, + expected, + other_worker_id(), + now + ChronoDuration::seconds(1), + ) + .await; + + let mut tx = pool.begin().await.unwrap(); + let wakeup_at = Task::fetch_next_available_at(&mut tx) + .await + .unwrap() + .unwrap(); + tx.commit().await.unwrap(); + + let row = fetch_task_row(&pool, expected).await.unwrap(); + assert_eq!(wakeup_at, row.wakeup_at); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_completes_tasks(pool: PgPool) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; + + task.run_step(&pool, step, lease).await.unwrap(); + + assert!(fetch_task_row(&pool, task.id).await.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_does_not_complete_tasks_after_losing_the_lease(pool: PgPool) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_does_not_complete_tasks_after_the_lease_expires(pool: PgPool) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::Valid(Valid), 0).await; + set_task_lease( + &pool, + task.id, + worker_id(), + Utc::now() - ChronoDuration::seconds(1), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.locked_by, Some(worker_id())); + assert!(row.lock_expires_at.unwrap() < Utc::now()); + assert!(row.error.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_returns_db_errors_when_completing_tasks_fails(pool: PgPool) { + let step = TestTask::Valid(Valid); + let id = insert_task(&pool, &step, 0, false).await; + let task = task_with_step(id, &step, 0); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") + .execute(&pool) + .await + .unwrap(); + + let err = task + .run_step(&pool, step, worker_lease()) + .await + .unwrap_err(); + + assert_database_error(err); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_saves_immediate_next_step_and_resets_retries(pool: PgPool) { + let (task, step, lease) = + claim_task(&pool, TestTask::AdvanceNow(AdvanceNow { value: 41 }), 2).await; + let started_at = Utc::now(); + + task.run_step(&pool, step, lease).await.unwrap(); + + let finished_at = Utc::now(); + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!( + row.step, + serialized_step(&TestTask::Followup(Followup { value: 42 })), + ); + assert_eq!(row.tried, 0); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); + assert!(row.error.is_none()); + assert_timestamp_between( + row.wakeup_at, + started_at, + finished_at + ChronoDuration::seconds(1), + ); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_saves_delayed_next_step_and_resets_retries(pool: PgPool) { + let delay = Duration::from_millis(250); + let (task, step, lease) = claim_task( + &pool, + TestTask::AdvanceLater(AdvanceLater { + value: 9, + delay_ms: delay.as_millis() as u64, + }), + 3, + ) + .await; + let started_at = Utc::now(); + + task.run_step(&pool, step, lease).await.unwrap(); + + let finished_at = Utc::now(); + let row = fetch_task_row(&pool, task.id).await.unwrap(); + let delay = ChronoDuration::from_std(delay).unwrap(); + assert_eq!( + row.step, + serialized_step(&TestTask::Followup(Followup { value: 10 })), + ); + assert_eq!(row.tried, 0); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); + assert!(row.error.is_none()); + assert_timestamp_between( + row.wakeup_at, + started_at + delay, + finished_at + delay + ChronoDuration::seconds(1), + ); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_does_not_save_next_step_after_losing_the_lease(pool: PgPool) { + init_tracing(); + let (task, step, lease) = + claim_task(&pool, TestTask::AdvanceNow(AdvanceNow { value: 41 }), 2).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!( + row.step, + serialized_step(&TestTask::AdvanceNow(AdvanceNow { value: 41 })), + ); + assert_eq!(row.tried, 2); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_returns_db_errors_when_saving_next_step_fails(pool: PgPool) { + let step = TestTask::AdvanceNow(AdvanceNow { value: 41 }); + let id = insert_task(&pool, &step, 2, false).await; + let task = task_with_step(id, &step, 2); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN step TO task_step") + .execute(&pool) + .await + .unwrap(); + + let err = task + .run_step(&pool, step, worker_lease()) + .await + .unwrap_err(); + + assert_database_error(err); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_schedules_retries_before_the_retry_limit(pool: PgPool) { + init_tracing(); + let retry_delay = >::RETRY_DELAY; + let (task, step, lease) = claim_task(&pool, TestTask::RetryFail(RetryFail), 1).await; + let started_at = Utc::now(); + + task.run_step(&pool, step, lease).await.unwrap(); + + let finished_at = Utc::now(); + let row = fetch_task_row(&pool, task.id).await.unwrap(); + let retry_delay = ChronoDuration::from_std(retry_delay).unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); + assert_eq!(row.tried, 2); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); + assert!(row.error.is_none()); + assert_timestamp_between( + row.wakeup_at, + started_at + retry_delay, + finished_at + retry_delay + ChronoDuration::seconds(1), + ); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_does_not_schedule_retries_after_losing_the_lease(pool: PgPool) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::RetryFail(RetryFail), 1).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); + assert_eq!(row.tried, 1); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_returns_db_errors_when_retrying_fails(pool: PgPool) { + let step = TestTask::RetryFail(RetryFail); + let id = insert_task(&pool, &step, 1, false).await; + let task = task_with_step(id, &step, 1); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN wakeup_at TO scheduled_at") + .execute(&pool) + .await + .unwrap(); + + let err = task + .run_step(&pool, step, worker_lease()) + .await + .unwrap_err(); + + assert_database_error(err); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_saves_terminal_errors_after_retry_limit(pool: PgPool) { + init_tracing(); + let retry_limit = >::RETRY_LIMIT; + let (task, step, lease) = claim_task(&pool, TestTask::RetryFail(RetryFail), retry_limit).await; + let started_at = Utc::now(); + + task.run_step(&pool, step, lease).await.unwrap(); + + let finished_at = Utc::now(); + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); + assert_eq!(row.tried, retry_limit + 1); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); + assert!(row + .error + .as_deref() + .is_some_and(|error| error.contains("retryable failure"))); + assert_timestamp_between( + row.wakeup_at, + started_at, + finished_at + ChronoDuration::seconds(1), + ); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_does_not_save_terminal_errors_after_losing_the_lease(pool: PgPool) { + init_tracing(); + let retry_limit = >::RETRY_LIMIT; + let (task, step, lease) = claim_task(&pool, TestTask::RetryFail(RetryFail), retry_limit).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::RetryFail(RetryFail))); + assert_eq!(row.tried, retry_limit); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_returns_db_errors_when_saving_terminal_errors_fails(pool: PgPool) { + let step = TestTask::RetryFail(RetryFail); + let retry_limit = >::RETRY_LIMIT; + let id = insert_task(&pool, &step, retry_limit, false).await; + let task = task_with_step(id, &step, retry_limit); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN error TO task_error") + .execute(&pool) + .await + .unwrap(); + + let err = task + .run_step(&pool, step, worker_lease()) + .await + .unwrap_err(); + + assert_database_error(err); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_saves_next_step_serialization_failures_as_errors(pool: PgPool) { + let (task, step, lease) = claim_task(&pool, TestTask::BrokenNext(BrokenNext), 0).await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::BrokenNext(BrokenNext)),); + assert_eq!(row.tried, 1); + assert!(row.locked_by.is_none()); + assert!(row.lock_expires_at.is_none()); + assert!(row + .error + .as_deref() + .is_some_and(|error| error.contains("can't serialize test step"))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn run_step_does_not_save_next_step_serialization_errors_after_losing_the_lease( + pool: PgPool, +) { + init_tracing(); + let (task, step, lease) = claim_task(&pool, TestTask::BrokenNext(BrokenNext), 0).await; + set_task_lease( + &pool, + task.id, + other_worker_id(), + Utc::now() + ChronoDuration::seconds(60), + ) + .await; + + task.run_step(&pool, step, lease).await.unwrap(); + + let row = fetch_task_row(&pool, task.id).await.unwrap(); + assert_eq!(row.step, serialized_step(&TestTask::BrokenNext(BrokenNext)),); + assert_eq!(row.tried, 0); + assert_eq!(row.locked_by, Some(other_worker_id())); + assert!(row.error.is_none()); +} From 4015df5f105316ea103ab5506c44f041eac6b70b Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 15:13:28 +0600 Subject: [PATCH 42/44] Move listener tests into child module --- src/listener.rs | 462 +----------------------------------------- src/listener/tests.rs | 454 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 455 insertions(+), 461 deletions(-) create mode 100644 src/listener/tests.rs diff --git a/src/listener.rs b/src/listener.rs index 0e68c07..2dfe8ee 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -245,464 +245,4 @@ impl<'a> Subscription<'a> { } #[cfg(test)] -mod tests { - use super::Listener; - use crate::Error; - use chrono::{DateTime, Utc}; - use parking_lot::Mutex; - use sqlx::{postgres::PgPoolOptions, types::Uuid, PgPool}; - use std::{future::pending, io, time::Duration}; - use tokio::{ - sync::Notify, - time::{sleep, timeout}, - }; - - fn connection_error() -> sqlx::Error { - sqlx::Error::Io(io::Error::new( - io::ErrorKind::BrokenPipe, - "listener connection dropped", - )) - } - - fn init_tracing() { - static INIT: std::sync::Once = std::sync::Once::new(); - INIT.call_once(|| { - let _ = tracing_subscriber::fmt() - .with_max_level(tracing::Level::TRACE) - .with_test_writer() - .without_time() - .try_init(); - }); - } - - #[tokio::test] - async fn listen_returns_connect_errors_for_unavailable_databases() { - let listener = Listener::new(); - let db = PgPoolOptions::new() - .connect_lazy(&format!( - "postgres:///pg_task_missing_{}", - Utc::now().timestamp_micros() - )) - .unwrap(); - - let err = listener.listen(db).await.unwrap_err(); - - assert!(matches!(err, Error::ListenerConnect(_))); - } - - #[sqlx::test(migrations = "./migrations")] - async fn listen_returns_listener_listen_errors_when_subscribing_fails(pool: PgPool) { - let listener = Listener::new(); - listener.fail_next_listen_for_tests(); - - let err = listener.listen(pool).await.unwrap_err(); - - assert!(matches!( - err, - Error::ListenerListen(sqlx::Error::Protocol(_)) - )); - } - - #[tokio::test] - async fn terminal_errors_wake_future_subscribers() { - let listener = Listener::new(); - listener.set_error_and_notify_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( - "listener failed".into(), - ))); - - timeout( - Duration::from_millis(50), - listener.subscribe().wait_forever(), - ) - .await - .unwrap(); - - assert!(matches!( - listener.take_error(), - Some(Error::ListenerReceive(sqlx::Error::Protocol(_))) - )); - } - - #[tokio::test] - async fn take_error_clears_stored_errors() { - let listener = Listener::new(); - listener.set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( - "listener failed".into(), - ))); - - assert!(matches!( - listener.take_error(), - Some(Error::ListenerReceive(sqlx::Error::Protocol(_))) - )); - assert!(listener.take_error().is_none()); - } - - #[tokio::test] - async fn wait_for_returns_when_a_wakeup_arrives_before_the_timeout() { - let listener = Listener::new(); - let subscription = listener.subscribe(); - - listener.stop_worker_for_tests(); - - timeout( - Duration::from_millis(50), - subscription.wait_for(Duration::from_secs(1)), - ) - .await - .unwrap(); - assert!(listener.time_to_stop_worker()); - } - - #[tokio::test] - async fn wait_for_returns_when_the_timeout_expires_without_a_wakeup() { - let listener = Listener::new(); - - timeout( - Duration::from_millis(100), - listener.subscribe().wait_for(Duration::from_millis(10)), - ) - .await - .unwrap(); - assert!(!listener.time_to_stop_worker()); - assert!(listener.take_error().is_none()); - } - - #[tokio::test] - async fn pool_timeouts_do_not_become_terminal_listener_errors() { - init_tracing(); - let error_slot = Mutex::new(None); - let notify = Notify::new(); - let db = PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(); - - assert!( - !Listener::handle_recv_error(&error_slot, ¬ify, &db, sqlx::Error::PoolTimedOut) - .await - ); - assert!(error_slot.lock().is_none()); - } - - #[tokio::test] - async fn terminal_recv_errors_are_stored_and_notify_waiters() { - init_tracing(); - let error_slot = Mutex::new(None); - let notify = Notify::new(); - let subscription = notify.notified(); - let db = PgPoolOptions::new() - .connect_lazy("postgres:///pg_task") - .unwrap(); - - assert!( - Listener::handle_recv_error( - &error_slot, - ¬ify, - &db, - sqlx::Error::Protocol("listener failed".into()), - ) - .await - ); - timeout(Duration::from_millis(50), subscription) - .await - .unwrap(); - - assert!(matches!( - error_slot.lock().take(), - Some(Error::ListenerReceive(sqlx::Error::Protocol(_))) - )); - } - - #[sqlx::test(migrations = "./migrations")] - async fn connection_errors_resume_listening_after_the_database_recovers(pool: PgPool) { - init_tracing(); - let error_slot = Mutex::new(None); - let notify = Notify::new(); - let subscription = notify.notified(); - - assert!( - !Listener::handle_recv_error(&error_slot, ¬ify, &pool, connection_error(),).await - ); - assert!(timeout(Duration::from_millis(50), subscription) - .await - .is_err()); - assert!(error_slot.lock().is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn connection_errors_become_terminal_when_reconnection_fails(pool: PgPool) { - init_tracing(); - let error_slot = Mutex::new(None); - let notify = Notify::new(); - let subscription = notify.notified(); - sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") - .execute(&pool) - .await - .unwrap(); - - assert!( - Listener::handle_recv_error(&error_slot, ¬ify, &pool, connection_error(),).await - ); - timeout(Duration::from_millis(50), subscription) - .await - .unwrap(); - assert!(matches!( - error_slot.lock().take(), - Some(Error::Db(sqlx::Error::Database(_), _)) - )); - } - - #[tokio::test] - async fn dropping_listener_aborts_background_task() { - let listener = Listener::new(); - let task = tokio::spawn(pending::<()>()); - listener.set_task_for_tests(task.abort_handle()); - - drop(listener); - - let error = timeout(Duration::from_millis(50), task) - .await - .unwrap() - .unwrap_err(); - assert!(error.is_cancelled()); - } - - #[tokio::test] - async fn replacing_listener_task_aborts_the_previous_background_task() { - let listener = Listener::new(); - let first_task = tokio::spawn(pending::<()>()); - let second_task = tokio::spawn(pending::<()>()); - listener.set_task_for_tests(first_task.abort_handle()); - listener.set_task_for_tests(second_task.abort_handle()); - - let first_error = timeout(Duration::from_millis(50), first_task) - .await - .unwrap() - .unwrap_err(); - assert!(first_error.is_cancelled()); - - drop(listener); - - let second_error = timeout(Duration::from_millis(50), second_task) - .await - .unwrap() - .unwrap_err(); - assert!(second_error.is_cancelled()); - } - - #[tokio::test] - async fn shutdown_aborts_the_background_task() { - let listener = Listener::new(); - let task = tokio::spawn(pending::<()>()); - listener.set_task_for_tests(task.abort_handle()); - - listener.shutdown(); - - let error = timeout(Duration::from_millis(50), task) - .await - .unwrap() - .unwrap_err(); - assert!(error.is_cancelled()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn listen_wakes_subscribers_for_task_inserts(pool: PgPool) { - let listener = Listener::new(); - listener.listen(pool.clone()).await.unwrap(); - - let subscription = listener.subscribe(); - sqlx::query!( - "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2)", - "{}", - Utc::now(), - ) - .execute(&pool) - .await - .unwrap(); - - timeout(Duration::from_secs(1), subscription.wait_forever()) - .await - .unwrap(); - - assert!(!listener.time_to_stop_worker()); - assert!(listener.take_error().is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn listen_wakes_subscribers_for_non_stop_notifications(pool: PgPool) { - let listener = Listener::new(); - listener.listen(pool.clone()).await.unwrap(); - - let subscription = listener.subscribe(); - sqlx::query!("NOTIFY pg_task_changed, 'wake'") - .execute(&pool) - .await - .unwrap(); - - timeout(Duration::from_secs(1), subscription.wait_forever()) - .await - .unwrap(); - - assert!(!listener.time_to_stop_worker()); - assert!(listener.take_error().is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn listen_wakes_subscribers_for_task_inserts_and_updates(pool: PgPool) { - let listener = Listener::new(); - listener.listen(pool.clone()).await.unwrap(); - let insert_subscription = listener.subscribe(); - let id = sqlx::query!( - "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2) RETURNING id", - "{}", - Utc::now(), - ) - .fetch_one(&pool) - .await - .unwrap() - .id; - timeout(Duration::from_secs(1), insert_subscription.wait_forever()) - .await - .unwrap(); - - let update_subscription = listener.subscribe(); - sqlx::query!("UPDATE pg_task SET error = $2 WHERE id = $1", id, "boom",) - .execute(&pool) - .await - .unwrap(); - timeout(Duration::from_secs(1), update_subscription.wait_forever()) - .await - .unwrap(); - - assert!(!listener.time_to_stop_worker()); - assert!(listener.take_error().is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn stop_worker_notifications_wake_future_subscribers(pool: PgPool) { - let listener = Listener::new(); - listener.listen(pool.clone()).await.unwrap(); - - sqlx::query!("NOTIFY pg_task_changed, 'stop_worker'") - .execute(&pool) - .await - .unwrap(); - - timeout(Duration::from_secs(1), async { - loop { - if listener.time_to_stop_worker() { - return; - } - sleep(Duration::from_millis(10)).await; - } - }) - .await - .unwrap(); - - timeout( - Duration::from_millis(50), - listener.subscribe().wait_forever(), - ) - .await - .unwrap(); - - assert!(listener.take_error().is_none()); - } - - #[sqlx::test(migrations = "./migrations")] - async fn closing_the_pool_surfaces_listener_errors_to_subscribers(pool: PgPool) { - let listener = Listener::new(); - listener.listen(pool.clone()).await.unwrap(); - - let subscription = listener.subscribe(); - let close_pool = tokio::spawn({ - let pool = pool.clone(); - async move { - pool.close().await; - } - }); - - timeout(Duration::from_secs(1), subscription.wait_forever()) - .await - .unwrap(); - close_pool.await.unwrap(); - - assert!(matches!( - listener.take_error(), - Some(Error::ListenerReceive(sqlx::Error::PoolClosed)) - )); - } - - #[sqlx::test(migrations = "./migrations")] - async fn closing_the_pool_retains_terminal_wakeups_for_future_subscribers(pool: PgPool) { - let listener = Listener::new(); - listener.listen(pool.clone()).await.unwrap(); - - let close_pool = tokio::spawn({ - let pool = pool.clone(); - async move { - pool.close().await; - } - }); - - timeout(Duration::from_secs(1), async { - loop { - if listener.error.lock().is_some() { - return; - } - sleep(Duration::from_millis(10)).await; - } - }) - .await - .unwrap(); - - timeout( - Duration::from_millis(50), - listener.subscribe().wait_forever(), - ) - .await - .unwrap(); - close_pool.await.unwrap(); - - assert!(matches!( - listener.take_error(), - Some(Error::ListenerReceive(sqlx::Error::PoolClosed)) - )); - } - - #[sqlx::test(migrations = "./migrations")] - async fn updating_tasks_refreshes_updated_at(pool: PgPool) { - let row = sqlx::query!( - " - INSERT INTO pg_task (step, wakeup_at) - VALUES ($1, $2) - RETURNING id, updated_at - ", - "{}", - Utc::now(), - ) - .fetch_one(&pool) - .await - .unwrap(); - let id: Uuid = row.id; - let initial_updated_at: DateTime = row.updated_at; - - sleep(Duration::from_millis(20)).await; - - let next_updated_at: DateTime = sqlx::query!( - " - UPDATE pg_task - SET error = $2 - WHERE id = $1 - RETURNING updated_at - ", - id, - "boom", - ) - .fetch_one(&pool) - .await - .unwrap() - .updated_at; - - assert!(next_updated_at > initial_updated_at); - } -} +mod tests; diff --git a/src/listener/tests.rs b/src/listener/tests.rs new file mode 100644 index 0000000..975a86f --- /dev/null +++ b/src/listener/tests.rs @@ -0,0 +1,454 @@ +use super::Listener; +use crate::Error; +use chrono::{DateTime, Utc}; +use parking_lot::Mutex; +use sqlx::{postgres::PgPoolOptions, types::Uuid, PgPool}; +use std::{future::pending, io, time::Duration}; +use tokio::{ + sync::Notify, + time::{sleep, timeout}, +}; + +fn connection_error() -> sqlx::Error { + sqlx::Error::Io(io::Error::new( + io::ErrorKind::BrokenPipe, + "listener connection dropped", + )) +} + +fn init_tracing() { + static INIT: std::sync::Once = std::sync::Once::new(); + INIT.call_once(|| { + let _ = tracing_subscriber::fmt() + .with_max_level(tracing::Level::TRACE) + .with_test_writer() + .without_time() + .try_init(); + }); +} + +#[tokio::test] +async fn listen_returns_connect_errors_for_unavailable_databases() { + let listener = Listener::new(); + let db = PgPoolOptions::new() + .connect_lazy(&format!( + "postgres:///pg_task_missing_{}", + Utc::now().timestamp_micros() + )) + .unwrap(); + + let err = listener.listen(db).await.unwrap_err(); + + assert!(matches!(err, Error::ListenerConnect(_))); +} + +#[sqlx::test(migrations = "./migrations")] +async fn listen_returns_listener_listen_errors_when_subscribing_fails(pool: PgPool) { + let listener = Listener::new(); + listener.fail_next_listen_for_tests(); + + let err = listener.listen(pool).await.unwrap_err(); + + assert!(matches!( + err, + Error::ListenerListen(sqlx::Error::Protocol(_)) + )); +} + +#[tokio::test] +async fn terminal_errors_wake_future_subscribers() { + let listener = Listener::new(); + listener.set_error_and_notify_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + + timeout( + Duration::from_millis(50), + listener.subscribe().wait_forever(), + ) + .await + .unwrap(); + + assert!(matches!( + listener.take_error(), + Some(Error::ListenerReceive(sqlx::Error::Protocol(_))) + )); +} + +#[tokio::test] +async fn take_error_clears_stored_errors() { + let listener = Listener::new(); + listener.set_error_for_tests(Error::ListenerReceive(sqlx::Error::Protocol( + "listener failed".into(), + ))); + + assert!(matches!( + listener.take_error(), + Some(Error::ListenerReceive(sqlx::Error::Protocol(_))) + )); + assert!(listener.take_error().is_none()); +} + +#[tokio::test] +async fn wait_for_returns_when_a_wakeup_arrives_before_the_timeout() { + let listener = Listener::new(); + let subscription = listener.subscribe(); + + listener.stop_worker_for_tests(); + + timeout( + Duration::from_millis(50), + subscription.wait_for(Duration::from_secs(1)), + ) + .await + .unwrap(); + assert!(listener.time_to_stop_worker()); +} + +#[tokio::test] +async fn wait_for_returns_when_the_timeout_expires_without_a_wakeup() { + let listener = Listener::new(); + + timeout( + Duration::from_millis(100), + listener.subscribe().wait_for(Duration::from_millis(10)), + ) + .await + .unwrap(); + assert!(!listener.time_to_stop_worker()); + assert!(listener.take_error().is_none()); +} + +#[tokio::test] +async fn pool_timeouts_do_not_become_terminal_listener_errors() { + init_tracing(); + let error_slot = Mutex::new(None); + let notify = Notify::new(); + let db = PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(); + + assert!( + !Listener::handle_recv_error(&error_slot, ¬ify, &db, sqlx::Error::PoolTimedOut).await + ); + assert!(error_slot.lock().is_none()); +} + +#[tokio::test] +async fn terminal_recv_errors_are_stored_and_notify_waiters() { + init_tracing(); + let error_slot = Mutex::new(None); + let notify = Notify::new(); + let subscription = notify.notified(); + let db = PgPoolOptions::new() + .connect_lazy("postgres:///pg_task") + .unwrap(); + + assert!( + Listener::handle_recv_error( + &error_slot, + ¬ify, + &db, + sqlx::Error::Protocol("listener failed".into()), + ) + .await + ); + timeout(Duration::from_millis(50), subscription) + .await + .unwrap(); + + assert!(matches!( + error_slot.lock().take(), + Some(Error::ListenerReceive(sqlx::Error::Protocol(_))) + )); +} + +#[sqlx::test(migrations = "./migrations")] +async fn connection_errors_resume_listening_after_the_database_recovers(pool: PgPool) { + init_tracing(); + let error_slot = Mutex::new(None); + let notify = Notify::new(); + let subscription = notify.notified(); + + assert!(!Listener::handle_recv_error(&error_slot, ¬ify, &pool, connection_error(),).await); + assert!(timeout(Duration::from_millis(50), subscription) + .await + .is_err()); + assert!(error_slot.lock().is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn connection_errors_become_terminal_when_reconnection_fails(pool: PgPool) { + init_tracing(); + let error_slot = Mutex::new(None); + let notify = Notify::new(); + let subscription = notify.notified(); + sqlx::query!("ALTER TABLE pg_task RENAME COLUMN id TO task_id") + .execute(&pool) + .await + .unwrap(); + + assert!(Listener::handle_recv_error(&error_slot, ¬ify, &pool, connection_error(),).await); + timeout(Duration::from_millis(50), subscription) + .await + .unwrap(); + assert!(matches!( + error_slot.lock().take(), + Some(Error::Db(sqlx::Error::Database(_), _)) + )); +} + +#[tokio::test] +async fn dropping_listener_aborts_background_task() { + let listener = Listener::new(); + let task = tokio::spawn(pending::<()>()); + listener.set_task_for_tests(task.abort_handle()); + + drop(listener); + + let error = timeout(Duration::from_millis(50), task) + .await + .unwrap() + .unwrap_err(); + assert!(error.is_cancelled()); +} + +#[tokio::test] +async fn replacing_listener_task_aborts_the_previous_background_task() { + let listener = Listener::new(); + let first_task = tokio::spawn(pending::<()>()); + let second_task = tokio::spawn(pending::<()>()); + listener.set_task_for_tests(first_task.abort_handle()); + listener.set_task_for_tests(second_task.abort_handle()); + + let first_error = timeout(Duration::from_millis(50), first_task) + .await + .unwrap() + .unwrap_err(); + assert!(first_error.is_cancelled()); + + drop(listener); + + let second_error = timeout(Duration::from_millis(50), second_task) + .await + .unwrap() + .unwrap_err(); + assert!(second_error.is_cancelled()); +} + +#[tokio::test] +async fn shutdown_aborts_the_background_task() { + let listener = Listener::new(); + let task = tokio::spawn(pending::<()>()); + listener.set_task_for_tests(task.abort_handle()); + + listener.shutdown(); + + let error = timeout(Duration::from_millis(50), task) + .await + .unwrap() + .unwrap_err(); + assert!(error.is_cancelled()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn listen_wakes_subscribers_for_task_inserts(pool: PgPool) { + let listener = Listener::new(); + listener.listen(pool.clone()).await.unwrap(); + + let subscription = listener.subscribe(); + sqlx::query!( + "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2)", + "{}", + Utc::now(), + ) + .execute(&pool) + .await + .unwrap(); + + timeout(Duration::from_secs(1), subscription.wait_forever()) + .await + .unwrap(); + + assert!(!listener.time_to_stop_worker()); + assert!(listener.take_error().is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn listen_wakes_subscribers_for_non_stop_notifications(pool: PgPool) { + let listener = Listener::new(); + listener.listen(pool.clone()).await.unwrap(); + + let subscription = listener.subscribe(); + sqlx::query!("NOTIFY pg_task_changed, 'wake'") + .execute(&pool) + .await + .unwrap(); + + timeout(Duration::from_secs(1), subscription.wait_forever()) + .await + .unwrap(); + + assert!(!listener.time_to_stop_worker()); + assert!(listener.take_error().is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn listen_wakes_subscribers_for_task_inserts_and_updates(pool: PgPool) { + let listener = Listener::new(); + listener.listen(pool.clone()).await.unwrap(); + let insert_subscription = listener.subscribe(); + let id = sqlx::query!( + "INSERT INTO pg_task (step, wakeup_at) VALUES ($1, $2) RETURNING id", + "{}", + Utc::now(), + ) + .fetch_one(&pool) + .await + .unwrap() + .id; + timeout(Duration::from_secs(1), insert_subscription.wait_forever()) + .await + .unwrap(); + + let update_subscription = listener.subscribe(); + sqlx::query!("UPDATE pg_task SET error = $2 WHERE id = $1", id, "boom",) + .execute(&pool) + .await + .unwrap(); + timeout(Duration::from_secs(1), update_subscription.wait_forever()) + .await + .unwrap(); + + assert!(!listener.time_to_stop_worker()); + assert!(listener.take_error().is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn stop_worker_notifications_wake_future_subscribers(pool: PgPool) { + let listener = Listener::new(); + listener.listen(pool.clone()).await.unwrap(); + + sqlx::query!("NOTIFY pg_task_changed, 'stop_worker'") + .execute(&pool) + .await + .unwrap(); + + timeout(Duration::from_secs(1), async { + loop { + if listener.time_to_stop_worker() { + return; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + timeout( + Duration::from_millis(50), + listener.subscribe().wait_forever(), + ) + .await + .unwrap(); + + assert!(listener.take_error().is_none()); +} + +#[sqlx::test(migrations = "./migrations")] +async fn closing_the_pool_surfaces_listener_errors_to_subscribers(pool: PgPool) { + let listener = Listener::new(); + listener.listen(pool.clone()).await.unwrap(); + + let subscription = listener.subscribe(); + let close_pool = tokio::spawn({ + let pool = pool.clone(); + async move { + pool.close().await; + } + }); + + timeout(Duration::from_secs(1), subscription.wait_forever()) + .await + .unwrap(); + close_pool.await.unwrap(); + + assert!(matches!( + listener.take_error(), + Some(Error::ListenerReceive(sqlx::Error::PoolClosed)) + )); +} + +#[sqlx::test(migrations = "./migrations")] +async fn closing_the_pool_retains_terminal_wakeups_for_future_subscribers(pool: PgPool) { + let listener = Listener::new(); + listener.listen(pool.clone()).await.unwrap(); + + let close_pool = tokio::spawn({ + let pool = pool.clone(); + async move { + pool.close().await; + } + }); + + timeout(Duration::from_secs(1), async { + loop { + if listener.error.lock().is_some() { + return; + } + sleep(Duration::from_millis(10)).await; + } + }) + .await + .unwrap(); + + timeout( + Duration::from_millis(50), + listener.subscribe().wait_forever(), + ) + .await + .unwrap(); + close_pool.await.unwrap(); + + assert!(matches!( + listener.take_error(), + Some(Error::ListenerReceive(sqlx::Error::PoolClosed)) + )); +} + +#[sqlx::test(migrations = "./migrations")] +async fn updating_tasks_refreshes_updated_at(pool: PgPool) { + let row = sqlx::query!( + " + INSERT INTO pg_task (step, wakeup_at) + VALUES ($1, $2) + RETURNING id, updated_at + ", + "{}", + Utc::now(), + ) + .fetch_one(&pool) + .await + .unwrap(); + let id: Uuid = row.id; + let initial_updated_at: DateTime = row.updated_at; + + sleep(Duration::from_millis(20)).await; + + let next_updated_at: DateTime = sqlx::query!( + " + UPDATE pg_task + SET error = $2 + WHERE id = $1 + RETURNING updated_at + ", + id, + "boom", + ) + .fetch_one(&pool) + .await + .unwrap() + .updated_at; + + assert!(next_updated_at > initial_updated_at); +} From 75f0368b35757d8cb86674fedbb923a650374fdc Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 15:38:36 +0600 Subject: [PATCH 43/44] Apply lease review cleanup --- ...a2b71958cf632128d4065d3d2e9176f38b06e.json | 24 +++++++++ ...65c4400e46cb175daebcec8bc95c6c1e74f73.json | 25 --------- migrations/20260509130000_task-leases.sql | 4 +- src/task.rs | 54 +++++++------------ src/task/tests.rs | 7 +-- src/worker.rs | 4 +- 6 files changed, 50 insertions(+), 68 deletions(-) create mode 100644 .sqlx/query-5c51a5a2b19231f0eb260399d98a2b71958cf632128d4065d3d2e9176f38b06e.json delete mode 100644 .sqlx/query-bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73.json diff --git a/.sqlx/query-5c51a5a2b19231f0eb260399d98a2b71958cf632128d4065d3d2e9176f38b06e.json b/.sqlx/query-5c51a5a2b19231f0eb260399d98a2b71958cf632128d4065d3d2e9176f38b06e.json new file mode 100644 index 0000000..406d44e --- /dev/null +++ b/.sqlx/query-5c51a5a2b19231f0eb260399d98a2b71958cf632128d4065d3d2e9176f38b06e.json @@ -0,0 +1,24 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET tried = tried + 1,\n error = $2,\n wakeup_at = now(),\n locked_by = NULL,\n lock_expires_at = NULL\n WHERE id = $1\n AND locked_by = $3\n AND lock_expires_at > now()\n RETURNING step::TEXT as \"step!\"\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "step!", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Text", + "Uuid" + ] + }, + "nullable": [ + false + ] + }, + "hash": "5c51a5a2b19231f0eb260399d98a2b71958cf632128d4065d3d2e9176f38b06e" +} diff --git a/.sqlx/query-bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73.json b/.sqlx/query-bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73.json deleted file mode 100644 index 2c1622f..0000000 --- a/.sqlx/query-bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73.json +++ /dev/null @@ -1,25 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET tried = tried + $3,\n error = $2,\n wakeup_at = now(),\n locked_by = NULL,\n lock_expires_at = NULL\n WHERE id = $1\n AND locked_by = $4\n AND lock_expires_at > now()\n RETURNING step::TEXT as \"step!\"\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "step!", - "type_info": "Text" - } - ], - "parameters": { - "Left": [ - "Uuid", - "Text", - "Int4", - "Uuid" - ] - }, - "nullable": [ - false - ] - }, - "hash": "bf30581629ab7f798cb4cfa403d65c4400e46cb175daebcec8bc95c6c1e74f73" -} diff --git a/migrations/20260509130000_task-leases.sql b/migrations/20260509130000_task-leases.sql index f854cb6..eb9b5b6 100644 --- a/migrations/20260509130000_task-leases.sql +++ b/migrations/20260509130000_task-leases.sql @@ -28,7 +28,9 @@ ON pg_task (( )) WHERE error IS NULL; --- Remove `running_at` column +DROP INDEX pg_task_wakeup_at_idx; + +-- Remove `is_running` column UPDATE pg_task SET locked_by = gen_random_uuid(), lock_expires_at = now() diff --git a/src/task.rs b/src/task.rs index 52ade4f..1d1b9ab 100644 --- a/src/task.rs +++ b/src/task.rs @@ -111,13 +111,13 @@ impl Task { .map_err(db_error!()) } - /// Marks the task running - pub(crate) async fn mark_running( + /// Claims the task lease for this worker. + pub(crate) async fn claim_lease( &self, con: &mut PgConnection, lease: WorkerLease, ) -> Result<()> { - trace!("[{}] mark running", self.id); + trace!("[{}] claim lease", self.id); sqlx::query!( r#" UPDATE pg_task @@ -131,7 +131,7 @@ impl Task { ) .execute(con) .await - .map_err(db_error!())?; + .map_err(db_error!("claim lease"))?; Ok(()) } @@ -165,9 +165,9 @@ impl Task { .map_err(db_error!("renew leases")) } - /// Deserializes the current task step and marks it running. + /// Deserializes the current task step and claims its lease. /// If deserialization fails, stores the error instead and leaves the task - /// non-running. + /// unleased. pub(crate) async fn claim>( &self, con: &mut PgConnection, @@ -181,7 +181,7 @@ impl Task { } }; - self.mark_running(con, lease).await?; + self.claim_lease(con, lease).await?; Ok(Some(step)) } @@ -210,7 +210,7 @@ impl Task { self.retry(db, self.tried, retry_limit, retry_delay, e, lease) .await?; } else { - self.save_step_error(db, e, true, lease).await?; + self.save_step_error(db, e, lease).await?; } } Ok(NextStep::None) => self.complete(db, lease).await?, @@ -262,31 +262,23 @@ impl Task { } /// Saves the task error if the worker still owns the task. - async fn save_step_error( - &self, - db: &PgPool, - err: StepError, - increment_tried: bool, - lease: WorkerLease, - ) -> Result<()> { + async fn save_step_error(&self, db: &PgPool, err: StepError, lease: WorkerLease) -> Result<()> { let err_str = source_chain::to_string(&*err); - let tried_increment = if increment_tried { 1 } else { 0 }; let updated_task = sqlx::query!( r#" UPDATE pg_task - SET tried = tried + $3, + SET tried = tried + 1, error = $2, wakeup_at = now(), locked_by = NULL, lock_expires_at = NULL WHERE id = $1 - AND locked_by = $4 + AND locked_by = $3 AND lock_expires_at > now() RETURNING step::TEXT as "step!" "#, self.id, &err_str, - tried_increment, lease.worker_id, ) .fetch_optional(db) @@ -298,21 +290,13 @@ impl Task { return Ok(()); }; - if increment_tried { - let attempt = self.tried + 1; - error!( - "[{id}] resulted in an error at step {step} on {attempt} attempt: {err_str}", - id = self.id, - step = updated_task.step, - attempt = ordinal(attempt) - ); - } else { - error!( - "[{id}] couldn't deserialize step {step}: {err_str}", - id = self.id, - step = updated_task.step - ); - } + let attempt = self.tried + 1; + error!( + "[{id}] resulted in an error at step {step} on {attempt} attempt: {err_str}", + id = self.id, + step = updated_task.step, + attempt = ordinal(attempt) + ); Ok(()) } @@ -329,7 +313,7 @@ impl Task { .map_err(|e| Error::SerializeStep(e, format!("{step:?}"))) { Ok(x) => x, - Err(e) => return self.save_step_error(db, e.into(), true, lease).await, + Err(e) => return self.save_step_error(db, e.into(), lease).await, }; let result = sqlx::query!( " diff --git a/src/task/tests.rs b/src/task/tests.rs index 1aba541..d1d41b0 100644 --- a/src/task/tests.rs +++ b/src/task/tests.rs @@ -399,7 +399,7 @@ async fn fetch_ready_returns_db_errors_for_query_failures(pool: PgPool) { } #[sqlx::test(migrations = "./migrations")] -async fn mark_running_returns_db_errors_for_update_failures(pool: PgPool) { +async fn claim_lease_returns_db_errors_for_update_failures(pool: PgPool) { let id = insert_task(&pool, &TestTask::Valid(Valid), 0, false).await; let task = task_with_step(id, &TestTask::Valid(Valid), 0); sqlx::query!("ALTER TABLE pg_task RENAME COLUMN locked_by TO task_locked_by") @@ -408,10 +408,7 @@ async fn mark_running_returns_db_errors_for_update_failures(pool: PgPool) { .unwrap(); let mut tx = pool.begin().await.unwrap(); - let err = task - .mark_running(&mut tx, worker_lease()) - .await - .unwrap_err(); + let err = task.claim_lease(&mut tx, worker_lease()).await.unwrap_err(); assert_database_error(err); } diff --git a/src/worker.rs b/src/worker.rs index 1c50fc4..323194b 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -265,7 +265,7 @@ impl + 'static> Worker { } } - /// Claims a currently available task and marks it running. + /// Claims a currently available task lease. async fn claim_available_task( &self, lease: WorkerLease, @@ -286,7 +286,7 @@ impl + 'static> Worker { Ok(Some((task, step, lease))) } - /// Waits until the next task is ready, marks it running and returns it. + /// Waits until the next task is ready, claims its lease and returns it. /// Returns `None` if the worker is stopped #[cfg(test)] async fn recv_task(&self, lease: WorkerLease) -> Result> { From 8adf8a979b943fac03dafb38462ed7418e613cf7 Mon Sep 17 00:00:00 2001 From: imbolc Date: Fri, 15 May 2026 16:14:33 +0600 Subject: [PATCH 44/44] Track heartbeat expiry per running task lease --- ...c5b5355e3caae84f7b6feee9c274b6f8717a5.json | 24 ++++ ...33efd3bbce043b99c2b308622fb67a1357d36.json | 16 --- ...b0088d84b034b529e52ae351eb320f2b1c8b.json} | 12 +- src/task.rs | 43 +++++-- src/task/tests.rs | 15 ++- src/worker.rs | 110 ++++++++++++------ src/worker/tests.rs | 87 ++++++++++---- 7 files changed, 217 insertions(+), 90 deletions(-) create mode 100644 .sqlx/query-2614ddc9b7ceb8e907d0eda7da8c5b5355e3caae84f7b6feee9c274b6f8717a5.json delete mode 100644 .sqlx/query-5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36.json rename .sqlx/{query-0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f.json => query-bb4a0dee55de24ab3eb578b3f685b0088d84b034b529e52ae351eb320f2b1c8b.json} (66%) diff --git a/.sqlx/query-2614ddc9b7ceb8e907d0eda7da8c5b5355e3caae84f7b6feee9c274b6f8717a5.json b/.sqlx/query-2614ddc9b7ceb8e907d0eda7da8c5b5355e3caae84f7b6feee9c274b6f8717a5.json new file mode 100644 index 0000000..255faf2 --- /dev/null +++ b/.sqlx/query-2614ddc9b7ceb8e907d0eda7da8c5b5355e3caae84f7b6feee9c274b6f8717a5.json @@ -0,0 +1,24 @@ +{ + "db_name": "PostgreSQL", + "query": "\n UPDATE pg_task\n SET locked_by = $2,\n lock_expires_at = now() + $3::interval\n WHERE id = $1\n RETURNING lock_expires_at AS \"lock_expires_at!\"\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "lock_expires_at!", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Uuid", + "Uuid", + "Interval" + ] + }, + "nullable": [ + true + ] + }, + "hash": "2614ddc9b7ceb8e907d0eda7da8c5b5355e3caae84f7b6feee9c274b6f8717a5" +} diff --git a/.sqlx/query-5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36.json b/.sqlx/query-5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36.json deleted file mode 100644 index 21d2f92..0000000 --- a/.sqlx/query-5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36.json +++ /dev/null @@ -1,16 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET locked_by = $2,\n lock_expires_at = now() + $3::interval\n WHERE id = $1\n ", - "describe": { - "columns": [], - "parameters": { - "Left": [ - "Uuid", - "Uuid", - "Interval" - ] - }, - "nullable": [] - }, - "hash": "5fa64eb35dad3a1a8639ec4f5b933efd3bbce043b99c2b308622fb67a1357d36" -} diff --git a/.sqlx/query-0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f.json b/.sqlx/query-bb4a0dee55de24ab3eb578b3f685b0088d84b034b529e52ae351eb320f2b1c8b.json similarity index 66% rename from .sqlx/query-0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f.json rename to .sqlx/query-bb4a0dee55de24ab3eb578b3f685b0088d84b034b529e52ae351eb320f2b1c8b.json index 02dfdb1..66ae5f3 100644 --- a/.sqlx/query-0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f.json +++ b/.sqlx/query-bb4a0dee55de24ab3eb578b3f685b0088d84b034b529e52ae351eb320f2b1c8b.json @@ -1,12 +1,17 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE pg_task\n SET lock_expires_at = now() + $2::interval\n WHERE locked_by = $1\n AND id = ANY($3)\n AND lock_expires_at > now()\n AND error IS NULL\n RETURNING id\n ", + "query": "\n UPDATE pg_task\n SET lock_expires_at = now() + $2::interval\n WHERE locked_by = $1\n AND id = ANY($3)\n AND lock_expires_at > now()\n AND error IS NULL\n RETURNING id, lock_expires_at AS \"lock_expires_at!\"\n ", "describe": { "columns": [ { "ordinal": 0, "name": "id", "type_info": "Uuid" + }, + { + "ordinal": 1, + "name": "lock_expires_at!", + "type_info": "Timestamptz" } ], "parameters": { @@ -17,8 +22,9 @@ ] }, "nullable": [ - false + false, + true ] }, - "hash": "0acd322b71ea8e7cfea28155ee3912a20c6b9f92396c707dbd8411b3e30d625f" + "hash": "bb4a0dee55de24ab3eb578b3f685b0088d84b034b529e52ae351eb320f2b1c8b" } diff --git a/src/task.rs b/src/task.rs index 1d1b9ab..3fbdbb3 100644 --- a/src/task.rs +++ b/src/task.rs @@ -25,6 +25,18 @@ pub(crate) struct WorkerLease { timeout: PgInterval, } +#[derive(Debug)] +pub(crate) struct ClaimedStep { + pub(crate) step: S, + pub(crate) lock_expires_at: DateTime, +} + +#[derive(Debug)] +pub(crate) struct RenewedTaskLease { + pub(crate) task_id: Uuid, + pub(crate) lock_expires_at: DateTime, +} + impl WorkerLease { pub(crate) fn new(worker_id: Uuid, timeout: Duration) -> Self { Self { @@ -116,7 +128,7 @@ impl Task { &self, con: &mut PgConnection, lease: WorkerLease, - ) -> Result<()> { + ) -> Result> { trace!("[{}] claim lease", self.id); sqlx::query!( r#" @@ -124,15 +136,16 @@ impl Task { SET locked_by = $2, lock_expires_at = now() + $3::interval WHERE id = $1 + RETURNING lock_expires_at AS "lock_expires_at!" "#, self.id, lease.worker_id, lease.timeout, ) - .execute(con) + .fetch_one(con) .await - .map_err(db_error!("claim lease"))?; - Ok(()) + .map(|row| row.lock_expires_at) + .map_err(db_error!("claim lease")) } /// Renews live task leases owned by a worker. @@ -140,7 +153,7 @@ impl Task { db: &PgPool, lease: WorkerLease, task_ids: &[Uuid], - ) -> Result> { + ) -> Result> { if task_ids.is_empty() { return Ok(Vec::new()); } @@ -153,7 +166,7 @@ impl Task { AND id = ANY($3) AND lock_expires_at > now() AND error IS NULL - RETURNING id + RETURNING id, lock_expires_at AS "lock_expires_at!" "#, lease.worker_id, lease.timeout, @@ -161,7 +174,14 @@ impl Task { ) .fetch_all(db) .await - .map(|rows| rows.into_iter().map(|row| row.id).collect()) + .map(|rows| { + rows.into_iter() + .map(|row| RenewedTaskLease { + task_id: row.id, + lock_expires_at: row.lock_expires_at, + }) + .collect() + }) .map_err(db_error!("renew leases")) } @@ -172,7 +192,7 @@ impl Task { &self, con: &mut PgConnection, lease: WorkerLease, - ) -> Result> { + ) -> Result>> { let step = match self.parse_step() { Ok(step) => step, Err(e) => { @@ -181,8 +201,11 @@ impl Task { } }; - self.claim_lease(con, lease).await?; - Ok(Some(step)) + let lock_expires_at = self.claim_lease(con, lease).await?; + Ok(Some(ClaimedStep { + step, + lock_expires_at, + })) } /// Runs the current step of the task to completion diff --git a/src/task/tests.rs b/src/task/tests.rs index d1d41b0..5cf9963 100644 --- a/src/task/tests.rs +++ b/src/task/tests.rs @@ -273,7 +273,7 @@ async fn claim_task(pool: &PgPool, step: TestTask, tried: i32) -> (Task, TestTas .unwrap() .unwrap(); tx.commit().await.unwrap(); - (task, claimed, lease) + (task, claimed.step, lease) } async fn fetch_task_row(pool: &PgPool, id: Uuid) -> Option { @@ -445,7 +445,10 @@ async fn claim_marks_valid_steps_leased(pool: PgPool) { tx.commit().await.unwrap(); let finished_at = Utc::now(); - assert!(matches!(claimed, Some(TestTask::Valid(Valid)))); + assert!(matches!( + claimed.map(|claimed| claimed.step), + Some(TestTask::Valid(Valid)) + )); let row = fetch_task_row(&pool, id).await.unwrap(); assert_eq!(row.step, serialized_step(&TestTask::Valid(Valid))); @@ -543,7 +546,13 @@ async fn renew_leases_extends_only_live_owned_leases(pool: PgPool) { .unwrap(); let finished_at = Utc::now(); - assert_eq!(renewed, vec![owned]); + let renewed_task_ids: Vec<_> = renewed.iter().map(|lease| lease.task_id).collect(); + assert_eq!(renewed_task_ids, vec![owned]); + assert_timestamp_between( + renewed[0].lock_expires_at, + started_at + ChronoDuration::seconds(60), + finished_at + ChronoDuration::seconds(61), + ); let owned = fetch_task_row(&pool, owned).await.unwrap(); assert_timestamp_between( owned.lock_expires_at.unwrap(), diff --git a/src/worker.rs b/src/worker.rs index 323194b..ddebfdc 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,17 +1,15 @@ use crate::{ listener::Listener, - task::{Task, WorkerLease}, - util::{db_error, db_interruption, wait_for_reconnection, DbInterruption}, + task::{RenewedTaskLease, Task, WorkerLease}, + util::{ + db_error, db_interruption, std_duration_to_chrono, wait_for_reconnection, DbInterruption, + }, Error, Result, Step, LOST_CONNECTION_SLEEP, }; +use chrono::{DateTime, Utc}; use parking_lot::Mutex; use sqlx::postgres::PgPool; -use std::{ - marker::PhantomData, - num::NonZeroUsize, - sync::Arc, - time::{Duration, Instant}, -}; +use std::{marker::PhantomData, num::NonZeroUsize, sync::Arc, time::Duration}; use tokio::{ sync::{mpsc, Semaphore}, time::{interval, sleep, MissedTickBehavior}, @@ -37,6 +35,7 @@ struct RunEvents { struct RunningStep { task_id: Uuid, abort_handle: tokio::task::AbortHandle, + lock_expires_at: DateTime, } enum TaskAvailability { @@ -158,7 +157,7 @@ impl + 'static> Worker { }; match availability { Ok(TaskAvailability::Ready) => match self.claim_available_task(lease).await { - Ok(Some((task, step, lease))) => { + Ok(Some((task, step, lease, lock_expires_at))) => { let permit = reserved_permit .take() .expect("task claiming requires a reserved semaphore permit"); @@ -172,7 +171,12 @@ impl + 'static> Worker { }; drop(permit); }); - Self::track_running_step(&running_steps, task_id, step.abort_handle()); + Self::track_running_step( + &running_steps, + task_id, + step.abort_handle(), + lock_expires_at, + ); } Ok(None) => continue, Err(e) => { @@ -269,7 +273,7 @@ impl + 'static> Worker { async fn claim_available_task( &self, lease: WorkerLease, - ) -> Result> { + ) -> Result)>> { trace!("Claiming an available task"); let mut tx = self.db.begin().await.map_err(db_error!("begin"))?; @@ -278,18 +282,21 @@ impl + 'static> Worker { return Ok(None); }; - let Some(step) = task.claim(&mut tx, lease).await? else { + let Some(claimed) = task.claim(&mut tx, lease).await? else { tx.commit().await.map_err(db_error!("save error"))?; return Ok(None); }; - tx.commit().await.map_err(db_error!("mark running"))?; - Ok(Some((task, step, lease))) + tx.commit().await.map_err(db_error!("claim lease"))?; + Ok(Some((task, claimed.step, lease, claimed.lock_expires_at))) } /// Waits until the next task is ready, claims its lease and returns it. /// Returns `None` if the worker is stopped #[cfg(test)] - async fn recv_task(&self, lease: WorkerLease) -> Result> { + async fn recv_task( + &self, + lease: WorkerLease, + ) -> Result)>> { trace!("Receiving the next task"); loop { @@ -386,17 +393,14 @@ impl + 'static> Worker { let db = self.db.clone(); let mut heartbeat = interval(self.heartbeat_interval); let heartbeat_interval = self.heartbeat_interval; - let lease_timeout = self.lease_timeout; heartbeat.set_missed_tick_behavior(MissedTickBehavior::Delay); tokio::spawn(async move { - let mut last_renewed_at = Instant::now(); let mut renewal_failed = false; heartbeat.tick().await; loop { heartbeat.tick().await; let running_task_ids = Self::running_task_ids(&running_steps); if running_task_ids.is_empty() { - last_renewed_at = Instant::now(); if renewal_failed { let _ = events.send(HeartbeatEvent::Recovered); renewal_failed = false; @@ -404,33 +408,33 @@ impl + 'static> Worker { continue; } match Task::renew_leases(&db, lease, &running_task_ids).await { - Ok(renewed_task_ids) - if Self::renewed_all_running_leases( + Ok(renewed_leases) + if Self::update_running_lease_expirations( &running_task_ids, - &renewed_task_ids, + &renewed_leases, &running_steps, ) => { - trace!("Renewed {} task leases", renewed_task_ids.len()); - last_renewed_at = Instant::now(); + trace!("Renewed {} task leases", renewed_leases.len()); if renewal_failed { let _ = events.send(HeartbeatEvent::Recovered); renewal_failed = false; } } - Ok(renewed_task_ids) => { + Ok(renewed_leases) => { warn!( "Task lease renewal updated {} of {} running task leases", - renewed_task_ids.len(), + renewed_leases.len(), running_task_ids.len() ); if !renewal_failed { let _ = events.send(HeartbeatEvent::Failed); renewal_failed = true; } - if last_renewed_at.elapsed().saturating_add(heartbeat_interval) - >= lease_timeout - { + if Self::running_lease_expires_before_next_heartbeat( + &running_steps, + heartbeat_interval, + ) { let _ = events.send(HeartbeatEvent::Expired(Error::TaskLeaseExpired)); break; } @@ -444,10 +448,10 @@ impl + 'static> Worker { let _ = events.send(HeartbeatEvent::Failed); renewal_failed = true; } - if !Self::running_task_ids(&running_steps).is_empty() - && last_renewed_at.elapsed().saturating_add(heartbeat_interval) - >= lease_timeout - { + if Self::running_lease_expires_before_next_heartbeat( + &running_steps, + heartbeat_interval, + ) { let _ = events.send(HeartbeatEvent::Expired(error)); break; } @@ -462,12 +466,14 @@ impl + 'static> Worker { running_steps: &Mutex>, task_id: Uuid, abort_handle: tokio::task::AbortHandle, + lock_expires_at: DateTime, ) { let mut running_steps = running_steps.lock(); running_steps.retain(|step| !step.abort_handle.is_finished()); running_steps.push(RunningStep { task_id, abort_handle, + lock_expires_at, }); } @@ -489,15 +495,43 @@ impl + 'static> Worker { running_steps.iter().map(|step| step.task_id).collect() } - fn renewed_all_running_leases( + fn update_running_lease_expirations( running_task_ids: &[Uuid], - renewed_task_ids: &[Uuid], + renewed_leases: &[RenewedTaskLease], running_steps: &Mutex>, ) -> bool { - let still_running_task_ids = Self::running_task_ids(running_steps); - running_task_ids.iter().all(|task_id| { - renewed_task_ids.contains(task_id) || !still_running_task_ids.contains(task_id) - }) + let mut running_steps = running_steps.lock(); + running_steps.retain(|step| !step.abort_handle.is_finished()); + let mut all_running_leases_renewed = true; + + for step in running_steps.iter_mut() { + if !running_task_ids.contains(&step.task_id) { + continue; + } + + if let Some(renewed_lease) = renewed_leases + .iter() + .find(|renewed_lease| renewed_lease.task_id == step.task_id) + { + step.lock_expires_at = renewed_lease.lock_expires_at; + } else { + all_running_leases_renewed = false; + } + } + + all_running_leases_renewed + } + + fn running_lease_expires_before_next_heartbeat( + running_steps: &Mutex>, + heartbeat_interval: Duration, + ) -> bool { + let next_heartbeat_at = Utc::now() + std_duration_to_chrono(heartbeat_interval); + let mut running_steps = running_steps.lock(); + running_steps.retain(|step| !step.abort_handle.is_finished()); + running_steps + .iter() + .any(|step| step.lock_expires_at <= next_heartbeat_at) } async fn finish_run( diff --git a/src/worker/tests.rs b/src/worker/tests.rs index 9813b6b..5c11860 100644 --- a/src/worker/tests.rs +++ b/src/worker/tests.rs @@ -427,9 +427,18 @@ fn worker_lease(worker: &Worker) -> WorkerLease { } fn running_step_entry(task_id: Uuid, abort_handle: tokio::task::AbortHandle) -> RunningStep { + running_step_entry_with_lease_expiry(task_id, abort_handle, Utc::now()) +} + +fn running_step_entry_with_lease_expiry( + task_id: Uuid, + abort_handle: tokio::task::AbortHandle, + lock_expires_at: chrono::DateTime, +) -> RunningStep { RunningStep { task_id, abort_handle, + lock_expires_at, } } @@ -714,8 +723,16 @@ async fn heartbeat_expires_when_any_running_step_loses_its_lease(pool: PgPool) { std::future::pending::<()>().await; }); let running_steps = Arc::new(Mutex::new(vec![ - running_step_entry(live, live_step.abort_handle()), - running_step_entry(expired, expired_step.abort_handle()), + running_step_entry_with_lease_expiry( + live, + live_step.abort_handle(), + Utc::now() + ChronoDuration::milliseconds(200), + ), + running_step_entry_with_lease_expiry( + expired, + expired_step.abort_handle(), + Utc::now() - ChronoDuration::milliseconds(1), + ), ])); let (events, mut events_receiver) = mpsc::unbounded_channel(); let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); @@ -798,9 +815,10 @@ async fn heartbeat_reports_recovery_after_live_leases_are_renewed(pool: PgPool) let running_step = tokio::spawn(async { std::future::pending::<()>().await; }); - let running_steps = Arc::new(Mutex::new(vec![running_step_entry( + let running_steps = Arc::new(Mutex::new(vec![running_step_entry_with_lease_expiry( id, running_step.abort_handle(), + initial_expires_at, )])); let (events, mut events_receiver) = mpsc::unbounded_channel(); let heartbeat = worker.spawn_heartbeat(events, running_steps, lease); @@ -841,7 +859,12 @@ async fn running_step_tracking_prunes_finished_steps() { finished_step_abort, )]); - Worker::::track_running_step(&running_steps, Uuid::new_v4(), running_step_abort); + Worker::::track_running_step( + &running_steps, + Uuid::new_v4(), + running_step_abort, + Utc::now(), + ); assert_eq!(running_steps.lock().len(), 1); assert!(Worker::::has_running_steps(&running_steps)); @@ -851,6 +874,28 @@ async fn running_step_tracking_prunes_finished_steps() { assert!(!Worker::::has_running_steps(&running_steps)); } +#[tokio::test] +async fn running_lease_expiry_uses_the_tracked_lock_expiry() { + let running_step = tokio::spawn(async { + std::future::pending::<()>().await; + }); + let running_steps = Mutex::new(vec![running_step_entry_with_lease_expiry( + Uuid::new_v4(), + running_step.abort_handle(), + Utc::now() + ChronoDuration::milliseconds(200), + )]); + + assert!( + !Worker::::running_lease_expires_before_next_heartbeat( + &running_steps, + Duration::from_millis(20), + ) + ); + + Worker::::abort_running_steps(&running_steps); + assert!(running_step.await.unwrap_err().is_cancelled()); +} + #[sqlx::test(migrations = "./migrations")] async fn run_returns_listener_startup_errors(pool: PgPool) { let worker = Worker::::new(pool); @@ -1069,7 +1114,7 @@ async fn recv_task_skips_invalid_tasks_and_returns_next_ready_task(pool: PgPool) let worker = Worker::::new(pool.clone()); let lease = worker_lease(&worker); - let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); + let (task, step, _lease, _lock_expires_at) = worker.recv_task(lease).await.unwrap().unwrap(); assert_eq!(task.id, expected); assert!(matches!(step, TestTask::Noop(Noop))); @@ -1110,7 +1155,7 @@ async fn recv_task_rechecks_locked_ready_tasks_without_notifications(pool: PgPoo tx.rollback().await.unwrap(); - let (task, step, _lease) = timeout(Duration::from_secs(1), recv) + let (task, step, _lease, _lock_expires_at) = timeout(Duration::from_secs(1), recv) .await .unwrap() .unwrap() @@ -1136,7 +1181,7 @@ async fn recv_task_rechecks_leased_tasks_when_their_lease_expires(pool: PgPool) let lease = worker_lease(&worker); let recv = tokio::spawn(async move { worker.recv_task(lease).await }); - let (task, step, _lease) = timeout(Duration::from_secs(2), recv) + let (task, step, _lease, _lock_expires_at) = timeout(Duration::from_secs(2), recv) .await .unwrap() .unwrap() @@ -1161,7 +1206,7 @@ async fn recv_task_replaces_expired_lease_with_the_current_worker(pool: PgPool) let worker_id = Uuid::new_v4(); let lease = WorkerLease::new(worker_id, worker.lease_timeout); - let (task, step, _lease) = worker.recv_task(lease).await.unwrap().unwrap(); + let (task, step, _lease, _lock_expires_at) = worker.recv_task(lease).await.unwrap().unwrap(); assert_eq!(task.id, id); assert!(matches!(step, TestTask::Noop(Noop))); @@ -1194,18 +1239,20 @@ async fn two_workers_claim_ready_tasks_once(pool: PgPool) { let first_recv = tokio::spawn(async move { first_worker.recv_task(first_lease).await }); let second_recv = tokio::spawn(async move { second_worker.recv_task(second_lease).await }); - let (first_task, first_step, _first_lease) = timeout(Duration::from_secs(1), first_recv) - .await - .unwrap() - .unwrap() - .unwrap() - .unwrap(); - let (second_task, second_step, _second_lease) = timeout(Duration::from_secs(1), second_recv) - .await - .unwrap() - .unwrap() - .unwrap() - .unwrap(); + let (first_task, first_step, _first_lease, _first_lock_expires_at) = + timeout(Duration::from_secs(1), first_recv) + .await + .unwrap() + .unwrap() + .unwrap() + .unwrap(); + let (second_task, second_step, _second_lease, _second_lock_expires_at) = + timeout(Duration::from_secs(1), second_recv) + .await + .unwrap() + .unwrap() + .unwrap() + .unwrap(); assert!(matches!(first_step, TestTask::Noop(Noop))); assert!(matches!(second_step, TestTask::Noop(Noop)));