diff --git a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py index ab667c655..ddcba0470 100644 --- a/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py +++ b/airbyte_cdk/sources/declarative/incremental/concurrent_partition_cursor.py @@ -5,6 +5,7 @@ import copy import logging import threading +import time from collections import OrderedDict from copy import deepcopy from datetime import timedelta @@ -58,7 +59,8 @@ class ConcurrentPerPartitionCursor(Cursor): CurrentPerPartitionCursor expects the state of the ConcurrentCursor to follow the format {cursor_field: cursor_value}. """ - DEFAULT_MAX_PARTITIONS_NUMBER = 10000 + DEFAULT_MAX_PARTITIONS_NUMBER = 25_000 + SWITCH_TO_GLOBAL_LIMIT = 10_000 _NO_STATE: Mapping[str, Any] = {} _NO_CURSOR_STATE: Mapping[str, Any] = {} _GLOBAL_STATE_KEY = "state" @@ -99,9 +101,11 @@ def __init__( self._new_global_cursor: Optional[StreamState] = None self._lookback_window: int = 0 self._parent_state: Optional[StreamState] = None - self._over_limit: int = 0 + self._number_of_partitions: int = 0 self._use_global_cursor: bool = False self._partition_serializer = PerPartitionKeySerializer() + # Track the last time a state message was emitted + self._last_emission_time: float = 0.0 self._set_initial_state(stream_state) @@ -141,21 +145,16 @@ def close_partition(self, partition: Partition) -> None: raise ValueError("stream_slice cannot be None") partition_key = self._to_partition_key(stream_slice.partition) - self._cursor_per_partition[partition_key].close_partition(partition=partition) with self._lock: self._semaphore_per_partition[partition_key].acquire() - cursor = self._cursor_per_partition[partition_key] - if ( - partition_key in self._finished_partitions - and self._semaphore_per_partition[partition_key]._value == 0 - ): + if not self._use_global_cursor: + self._cursor_per_partition[partition_key].close_partition(partition=partition) + cursor = self._cursor_per_partition[partition_key] if ( - self._new_global_cursor is None - or self._new_global_cursor[self.cursor_field.cursor_field_key] - < cursor.state[self.cursor_field.cursor_field_key] + partition_key in self._finished_partitions + and self._semaphore_per_partition[partition_key]._value == 0 ): - self._new_global_cursor = copy.deepcopy(cursor.state) - if not self._use_global_cursor: + self._update_global_cursor(cursor.state[self.cursor_field.cursor_field_key]) self._emit_state_message() def ensure_at_least_one_state_emitted(self) -> None: @@ -169,9 +168,23 @@ def ensure_at_least_one_state_emitted(self) -> None: self._global_cursor = self._new_global_cursor self._lookback_window = self._timer.finish() self._parent_state = self._partition_router.get_stream_state() - self._emit_state_message() + self._emit_state_message(throttle=False) - def _emit_state_message(self) -> None: + def _throttle_state_message(self) -> Optional[float]: + """ + Throttles the state message emission to once every 60 seconds. + """ + current_time = time.time() + if current_time - self._last_emission_time <= 60: + return None + return current_time + + def _emit_state_message(self, throttle: bool = True) -> None: + if throttle: + current_time = self._throttle_state_message() + if current_time is None: + return + self._last_emission_time = current_time self._connector_state_manager.update_state_for_stream( self._stream_name, self._stream_namespace, @@ -202,6 +215,7 @@ def _generate_slices_from_partition(self, partition: StreamSlice) -> Iterable[St self._lookback_window if self._global_cursor else 0, ) with self._lock: + self._number_of_partitions += 1 self._cursor_per_partition[self._to_partition_key(partition.partition)] = cursor self._semaphore_per_partition[self._to_partition_key(partition.partition)] = ( threading.Semaphore(0) @@ -232,9 +246,15 @@ def _ensure_partition_limit(self) -> None: - Logs a warning each time a partition is removed, indicating whether it was finished or removed due to being the oldest. """ + if not self._use_global_cursor and self.limit_reached(): + logger.info( + f"Exceeded the 'SWITCH_TO_GLOBAL_LIMIT' of {self.SWITCH_TO_GLOBAL_LIMIT}. " + f"Switching to global cursor for {self._stream_name}." + ) + self._use_global_cursor = True + with self._lock: while len(self._cursor_per_partition) > self.DEFAULT_MAX_PARTITIONS_NUMBER - 1: - self._over_limit += 1 # Try removing finished partitions first for partition_key in list(self._cursor_per_partition.keys()): if ( @@ -245,7 +265,7 @@ def _ensure_partition_limit(self) -> None: partition_key ) # Remove the oldest partition logger.warning( - f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._over_limit}." + f"The maximum number of partitions has been reached. Dropping the oldest finished partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}." ) break else: @@ -254,7 +274,7 @@ def _ensure_partition_limit(self) -> None: 1 ] # Remove the oldest partition logger.warning( - f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._over_limit}." + f"The maximum number of partitions has been reached. Dropping the oldest partition: {oldest_partition}. Over limit: {self._number_of_partitions - self.DEFAULT_MAX_PARTITIONS_NUMBER}." ) def _set_initial_state(self, stream_state: StreamState) -> None: @@ -314,6 +334,7 @@ def _set_initial_state(self, stream_state: StreamState) -> None: self._lookback_window = int(stream_state.get("lookback_window", 0)) for state in stream_state.get(self._PERPARTITION_STATE_KEY, []): + self._number_of_partitions += 1 self._cursor_per_partition[self._to_partition_key(state["partition"])] = ( self._create_cursor(state["cursor"]) ) @@ -354,16 +375,26 @@ def _set_global_state(self, stream_state: Mapping[str, Any]) -> None: self._new_global_cursor = deepcopy(fixed_global_state) def observe(self, record: Record) -> None: - if not self._use_global_cursor and self.limit_reached(): - self._use_global_cursor = True - if not record.associated_slice: raise ValueError( "Invalid state as stream slices that are emitted should refer to an existing cursor" ) - self._cursor_per_partition[ - self._to_partition_key(record.associated_slice.partition) - ].observe(record) + + record_cursor = self._connector_state_converter.output_format( + self._connector_state_converter.parse_value(self._cursor_field.extract_value(record)) + ) + self._update_global_cursor(record_cursor) + if not self._use_global_cursor: + self._cursor_per_partition[ + self._to_partition_key(record.associated_slice.partition) + ].observe(record) + + def _update_global_cursor(self, value: Any) -> None: + if ( + self._new_global_cursor is None + or self._new_global_cursor[self.cursor_field.cursor_field_key] < value + ): + self._new_global_cursor = {self.cursor_field.cursor_field_key: copy.deepcopy(value)} def _to_partition_key(self, partition: Mapping[str, Any]) -> str: return self._partition_serializer.to_partition_key(partition) @@ -397,4 +428,4 @@ def _get_cursor(self, record: Record) -> ConcurrentCursor: return cursor def limit_reached(self) -> bool: - return self._over_limit > self.DEFAULT_MAX_PARTITIONS_NUMBER + return self._number_of_partitions > self.SWITCH_TO_GLOBAL_LIMIT diff --git a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py index ef06676f5..767d24874 100644 --- a/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py +++ b/unit_tests/sources/declarative/incremental/test_concurrent_perpartitioncursor.py @@ -3,6 +3,7 @@ from copy import deepcopy from datetime import datetime, timedelta from typing import Any, List, Mapping, MutableMapping, Optional, Union +from unittest.mock import MagicMock, patch from urllib.parse import unquote import pytest @@ -18,6 +19,7 @@ from airbyte_cdk.sources.declarative.concurrent_declarative_source import ( ConcurrentDeclarativeSource, ) +from airbyte_cdk.sources.declarative.incremental import ConcurrentPerPartitionCursor from airbyte_cdk.test.catalog_builder import CatalogBuilder, ConfiguredAirbyteStreamBuilder from airbyte_cdk.test.entrypoint_wrapper import EntrypointOutput, read @@ -1181,14 +1183,18 @@ def test_incremental_parent_state( initial_state, expected_state, ): - run_incremental_parent_state_test( - manifest, - mock_requests, - expected_records, - num_intermediate_states, - initial_state, - [expected_state], - ) + # Patch `_throttle_state_message` so it always returns a float (indicating "no throttle") + with patch.object( + ConcurrentPerPartitionCursor, "_throttle_state_message", return_value=9999999.0 + ): + run_incremental_parent_state_test( + manifest, + mock_requests, + expected_records, + num_intermediate_states, + initial_state, + [expected_state], + ) STATE_MIGRATION_EXPECTED_STATE = { @@ -2967,3 +2973,47 @@ def test_incremental_substream_request_options_provider( expected_records, expected_state, ) + + +def test_state_throttling(mocker): + """ + Verifies that _emit_state_message does not emit a new state if less than 60s + have passed since last emission, and does emit once 60s or more have passed. + """ + cursor = ConcurrentPerPartitionCursor( + cursor_factory=MagicMock(), + partition_router=MagicMock(), + stream_name="test_stream", + stream_namespace=None, + stream_state={}, + message_repository=MagicMock(), + connector_state_manager=MagicMock(), + connector_state_converter=MagicMock(), + cursor_field=MagicMock(), + ) + + mock_connector_manager = cursor._connector_state_manager + mock_repo = cursor._message_repository + + # Set the last emission time to "0" so we can control offset from that + cursor._last_emission_time = 0 + + mock_time = mocker.patch("time.time") + + # First attempt: only 10 seconds passed => NO emission + mock_time.return_value = 10 + cursor._emit_state_message() + mock_connector_manager.update_state_for_stream.assert_not_called() + mock_repo.emit_message.assert_not_called() + + # Second attempt: 30 seconds passed => still NO emission + mock_time.return_value = 30 + cursor._emit_state_message() + mock_connector_manager.update_state_for_stream.assert_not_called() + mock_repo.emit_message.assert_not_called() + + # Advance time: 70 seconds => exceed 60s => MUST emit + mock_time.return_value = 70 + cursor._emit_state_message() + mock_connector_manager.update_state_for_stream.assert_called_once() + mock_repo.emit_message.assert_called_once()