diff --git a/airbyte_cdk/connector_builder/connector_builder_handler.py b/airbyte_cdk/connector_builder/connector_builder_handler.py index e6c9a3f3f..56da5f848 100644 --- a/airbyte_cdk/connector_builder/connector_builder_handler.py +++ b/airbyte_cdk/connector_builder/connector_builder_handler.py @@ -2,10 +2,11 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. # -import dataclasses + +from dataclasses import asdict, dataclass, field from typing import Any, List, Mapping -from airbyte_cdk.connector_builder.message_grouper import MessageGrouper +from airbyte_cdk.connector_builder.test_reader import TestReader from airbyte_cdk.models import ( AirbyteMessage, AirbyteRecordMessage, @@ -32,11 +33,11 @@ MAX_RECORDS_KEY = "max_records" -@dataclasses.dataclass +@dataclass class TestReadLimits: - max_records: int = dataclasses.field(default=DEFAULT_MAXIMUM_RECORDS) - max_pages_per_slice: int = dataclasses.field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE) - max_slices: int = dataclasses.field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES) + max_records: int = field(default=DEFAULT_MAXIMUM_RECORDS) + max_pages_per_slice: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_PAGES_PER_SLICE) + max_slices: int = field(default=DEFAULT_MAXIMUM_NUMBER_OF_SLICES) def get_limits(config: Mapping[str, Any]) -> TestReadLimits: @@ -73,17 +74,20 @@ def read_stream( limits: TestReadLimits, ) -> AirbyteMessage: try: - handler = MessageGrouper(limits.max_pages_per_slice, limits.max_slices, limits.max_records) - stream_name = configured_catalog.streams[ - 0 - ].stream.name # The connector builder only supports a single stream - stream_read = handler.get_message_groups( + test_read_handler = TestReader( + limits.max_pages_per_slice, limits.max_slices, limits.max_records + ) + # The connector builder only supports a single stream + stream_name = configured_catalog.streams[0].stream.name + + stream_read = test_read_handler.run_test_read( source, config, configured_catalog, state, limits.max_records ) + return AirbyteMessage( type=MessageType.RECORD, record=AirbyteRecordMessage( - data=dataclasses.asdict(stream_read), stream=stream_name, emitted_at=_emitted_at() + data=asdict(stream_read), stream=stream_name, emitted_at=_emitted_at() ), ) except Exception as exc: diff --git a/airbyte_cdk/connector_builder/message_grouper.py b/airbyte_cdk/connector_builder/message_grouper.py deleted file mode 100644 index ce43afab8..000000000 --- a/airbyte_cdk/connector_builder/message_grouper.py +++ /dev/null @@ -1,448 +0,0 @@ -# -# Copyright (c) 2023 Airbyte, Inc., all rights reserved. -# - -import json -import logging -from copy import deepcopy -from json import JSONDecodeError -from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Union - -from airbyte_cdk.connector_builder.models import ( - AuxiliaryRequest, - HttpRequest, - HttpResponse, - LogMessage, - StreamRead, - StreamReadPages, - StreamReadSlices, -) -from airbyte_cdk.entrypoint import AirbyteEntrypoint -from airbyte_cdk.models import ( - AirbyteControlMessage, - AirbyteLogMessage, - AirbyteMessage, - AirbyteStateMessage, - AirbyteTraceMessage, - ConfiguredAirbyteCatalog, - OrchestratorType, - TraceType, -) -from airbyte_cdk.models import Type as MessageType -from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource -from airbyte_cdk.sources.utils.slice_logger import SliceLogger -from airbyte_cdk.sources.utils.types import JsonType -from airbyte_cdk.utils import AirbyteTracedException -from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer -from airbyte_cdk.utils.schema_inferrer import SchemaInferrer, SchemaValidationException - - -class MessageGrouper: - logger = logging.getLogger("airbyte.connector-builder") - - def __init__(self, max_pages_per_slice: int, max_slices: int, max_record_limit: int = 1000): - self._max_pages_per_slice = max_pages_per_slice - self._max_slices = max_slices - self._max_record_limit = max_record_limit - - def _pk_to_nested_and_composite_field( - self, field: Optional[Union[str, List[str], List[List[str]]]] - ) -> List[List[str]]: - if not field: - return [[]] - - if isinstance(field, str): - return [[field]] - - is_composite_key = isinstance(field[0], str) - if is_composite_key: - return [[i] for i in field] # type: ignore # the type of field is expected to be List[str] here - - return field # type: ignore # the type of field is expected to be List[List[str]] here - - def _cursor_field_to_nested_and_composite_field( - self, field: Union[str, List[str]] - ) -> List[List[str]]: - if not field: - return [[]] - - if isinstance(field, str): - return [[field]] - - is_nested_key = isinstance(field[0], str) - if is_nested_key: - return [field] - - raise ValueError(f"Unknown type for cursor field `{field}") - - def get_message_groups( - self, - source: DeclarativeSource, - config: Mapping[str, Any], - configured_catalog: ConfiguredAirbyteCatalog, - state: List[AirbyteStateMessage], - record_limit: Optional[int] = None, - ) -> StreamRead: - if record_limit is not None and not (1 <= record_limit <= self._max_record_limit): - raise ValueError( - f"Record limit must be between 1 and {self._max_record_limit}. Got {record_limit}" - ) - stream = source.streams(config)[ - 0 - ] # The connector builder currently only supports reading from a single stream at a time - schema_inferrer = SchemaInferrer( - self._pk_to_nested_and_composite_field(stream.primary_key), - self._cursor_field_to_nested_and_composite_field(stream.cursor_field), - ) - datetime_format_inferrer = DatetimeFormatInferrer() - - if record_limit is None: - record_limit = self._max_record_limit - else: - record_limit = min(record_limit, self._max_record_limit) - - slices = [] - log_messages = [] - latest_config_update: AirbyteControlMessage = None - auxiliary_requests = [] - for message_group in self._get_message_groups( - self._read_stream(source, config, configured_catalog, state), - schema_inferrer, - datetime_format_inferrer, - record_limit, - ): - if isinstance(message_group, AirbyteLogMessage): - log_messages.append( - LogMessage( - **{"message": message_group.message, "level": message_group.level.value} - ) - ) - elif isinstance(message_group, AirbyteTraceMessage): - if message_group.type == TraceType.ERROR: - log_messages.append( - LogMessage( - **{ - "message": message_group.error.message, - "level": "ERROR", - "internal_message": message_group.error.internal_message, - "stacktrace": message_group.error.stack_trace, - } - ) - ) - elif isinstance(message_group, AirbyteControlMessage): - if ( - not latest_config_update - or latest_config_update.emitted_at <= message_group.emitted_at - ): - latest_config_update = message_group - elif isinstance(message_group, AuxiliaryRequest): - auxiliary_requests.append(message_group) - elif isinstance(message_group, StreamReadSlices): - slices.append(message_group) - else: - raise ValueError(f"Unknown message group type: {type(message_group)}") - - try: - # The connector builder currently only supports reading from a single stream at a time - configured_stream = configured_catalog.streams[0] - schema = schema_inferrer.get_stream_schema(configured_stream.stream.name) - except SchemaValidationException as exception: - for validation_error in exception.validation_errors: - log_messages.append(LogMessage(validation_error, "ERROR")) - schema = exception.schema - - return StreamRead( - logs=log_messages, - slices=slices, - test_read_limit_reached=self._has_reached_limit(slices), - auxiliary_requests=auxiliary_requests, - inferred_schema=schema, - latest_config_update=self._clean_config(latest_config_update.connectorConfig.config) - if latest_config_update - else None, - inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(), - ) - - def _get_message_groups( - self, - messages: Iterator[AirbyteMessage], - schema_inferrer: SchemaInferrer, - datetime_format_inferrer: DatetimeFormatInferrer, - limit: int, - ) -> Iterable[ - Union[ - StreamReadPages, - AirbyteControlMessage, - AirbyteLogMessage, - AirbyteTraceMessage, - AuxiliaryRequest, - ] - ]: - """ - Message groups are partitioned according to when request log messages are received. Subsequent response log messages - and record messages belong to the prior request log message and when we encounter another request, append the latest - message group, until records have been read. - - Messages received from the CDK read operation will always arrive in the following order: - {type: LOG, log: {message: "request: ..."}} - {type: LOG, log: {message: "response: ..."}} - ... 0 or more record messages - {type: RECORD, record: {data: ...}} - {type: RECORD, record: {data: ...}} - Repeats for each request/response made - - Note: The exception is that normal log messages can be received at any time which are not incorporated into grouping - """ - records_count = 0 - at_least_one_page_in_group = False - current_page_records: List[Mapping[str, Any]] = [] - current_slice_descriptor: Optional[Dict[str, Any]] = None - current_slice_pages: List[StreamReadPages] = [] - current_page_request: Optional[HttpRequest] = None - current_page_response: Optional[HttpResponse] = None - latest_state_message: Optional[Dict[str, Any]] = None - - while records_count < limit and (message := next(messages, None)): - json_object = self._parse_json(message.log) if message.type == MessageType.LOG else None - if json_object is not None and not isinstance(json_object, dict): - raise ValueError( - f"Expected log message to be a dict, got {json_object} of type {type(json_object)}" - ) - json_message: Optional[Dict[str, JsonType]] = json_object - if self._need_to_close_page(at_least_one_page_in_group, message, json_message): - self._close_page( - current_page_request, - current_page_response, - current_slice_pages, - current_page_records, - ) - current_page_request = None - current_page_response = None - - if ( - at_least_one_page_in_group - and message.type == MessageType.LOG - and message.log.message.startswith(SliceLogger.SLICE_LOG_PREFIX) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message - ): - yield StreamReadSlices( - pages=current_slice_pages, - slice_descriptor=current_slice_descriptor, - state=[latest_state_message] if latest_state_message else [], - ) - current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message - current_slice_pages = [] - at_least_one_page_in_group = False - elif message.type == MessageType.LOG and message.log.message.startswith( # type: ignore[union-attr] # None doesn't have 'message' - SliceLogger.SLICE_LOG_PREFIX - ): - # parsing the first slice - current_slice_descriptor = self._parse_slice_description(message.log.message) # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message - elif message.type == MessageType.LOG: - if json_message is not None and self._is_http_log(json_message): - if self._is_auxiliary_http_request(json_message): - airbyte_cdk = json_message.get("airbyte_cdk", {}) - if not isinstance(airbyte_cdk, dict): - raise ValueError( - f"Expected airbyte_cdk to be a dict, got {airbyte_cdk} of type {type(airbyte_cdk)}" - ) - stream = airbyte_cdk.get("stream", {}) - if not isinstance(stream, dict): - raise ValueError( - f"Expected stream to be a dict, got {stream} of type {type(stream)}" - ) - title_prefix = ( - "Parent stream: " if stream.get("is_substream", False) else "" - ) - http = json_message.get("http", {}) - if not isinstance(http, dict): - raise ValueError( - f"Expected http to be a dict, got {http} of type {type(http)}" - ) - yield AuxiliaryRequest( - title=title_prefix + str(http.get("title", None)), - description=str(http.get("description", None)), - request=self._create_request_from_log_message(json_message), - response=self._create_response_from_log_message(json_message), - ) - else: - at_least_one_page_in_group = True - current_page_request = self._create_request_from_log_message(json_message) - current_page_response = self._create_response_from_log_message(json_message) - else: - yield message.log - elif message.type == MessageType.TRACE: - if message.trace.type == TraceType.ERROR: # type: ignore[union-attr] # AirbyteMessage with MessageType.TRACE has trace.type - yield message.trace - elif message.type == MessageType.RECORD: - current_page_records.append(message.record.data) # type: ignore[arg-type, union-attr] # AirbyteMessage with MessageType.RECORD has record.data - records_count += 1 - schema_inferrer.accumulate(message.record) - datetime_format_inferrer.accumulate(message.record) - elif ( - message.type == MessageType.CONTROL - and message.control.type == OrchestratorType.CONNECTOR_CONFIG # type: ignore[union-attr] # None doesn't have 'type' - ): - yield message.control - elif message.type == MessageType.STATE: - latest_state_message = message.state # type: ignore[assignment] - else: - if current_page_request or current_page_response or current_page_records: - self._close_page( - current_page_request, - current_page_response, - current_slice_pages, - current_page_records, - ) - yield StreamReadSlices( - pages=current_slice_pages, - slice_descriptor=current_slice_descriptor, - state=[latest_state_message] if latest_state_message else [], - ) - - @staticmethod - def _need_to_close_page( - at_least_one_page_in_group: bool, - message: AirbyteMessage, - json_message: Optional[Dict[str, Any]], - ) -> bool: - return ( - at_least_one_page_in_group - and message.type == MessageType.LOG - and ( - MessageGrouper._is_page_http_request(json_message) - or message.log.message.startswith("slice:") # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message - ) - ) - - @staticmethod - def _is_page_http_request(json_message: Optional[Dict[str, Any]]) -> bool: - if not json_message: - return False - else: - return MessageGrouper._is_http_log( - json_message - ) and not MessageGrouper._is_auxiliary_http_request(json_message) - - @staticmethod - def _is_http_log(message: Dict[str, JsonType]) -> bool: - return bool(message.get("http", False)) - - @staticmethod - def _is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool: - """ - A auxiliary request is a request that is performed and will not directly lead to record for the specific stream it is being queried. - A couple of examples are: - * OAuth authentication - * Substream slice generation - """ - if not message: - return False - - is_http = MessageGrouper._is_http_log(message) - return is_http and message.get("http", {}).get("is_auxiliary", False) - - @staticmethod - def _close_page( - current_page_request: Optional[HttpRequest], - current_page_response: Optional[HttpResponse], - current_slice_pages: List[StreamReadPages], - current_page_records: List[Mapping[str, Any]], - ) -> None: - """ - Close a page when parsing message groups - """ - current_slice_pages.append( - StreamReadPages( - request=current_page_request, - response=current_page_response, - records=deepcopy(current_page_records), # type: ignore [arg-type] - ) - ) - current_page_records.clear() - - def _read_stream( - self, - source: DeclarativeSource, - config: Mapping[str, Any], - configured_catalog: ConfiguredAirbyteCatalog, - state: List[AirbyteStateMessage], - ) -> Iterator[AirbyteMessage]: - # the generator can raise an exception - # iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage - try: - yield from AirbyteEntrypoint(source).read( - source.spec(self.logger), config, configured_catalog, state - ) - except AirbyteTracedException as traced_exception: - # Look for this message which indicates that it is the "final exception" raised by AbstractSource. - # If it matches, don't yield this as we don't need to show this in the Builder. - # This is somewhat brittle as it relies on the message string, but if they drift then the worst case - # is that this message will be shown in the Builder. - if ( - traced_exception.message is not None - and "During the sync, the following streams did not sync successfully" - in traced_exception.message - ): - return - yield traced_exception.as_airbyte_message() - except Exception as e: - error_message = f"{e.args[0] if len(e.args) > 0 else str(e)}" - yield AirbyteTracedException.from_exception( - e, message=error_message - ).as_airbyte_message() - - @staticmethod - def _parse_json(log_message: AirbyteLogMessage) -> JsonType: - # TODO: As a temporary stopgap, the CDK emits request/response data as a log message string. Ideally this should come in the - # form of a custom message object defined in the Airbyte protocol, but this unblocks us in the immediate while the - # protocol change is worked on. - try: - json_object: JsonType = json.loads(log_message.message) - return json_object - except JSONDecodeError: - return None - - @staticmethod - def _create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpRequest: - url = json_http_message.get("url", {}).get("full", "") - request = json_http_message.get("http", {}).get("request", {}) - return HttpRequest( - url=url, - http_method=request.get("method", ""), - headers=request.get("headers"), - body=request.get("body", {}).get("content", ""), - ) - - @staticmethod - def _create_response_from_log_message(json_http_message: Dict[str, Any]) -> HttpResponse: - response = json_http_message.get("http", {}).get("response", {}) - body = response.get("body", {}).get("content", "") - return HttpResponse( - status=response.get("status_code"), body=body, headers=response.get("headers") - ) - - def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool: - if len(slices) >= self._max_slices: - return True - - record_count = 0 - - for _slice in slices: - if len(_slice.pages) >= self._max_pages_per_slice: - return True - for page in _slice.pages: - record_count += len(page.records) - if record_count >= self._max_record_limit: - return True - return False - - def _parse_slice_description(self, log_message: str) -> Dict[str, Any]: - return json.loads(log_message.replace(SliceLogger.SLICE_LOG_PREFIX, "", 1)) # type: ignore - - @staticmethod - def _clean_config(config: Dict[str, Any]) -> Dict[str, Any]: - cleaned_config = deepcopy(config) - for key in config.keys(): - if key.startswith("__"): - del cleaned_config[key] - return cleaned_config diff --git a/airbyte_cdk/connector_builder/test_reader/__init__.py b/airbyte_cdk/connector_builder/test_reader/__init__.py new file mode 100644 index 000000000..5159c657c --- /dev/null +++ b/airbyte_cdk/connector_builder/test_reader/__init__.py @@ -0,0 +1,7 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +from .reader import TestReader + +__all__ = ("TestReader",) diff --git a/airbyte_cdk/connector_builder/test_reader/helpers.py b/airbyte_cdk/connector_builder/test_reader/helpers.py new file mode 100644 index 000000000..21b75c134 --- /dev/null +++ b/airbyte_cdk/connector_builder/test_reader/helpers.py @@ -0,0 +1,591 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +import json +from copy import deepcopy +from json import JSONDecodeError +from typing import Any, Dict, List, Mapping, Optional + +from airbyte_cdk.connector_builder.models import ( + AuxiliaryRequest, + HttpRequest, + HttpResponse, + StreamReadPages, + StreamReadSlices, +) +from airbyte_cdk.models import ( + AirbyteLogMessage, + AirbyteMessage, + OrchestratorType, + TraceType, +) +from airbyte_cdk.models import Type as MessageType +from airbyte_cdk.sources.utils.slice_logger import SliceLogger +from airbyte_cdk.sources.utils.types import JsonType +from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer +from airbyte_cdk.utils.schema_inferrer import ( + SchemaInferrer, +) + +from .types import LOG_MESSAGES_OUTPUT_TYPE + +# ------- +# Parsers +# ------- + + +def airbyte_message_to_json(message: AirbyteMessage) -> Optional[Dict[str, JsonType]]: + """ + Converts an AirbyteMessage to a JSON dictionary if its type is LOG. + + This function attempts to parse the 'log' field of the given AirbyteMessage when its type is MessageType.LOG. + If the parsed JSON object exists but is not a dictionary, a ValueError is raised. If the message is not of type LOG, + the function returns None. + + Parameters: + message (AirbyteMessage): The AirbyteMessage instance containing the log data. + + Returns: + Optional[Dict[str, JsonType]]: The parsed log message as a dictionary if the message type is LOG, otherwise None. + + Raises: + ValueError: If the parsed log message is not a dictionary. + """ + if is_log_message(message): + json_object = parse_json(message.log) # type: ignore + + if json_object is not None and not isinstance(json_object, dict): + raise ValueError( + f"Expected log message to be a dict, got {json_object} of type {type(json_object)}" + ) + + return json_object + return None + + +def clean_config(config: Dict[str, Any]) -> Dict[str, Any]: + """ + Cleans the configuration dictionary by removing all keys that start with a double underscore. + + This function creates a deep copy of the provided configuration dictionary and iterates + over its keys, deleting any key that begins with '__'. This is useful for filtering out + internal or meta-data fields that are not meant to be part of the final configuration. + + Args: + config (Dict[str, Any]): The input configuration dictionary containing various key-value pairs. + + Returns: + Dict[str, Any]: A deep copy of the original configuration with keys starting with '__' removed. + """ + cleaned_config = deepcopy(config) + for key in config.keys(): + if key.startswith("__"): + del cleaned_config[key] + return cleaned_config + + +def create_request_from_log_message(json_http_message: Dict[str, Any]) -> HttpRequest: + """ + Creates an HttpRequest object from the provided JSON-formatted log message. + + This function parses a dictionary that represents a logged HTTP message, extracting the URL, HTTP method, + headers, and body from nested dictionary structures. It is assumed that the expected keys and nested keys exist + or default values are used. + + Parameters: + json_http_message (Dict[str, Any]): + A dictionary containing log message details with the following expected structure: + { + "url": { + "full": "" + }, + "http": { + "request": { + "method": "", + "headers": , + "body": { + "content": "" + } + } + } + } + + Returns: + HttpRequest: + An HttpRequest instance initialized with: + - url: Extracted from json_http_message["url"]["full"], defaults to an empty string if missing. + - http_method: Extracted from json_http_message["http"]["request"]["method"], defaults to an empty string if missing. + - headers: Extracted from json_http_message["http"]["request"]["headers"]. + - body: Extracted from json_http_message["http"]["request"]["body"]["content"], defaults to an empty string if missing. + """ + url = json_http_message.get("url", {}).get("full", "") + request = json_http_message.get("http", {}).get("request", {}) + return HttpRequest( + url=url, + http_method=request.get("method", ""), + headers=request.get("headers"), + body=request.get("body", {}).get("content", ""), + ) + + +def create_response_from_log_message(json_http_message: Dict[str, Any]) -> HttpResponse: + """ + Generate an HttpResponse instance from a JSON log message containing HTTP response details. + + Parameters: + json_http_message (Dict[str, Any]): A dictionary representing a JSON-encoded HTTP message. + It should include an "http" key with a nested "response" dictionary that contains: + - "status_code": The HTTP status code. + - "body": A dictionary with a "content" key for the response body. + - "headers": The HTTP response headers. + + Returns: + HttpResponse: An HttpResponse object constructed from the extracted status code, body content, and headers. + """ + response = json_http_message.get("http", {}).get("response", {}) + body = response.get("body", {}).get("content", "") + return HttpResponse( + status=response.get("status_code"), body=body, headers=response.get("headers") + ) + + +def parse_json(log_message: AirbyteLogMessage) -> JsonType: + """ + Parse and extract a JSON object from an Airbyte log message. + + This function attempts to decode the JSON string contained in the message field + of the provided AirbyteLogMessage instance. If the decoding process fails due to + malformed JSON, the function returns None. + + Args: + log_message (AirbyteLogMessage): A log message object containing a JSON-formatted string in its 'message' attribute. + + Returns: + JsonType: The parsed JSON object if decoding is successful; otherwise, None. + """ + # TODO: As a temporary stopgap, the CDK emits request/response data as a log message string. Ideally this should come in the + # form of a custom message object defined in the Airbyte protocol, but this unblocks us in the immediate while the + # protocol change is worked on. + try: + json_object: JsonType = json.loads(log_message.message) + return json_object + except JSONDecodeError: + return None + + +def parse_slice_description(log_message: str) -> Dict[str, Any]: + """ + Parses a log message containing a JSON payload and returns it as a dictionary. + + The function removes a predefined logging prefix (defined by the constant + SliceLogger.SLICE_LOG_PREFIX) from the beginning of the log message and then + parses the remaining string as JSON. + + Parameters: + log_message (str): The log message string that includes the JSON payload, + prefixed by SliceLogger.SLICE_LOG_PREFIX. + + Returns: + Dict[str, Any]: A dictionary resulting from parsing the modified log message. + + Raises: + json.JSONDecodeError: If the log message (after prefix removal) is not a valid JSON. + """ + return json.loads(log_message.replace(SliceLogger.SLICE_LOG_PREFIX, "", 1)) # type: ignore + + +# ------- +# Conditions +# ------- + + +def should_close_page( + at_least_one_page_in_group: bool, + message: AirbyteMessage, + json_message: Optional[Dict[str, Any]], +) -> bool: + """ + Determines whether a page should be closed based on its content and state. + + Args: + at_least_one_page_in_group (bool): Indicates if there is at least one page in the group. + message (AirbyteMessage): The message object containing details such as type and log information. + json_message (Optional[Dict[str, Any]]): A JSON representation of the message that may provide additional context, + particularly for HTTP requests. + + Returns: + bool: True if all of the following conditions are met: + - There is at least one page in the group. + - The message type is MessageType.LOG. + - Either the JSON message corresponds to a page HTTP request (as determined by _is_page_http_request) + or the log message starts with "slice:". + Otherwise, returns False. + """ + return ( + at_least_one_page_in_group + and is_log_message(message) + and ( + is_page_http_request(json_message) or message.log.message.startswith("slice:") # type: ignore[union-attr] # AirbyteMessage with MessageType.LOG has log.message + ) + ) + + +def should_process_slice_descriptor(message: AirbyteMessage) -> bool: + """ + Determines whether the given AirbyteMessage should be processed as a slice descriptor. + + This function checks if the message is a log message and if its log content starts with the + specific slice log prefix. It is used to filter out messages that represent slice descriptors + for further processing. + + Parameters: + message (AirbyteMessage): The message to evaluate. + + Returns: + bool: True if the message is a log message whose log message starts with the predefined + slice log prefix, indicating it is a slice descriptor; otherwise, False. + """ + return is_log_message(message) and message.log.message.startswith( # type: ignore + SliceLogger.SLICE_LOG_PREFIX + ) + + +def should_close_page_for_slice(at_least_one_page_in_group: bool, message: AirbyteMessage) -> bool: + """ + Determines whether the current slice page should be closed. + + This function checks if there is at least one page in the current group and if further processing + of the slice descriptor is required based on the provided Airbyte message. + + Args: + at_least_one_page_in_group (bool): Indicates if at least one page already exists in the slice group. + message (AirbyteMessage): The message containing the slice descriptor information to be evaluated. + + Returns: + bool: True if both conditions are met and the slice page needs to be closed; otherwise, False. + """ + return at_least_one_page_in_group and should_process_slice_descriptor(message) + + +def is_page_http_request(json_message: Optional[Dict[str, Any]]) -> bool: + """ + Determines whether a given JSON message represents a page HTTP request. + + This function checks if the provided JSON message qualifies as a page HTTP request by verifying that: + 1. The JSON message exists. + 2. The JSON message is recognized as a valid HTTP log. + 3. The JSON message is not classified as an auxiliary HTTP request. + + Args: + json_message (Optional[Dict[str, Any]]): A dictionary containing the JSON message to be evaluated. + If None or empty, the message will not be considered a page HTTP request. + + Returns: + bool: True if the JSON message is a valid HTTP log and not an auxiliary HTTP request; otherwise, False. + """ + if not json_message: + return False + else: + return is_http_log(json_message) and not is_auxiliary_http_request(json_message) + + +def is_http_log(message: Dict[str, JsonType]) -> bool: + """ + Determine if the provided log message represents an HTTP log. + + This function inspects the given message dictionary for the presence of the "http" key. + If the key exists and its value is truthy, the function interprets the message as an HTTP log. + + Args: + message (Dict[str, JsonType]): A dictionary containing log data. It may include an "http" key + whose truthy value indicates an HTTP log. + + Returns: + bool: True if the message is an HTTP log (i.e., "http" exists and is truthy); otherwise, False. + """ + return bool(message.get("http", False)) + + +def is_auxiliary_http_request(message: Optional[Dict[str, Any]]) -> bool: + """ + Determines if the provided message represents an auxiliary HTTP request. + + A auxiliary request is a request that is performed and will not directly lead to record for the specific stream it is being queried. + + A couple of examples are: + * OAuth authentication + * Substream slice generation + + Parameters: + message (Optional[Dict[str, Any]]): A dictionary representing a log message for an HTTP request. + The dictionary may contain nested keys indicating whether the request is auxiliary. + + Returns: + bool: True if the message is an HTTP log and indicates an auxiliary request; otherwise, False. + """ + if not message: + return False + + return is_http_log(message) and message.get("http", {}).get("is_auxiliary", False) + + +def is_log_message(message: AirbyteMessage) -> bool: + """ + Determines whether the provided message is of type LOG. + + Args: + message (AirbyteMessage): The message to evaluate. + + Returns: + bool: True if the message's type is LOG, otherwise False. + """ + return message.type == MessageType.LOG # type: ignore + + +def is_trace_with_error(message: AirbyteMessage) -> bool: + """ + Determines whether the provided AirbyteMessage is a TRACE message with an error. + + This function checks if the message's type is TRACE and that its trace component is of type ERROR. + + Parameters: + message (AirbyteMessage): The Airbyte message to be evaluated. + + Returns: + bool: True if the message is a TRACE message with an error, False otherwise. + """ + return message.type == MessageType.TRACE and message.trace.type == TraceType.ERROR # type: ignore + + +def is_record_message(message: AirbyteMessage) -> bool: + """ + Determines whether the provided Airbyte message represents a record. + + Parameters: + message (AirbyteMessage): The message instance to check. It should include a 'type' attribute that is comparable to MessageType.RECORD. + + Returns: + bool: True if the message type is RECORD, otherwise False. + """ + return message.type == MessageType.RECORD # type: ignore + + +def is_config_update_message(message: AirbyteMessage) -> bool: + """ + Determine whether the provided AirbyteMessage represents a connector configuration update. + + This function evaluates if the message is a control message and if its control type + matches that of a connector configuration update (i.e., OrchestratorType.CONNECTOR_CONFIG). + It is primarily used to filter messages related to configuration updates in the data pipeline. + + Parameters: + message (AirbyteMessage): The message object to be evaluated. + + Returns: + bool: True if the message is a connector configuration update message, False otherwise. + """ + return ( # type: ignore + message.type == MessageType.CONTROL + and message.control.type == OrchestratorType.CONNECTOR_CONFIG # type: ignore + ) + + +def is_state_message(message: AirbyteMessage) -> bool: + """ + Determines whether the provided AirbyteMessage is a state message. + + Parameters: + message (AirbyteMessage): The message to inspect. + + Returns: + bool: True if the message's type is MessageType.STATE, False otherwise. + """ + return message.type == MessageType.STATE # type: ignore + + +# ------- +# Handlers +# ------- + + +def handle_current_slice( + current_slice_pages: List[StreamReadPages], + current_slice_descriptor: Optional[Dict[str, Any]] = None, + latest_state_message: Optional[Dict[str, Any]] = None, +) -> StreamReadSlices: + """ + Handles the current slice by packaging its pages, descriptor, and state into a StreamReadSlices instance. + + Args: + current_slice_pages (List[StreamReadPages]): The pages to be included in the slice. + current_slice_descriptor (Optional[Dict[str, Any]]): Descriptor for the current slice, optional. + latest_state_message (Optional[Dict[str, Any]]): The latest state message, optional. + + Returns: + StreamReadSlices: An object containing the current slice's pages, descriptor, and state. + """ + return StreamReadSlices( + pages=current_slice_pages, + slice_descriptor=current_slice_descriptor, + state=[latest_state_message] if latest_state_message else [], + ) + + +def handle_current_page( + current_page_request: Optional[HttpRequest], + current_page_response: Optional[HttpResponse], + current_slice_pages: List[StreamReadPages], + current_page_records: List[Mapping[str, Any]], +) -> tuple[None, None]: + """ + Closes the current page by appending its request, response, and records + to the list of pages and clearing the current page records. + + Args: + current_page_request (Optional[HttpRequest]): The HTTP request associated with the current page. + current_page_response (Optional[HttpResponse]): The HTTP response associated with the current page. + current_slice_pages (List[StreamReadPages]): A list to append the current page information. + current_page_records (List[Mapping[str, Any]]): The records of the current page to be cleared after processing. + + Returns: + tuple[None, None]: A tuple indicating that no values are returned. + """ + + current_slice_pages.append( + StreamReadPages( + request=current_page_request, + response=current_page_response, + records=deepcopy(current_page_records), # type: ignore [arg-type] + ) + ) + current_page_records.clear() + + return None, None + + +def handle_auxiliary_request(json_message: Dict[str, JsonType]) -> AuxiliaryRequest: + """ + Parses the provided JSON message and constructs an AuxiliaryRequest object by extracting + relevant fields from nested dictionaries. + + This function retrieves and validates the "airbyte_cdk", "stream", and "http" dictionaries + from the JSON message. It raises a ValueError if any of these are not of type dict. A title + is dynamically created by checking if the stream is a substream and then combining a prefix + with the "title" field from the "http" dictionary. The function also uses helper functions + to generate the request and response portions of the AuxiliaryRequest. + + Parameters: + json_message (Dict[str, JsonType]): A dictionary representing the JSON log message containing + auxiliary request details. + + Returns: + AuxiliaryRequest: An object containing the formatted title, description, request, and response + extracted from the JSON message. + + Raises: + ValueError: If any of the "airbyte_cdk", "stream", or "http" fields is not a dictionary. + """ + airbyte_cdk = json_message.get("airbyte_cdk", {}) + + if not isinstance(airbyte_cdk, dict): + raise ValueError( + f"Expected airbyte_cdk to be a dict, got {airbyte_cdk} of type {type(airbyte_cdk)}" + ) + + stream = airbyte_cdk.get("stream", {}) + + if not isinstance(stream, dict): + raise ValueError(f"Expected stream to be a dict, got {stream} of type {type(stream)}") + + title_prefix = "Parent stream: " if stream.get("is_substream", False) else "" + http = json_message.get("http", {}) + + if not isinstance(http, dict): + raise ValueError(f"Expected http to be a dict, got {http} of type {type(http)}") + + return AuxiliaryRequest( + title=title_prefix + str(http.get("title", None)), + description=str(http.get("description", None)), + request=create_request_from_log_message(json_message), + response=create_response_from_log_message(json_message), + ) + + +def handle_log_message( + message: AirbyteMessage, + json_message: Dict[str, JsonType] | None, + at_least_one_page_in_group: bool, + current_page_request: Optional[HttpRequest], + current_page_response: Optional[HttpResponse], +) -> LOG_MESSAGES_OUTPUT_TYPE: + """ + Process a log message by handling both HTTP-specific and auxiliary log entries. + + Parameters: + message (AirbyteMessage): The original log message received. + json_message (Dict[str, JsonType] | None): A parsed JSON representation of the log message, if available. + at_least_one_page_in_group (bool): Indicates whether at least one page within the group has been processed. + current_page_request (Optional[HttpRequest]): The HTTP request object corresponding to the current page, if any. + current_page_response (Optional[HttpResponse]): The HTTP response object corresponding to the current page, if any. + + Returns: + LOG_MESSAGES_OUTPUT_TYPE: A tuple containing: + - A boolean flag that determines whether the group contains at least one page. + - An updated HttpRequest for the current page (if applicable). + - An updated HttpResponse for the current page (if applicable). + - The auxiliary log message, which might be the original HTTP log or another log field. + + Note: + If the parsed JSON message indicates an HTTP log and represents an auxiliary HTTP request, + the auxiliary log is handled via _handle_auxiliary_request. Otherwise, if the JSON log is a standard HTTP log, + the function updates the current page's request and response objects by generating them from the log message. + """ + auxiliary_request = None + log_message = None + + if json_message is not None and is_http_log(json_message): + if is_auxiliary_http_request(json_message): + auxiliary_request = handle_auxiliary_request(json_message) + else: + at_least_one_page_in_group = True + current_page_request = create_request_from_log_message(json_message) + current_page_response = create_response_from_log_message(json_message) + else: + log_message = message.log + + return ( + at_least_one_page_in_group, + current_page_request, + current_page_response, + auxiliary_request or log_message, + ) + + +def handle_record_message( + message: AirbyteMessage, + schema_inferrer: SchemaInferrer, + datetime_format_inferrer: DatetimeFormatInferrer, + records_count: int, + current_page_records: List[Mapping[str, Any]], +) -> int: + """ + Processes an Airbyte record message by updating the current batch and accumulating schema and datetime format information. + + Parameters: + message (AirbyteMessage): The Airbyte message to process. Expected to have a 'type' attribute and, if of type RECORD, a 'record' attribute containing the record data. + schema_inferrer (SchemaInferrer): An instance responsible for inferring and accumulating schema details based on the record data. + datetime_format_inferrer (DatetimeFormatInferrer): An instance responsible for inferring and accumulating datetime format information from the record data. + records_count (int): The current count of processed records. This value is incremented if the message is a record. + current_page_records (List[Mapping[str, Any]]): A list where the data of processed record messages is accumulated. + + Returns: + int: The updated record count after processing the message. + """ + if message.type == MessageType.RECORD: + current_page_records.append(message.record.data) # type: ignore + records_count += 1 + schema_inferrer.accumulate(message.record) # type: ignore + datetime_format_inferrer.accumulate(message.record) # type: ignore + + return records_count diff --git a/airbyte_cdk/connector_builder/test_reader/message_grouper.py b/airbyte_cdk/connector_builder/test_reader/message_grouper.py new file mode 100644 index 000000000..56082e202 --- /dev/null +++ b/airbyte_cdk/connector_builder/test_reader/message_grouper.py @@ -0,0 +1,160 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + + +from typing import Any, Dict, Iterator, List, Mapping, Optional + +from airbyte_cdk.connector_builder.models import ( + HttpRequest, + HttpResponse, + StreamReadPages, +) +from airbyte_cdk.models import ( + AirbyteMessage, +) +from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer +from airbyte_cdk.utils.schema_inferrer import ( + SchemaInferrer, +) + +from .helpers import ( + airbyte_message_to_json, + handle_current_page, + handle_current_slice, + handle_log_message, + handle_record_message, + is_config_update_message, + is_log_message, + is_record_message, + is_state_message, + is_trace_with_error, + parse_slice_description, + should_close_page, + should_close_page_for_slice, + should_process_slice_descriptor, +) +from .types import MESSAGE_GROUPS + + +def get_message_groups( + messages: Iterator[AirbyteMessage], + schema_inferrer: SchemaInferrer, + datetime_format_inferrer: DatetimeFormatInferrer, + limit: int, +) -> MESSAGE_GROUPS: + """ + Processes an iterator of AirbyteMessage objects to group and yield messages based on their type and sequence. + + This function iterates over the provided messages until the number of record messages processed reaches the specified limit. + It accumulates messages into pages and slices, handling various types of messages such as log, trace (with errors), record, + configuration update, and state messages. The function makes use of helper routines to: + - Convert messages to JSON. + - Determine when to close a page or a slice. + - Parse slice descriptors. + - Handle log messages and auxiliary requests. + - Process record messages while inferring schema and datetime formats. + + Depending on the incoming message type, it may yield: + - StreamReadSlices objects when a slice is completed. + - Auxiliary HTTP requests/responses generated from log messages. + - Error trace messages if encountered. + - Configuration update messages. + + Parameters: + messages (Iterator[AirbyteMessage]): An iterator yielding AirbyteMessage instances. + schema_inferrer (SchemaInferrer): An instance used to infer and update schema based on record messages. + datetime_format_inferrer (DatetimeFormatInferrer): An instance used to infer datetime formats from record messages. + limit (int): The maximum number of record messages to process before stopping. + + Yields: + Depending on the type of message processed, one or more of the following: + - StreamReadSlices: A grouping of pages within a slice along with the slice descriptor and state. + - HttpRequest/HttpResponse: Auxiliary request/response information derived from log messages. + - TraceMessage: Error details when a trace message with errors is encountered. + - ControlMessage: Configuration update details. + + Notes: + The function depends on several helper functions (e.g., airbyte_message_to_json, should_close_page, + handle_current_page, parse_slice_description, handle_log_message, and others) to encapsulate specific behavior. + It maintains internal state for grouping pages and slices, ensuring that related messages are correctly aggregated + and yielded as complete units. + """ + + records_count = 0 + at_least_one_page_in_group = False + current_page_records: List[Mapping[str, Any]] = [] + current_slice_descriptor: Optional[Dict[str, Any]] = None + current_slice_pages: List[StreamReadPages] = [] + current_page_request: Optional[HttpRequest] = None + current_page_response: Optional[HttpResponse] = None + latest_state_message: Optional[Dict[str, Any]] = None + + while records_count < limit and (message := next(messages, None)): + json_message = airbyte_message_to_json(message) + + if should_close_page(at_least_one_page_in_group, message, json_message): + current_page_request, current_page_response = handle_current_page( + current_page_request, + current_page_response, + current_slice_pages, + current_page_records, + ) + + if should_close_page_for_slice(at_least_one_page_in_group, message): + yield handle_current_slice( + current_slice_pages, + current_slice_descriptor, + latest_state_message, + ) + current_slice_descriptor = parse_slice_description(message.log.message) # type: ignore + current_slice_pages = [] + at_least_one_page_in_group = False + elif should_process_slice_descriptor(message): + # parsing the first slice + current_slice_descriptor = parse_slice_description(message.log.message) # type: ignore + elif is_log_message(message): + ( + at_least_one_page_in_group, + current_page_request, + current_page_response, + log_or_auxiliary_request, + ) = handle_log_message( + message, + json_message, + at_least_one_page_in_group, + current_page_request, + current_page_response, + ) + if log_or_auxiliary_request: + yield log_or_auxiliary_request + elif is_trace_with_error(message): + if message.trace is not None: + yield message.trace + elif is_record_message(message): + records_count = handle_record_message( + message, + schema_inferrer, + datetime_format_inferrer, + records_count, + current_page_records, + ) + elif is_config_update_message(message): + if message.control is not None: + yield message.control + elif is_state_message(message): + latest_state_message = message.state # type: ignore + + else: + if current_page_request or current_page_response or current_page_records: + handle_current_page( + current_page_request, + current_page_response, + current_slice_pages, + current_page_records, + ) + yield handle_current_slice( + current_slice_pages, + current_slice_descriptor, + latest_state_message, + ) diff --git a/airbyte_cdk/connector_builder/test_reader/reader.py b/airbyte_cdk/connector_builder/test_reader/reader.py new file mode 100644 index 000000000..b776811eb --- /dev/null +++ b/airbyte_cdk/connector_builder/test_reader/reader.py @@ -0,0 +1,441 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + + +import logging +from typing import Any, Dict, Iterator, List, Mapping, Optional, Union + +from airbyte_cdk.connector_builder.models import ( + AuxiliaryRequest, + LogMessage, + StreamRead, + StreamReadSlices, +) +from airbyte_cdk.entrypoint import AirbyteEntrypoint +from airbyte_cdk.models import ( + AirbyteControlMessage, + AirbyteLogMessage, + AirbyteMessage, + AirbyteStateMessage, + AirbyteTraceMessage, + ConfiguredAirbyteCatalog, + TraceType, +) +from airbyte_cdk.sources.declarative.declarative_source import DeclarativeSource +from airbyte_cdk.utils import AirbyteTracedException +from airbyte_cdk.utils.datetime_format_inferrer import DatetimeFormatInferrer +from airbyte_cdk.utils.schema_inferrer import ( + SchemaInferrer, + SchemaValidationException, +) + +from .helpers import clean_config +from .message_grouper import get_message_groups +from .types import GROUPED_MESSAGES, INFERRED_SCHEMA_OUTPUT_TYPE, MESSAGE_GROUPS + + +class TestReader: + """ + A utility class for performing test reads from a declarative data source, primarily used to validate + connector configurations by performing partial stream reads. + + Initialization: + + TestReader(max_pages_per_slice: int, max_slices: int, max_record_limit: int = 1000) + Initializes a new instance of the TestReader class with limits on pages per slice, slices, and records + per read operation. + + Public Methods: + run_test_read(source, config, configured_catalog, state, record_limit=None) -> StreamRead: + + Executes a test read operation from the given declarative source. It configures and infers the schema, + processes the read messages (including logging and error handling), and returns a StreamRead object + that contains slices of data, log messages, auxiliary requests, and any inferred schema or datetime formats. + + Parameters: + source (DeclarativeSource): The data source to read from. + config (Mapping[str, Any]): Configuration parameters for the source. + configured_catalog (ConfiguredAirbyteCatalog): Catalog containing stream configuration. + state (List[AirbyteStateMessage]): Current state information for the read. + record_limit (Optional[int]): Optional override for the maximum number of records to read. + + Returns: + StreamRead: An object encapsulating logs, data slices, auxiliary requests, and inferred metadata, + along with indicators if any configured limit was reached. + + """ + + logger = logging.getLogger("airbyte.connector-builder") + + def __init__( + self, + max_pages_per_slice: int, + max_slices: int, + max_record_limit: int = 1000, + ) -> None: + self._max_pages_per_slice = max_pages_per_slice + self._max_slices = max_slices + self._max_record_limit = max_record_limit + + def run_test_read( + self, + source: DeclarativeSource, + config: Mapping[str, Any], + configured_catalog: ConfiguredAirbyteCatalog, + state: List[AirbyteStateMessage], + record_limit: Optional[int] = None, + ) -> StreamRead: + """ + Run a test read for the connector by reading from a single stream and inferring schema and datetime formats. + + Parameters: + source (DeclarativeSource): The source instance providing the streams. + config (Mapping[str, Any]): The configuration settings to use for reading. + configured_catalog (ConfiguredAirbyteCatalog): The catalog specifying the stream configuration. + state (List[AirbyteStateMessage]): A list of state messages to resume the read. + record_limit (Optional[int], optional): Maximum number of records to read. Defaults to None. + + Returns: + StreamRead: An object containing the following attributes: + - logs (List[str]): Log messages generated during the process. + - slices (List[Any]): The data slices read from the stream. + - test_read_limit_reached (bool): Indicates whether the record limit was reached. + - auxiliary_requests (Any): Any auxiliary requests generated during reading. + - inferred_schema (Any): The schema inferred from the stream data. + - latest_config_update (Any): The latest configuration update, if applicable. + - inferred_datetime_formats (Dict[str, str]): Mapping of fields to their inferred datetime formats. + """ + + record_limit = self._check_record_limit(record_limit) + # The connector builder currently only supports reading from a single stream at a time + stream = source.streams(config)[0] + schema_inferrer = SchemaInferrer( + self._pk_to_nested_and_composite_field(stream.primary_key), + self._cursor_field_to_nested_and_composite_field(stream.cursor_field), + ) + datetime_format_inferrer = DatetimeFormatInferrer() + message_group = get_message_groups( + self._read_stream(source, config, configured_catalog, state), + schema_inferrer, + datetime_format_inferrer, + record_limit, + ) + + slices, log_messages, auxiliary_requests, latest_config_update = self._categorise_groups( + message_group + ) + schema, log_messages = self._get_infered_schema( + configured_catalog, schema_inferrer, log_messages + ) + + return StreamRead( + logs=log_messages, + slices=slices, + test_read_limit_reached=self._has_reached_limit(slices), + auxiliary_requests=auxiliary_requests, + inferred_schema=schema, + latest_config_update=self._get_latest_config_update(latest_config_update), + inferred_datetime_formats=datetime_format_inferrer.get_inferred_datetime_formats(), + ) + + def _pk_to_nested_and_composite_field( + self, field: Optional[Union[str, List[str], List[List[str]]]] + ) -> List[List[str]]: + """ + Converts a primary key definition into a nested list representation. + + The function accepts a primary key that can be a single string, a list of strings, or a list of lists of strings. + It ensures that the return value is always a list of lists of strings. + + Parameters: + field (Optional[Union[str, List[str], List[List[str]]]]): + The primary key definition. This can be: + - None or an empty value: returns a list containing an empty list. + - A single string: returns a list containing one list with the string. + - A list of strings (composite key): returns a list where each key is encapsulated in its own list. + - A list of lists of strings (nested field structure): returns as is. + + Returns: + List[List[str]]: + A nested list representation of the primary key. + """ + if not field: + return [[]] + + if isinstance(field, str): + return [[field]] + + is_composite_key = isinstance(field[0], str) + if is_composite_key: + return [[i] for i in field] # type: ignore # the type of field is expected to be List[str] here + + return field # type: ignore # the type of field is expected to be List[List[str]] here + + def _cursor_field_to_nested_and_composite_field( + self, field: Union[str, List[str]] + ) -> List[List[str]]: + """ + Transforms the cursor field input into a nested list format suitable for further processing. + + This function accepts a cursor field specification, which can be either: + - A falsy value (e.g., None or an empty string), in which case it returns a list containing an empty list. + - A string representing a simple cursor field. The string is wrapped in a nested list. + - A list of strings representing a composite or nested cursor field. The list is returned wrapped in an outer list. + + Parameters: + field (Union[str, List[str]]): The cursor field specification. It can be: + - An empty or falsy value: returns [[]]. + - A string: returns [[field]]. + - A list of strings: returns [field] if the first element is a string. + + Returns: + List[List[str]]: A nested list representation of the cursor field. + + Raises: + ValueError: If the input is a list but its first element is not a string, + indicating an unsupported type for a cursor field. + """ + if not field: + return [[]] + + if isinstance(field, str): + return [[field]] + + is_nested_key = isinstance(field[0], str) + if is_nested_key: + return [field] + + raise ValueError(f"Unknown type for cursor field `{field}") + + def _check_record_limit(self, record_limit: Optional[int] = None) -> int: + """ + Checks and adjusts the provided record limit to ensure it falls within the valid range. + + If record_limit is provided, it must be between 1 and self._max_record_limit inclusive. + If record_limit is None, it defaults to self._max_record_limit. + + Args: + record_limit (Optional[int]): The requested record limit to validate. + + Returns: + int: The validated record limit. If record_limit exceeds self._max_record_limit, the maximum allowed value is used. + + Raises: + ValueError: If record_limit is provided and is not between 1 and self._max_record_limit. + """ + if record_limit is not None and not (1 <= record_limit <= self._max_record_limit): + raise ValueError( + f"Record limit must be between 1 and {self._max_record_limit}. Got {record_limit}" + ) + + if record_limit is None: + record_limit = self._max_record_limit + else: + record_limit = min(record_limit, self._max_record_limit) + + return record_limit + + def _categorise_groups(self, message_groups: MESSAGE_GROUPS) -> GROUPED_MESSAGES: + """ + Categorizes a sequence of message groups into slices, log messages, auxiliary requests, and the latest configuration update. + + This function iterates over each message group in the provided collection and processes it based on its type: + - AirbyteLogMessage: Converts the log message into a LogMessage object and appends it to the log_messages list. + - AirbyteTraceMessage with type ERROR: Extracts error details, creates a LogMessage at the "ERROR" level, and appends it. + - AirbyteControlMessage: Updates the latest_config_update if the current message is more recent. + - AuxiliaryRequest: Appends the message to the auxiliary_requests list. + - StreamReadSlices: Appends the message to the slices list. + - Any other type: Raises a ValueError indicating an unknown message group type. + + Parameters: + message_groups (MESSAGE_GROUPS): A collection of message groups to be processed. + + Returns: + GROUPED_MESSAGES: A tuple containing four elements: + - slices: A list of StreamReadSlices objects. + - log_messages: A list of LogMessage objects. + - auxiliary_requests: A list of AuxiliaryRequest objects. + - latest_config_update: The most recent AirbyteControlMessage, or None if none was processed. + + Raises: + ValueError: If a message group of an unknown type is encountered. + """ + + slices = [] + log_messages = [] + auxiliary_requests = [] + latest_config_update: Optional[AirbyteControlMessage] = None + + for message_group in message_groups: + match message_group: + case AirbyteLogMessage(): + log_messages.append( + LogMessage(message=message_group.message, level=message_group.level.value) + ) + case AirbyteTraceMessage(): + if message_group.type == TraceType.ERROR: + log_messages.append( + LogMessage( + message=message_group.error.message, # type: ignore + level="ERROR", + internal_message=message_group.error.internal_message, # type: ignore + stacktrace=message_group.error.stack_trace, # type: ignore + ) + ) + case AirbyteControlMessage(): + if ( + not latest_config_update + or latest_config_update.emitted_at <= message_group.emitted_at + ): + latest_config_update = message_group + case AuxiliaryRequest(): + auxiliary_requests.append(message_group) + case StreamReadSlices(): + slices.append(message_group) + case _: + raise ValueError(f"Unknown message group type: {type(message_group)}") + + return slices, log_messages, auxiliary_requests, latest_config_update + + def _get_infered_schema( + self, + configured_catalog: ConfiguredAirbyteCatalog, + schema_inferrer: SchemaInferrer, + log_messages: List[LogMessage], + ) -> INFERRED_SCHEMA_OUTPUT_TYPE: + """ + Retrieves the inferred schema from the given configured catalog using the provided schema inferrer. + + This function processes a single stream from the configured catalog. It attempts to obtain the stream's + schema via the schema inferrer. If a SchemaValidationException occurs, each validation error is logged in the + provided log_messages list and the partially inferred schema (from the exception) is returned. + + Parameters: + configured_catalog (ConfiguredAirbyteCatalog): The configured catalog that contains the stream metadata. + It is assumed that only one stream is defined. + schema_inferrer (SchemaInferrer): An instance responsible for inferring the schema for a given stream. + log_messages (List[LogMessage]): A list that will be appended with log messages, especially error messages + if schema validation issues arise. + + Returns: + INFERRED_SCHEMA_OUTPUT_TYPE: A tuple consisting of the inferred schema and the updated list of log messages. + """ + + try: + # The connector builder currently only supports reading from a single stream at a time + configured_stream = configured_catalog.streams[0] + schema = schema_inferrer.get_stream_schema(configured_stream.stream.name) + except SchemaValidationException as exception: + # we update the log_messages with possible errors + for validation_error in exception.validation_errors: + log_messages.append(LogMessage(validation_error, "ERROR")) + schema = exception.schema + + return schema, log_messages + + def _get_latest_config_update( + self, + latest_config_update: AirbyteControlMessage | None, + ) -> Dict[str, Any] | None: + """ + Retrieves a cleaned configuration from the latest Airbyte control message. + + This helper function extracts the configuration from the given Airbyte control message, cleans it using the internal `Parsers.clean_config` function, + and returns the resulting dictionary. If no control message is provided (i.e., latest_config_update is None), the function returns None. + + Parameters: + latest_config_update (AirbyteControlMessage | None): The control message containing the connector configuration. May be None. + + Returns: + Dict[str, Any] | None: The cleaned configuration dictionary if available; otherwise, None. + """ + + return ( + clean_config(latest_config_update.connectorConfig.config) # type: ignore + if latest_config_update + else None + ) + + def _read_stream( + self, + source: DeclarativeSource, + config: Mapping[str, Any], + configured_catalog: ConfiguredAirbyteCatalog, + state: List[AirbyteStateMessage], + ) -> Iterator[AirbyteMessage]: + """ + Reads messages from the given DeclarativeSource using an AirbyteEntrypoint. + + This method attempts to yield messages from the source's read generator. If the generator + raises an AirbyteTracedException, it checks whether the exception message indicates a non-actionable + error (e.g., a final exception from AbstractSource that should not be logged). In that case, it stops + processing without yielding the exception as a message. For other exceptions, the exception is caught, + wrapped into an AirbyteTracedException, and yielded as an AirbyteMessage. + + Parameters: + source (DeclarativeSource): The source object that provides data reading logic. + config (Mapping[str, Any]): The configuration dictionary for the source. + configured_catalog (ConfiguredAirbyteCatalog): The catalog defining the streams and their configurations. + state (List[AirbyteStateMessage]): A list representing the current state for incremental sync. + + Yields: + AirbyteMessage: Messages yielded from the source's generator. In case of exceptions, + an AirbyteMessage encapsulating the error is yielded instead. + """ + # the generator can raise an exception + # iterate over the generated messages. if next raise an exception, catch it and yield it as an AirbyteLogMessage + try: + yield from AirbyteEntrypoint(source).read( + source.spec(self.logger), config, configured_catalog, state + ) + except AirbyteTracedException as traced_exception: + # Look for this message which indicates that it is the "final exception" raised by AbstractSource. + # If it matches, don't yield this as we don't need to show this in the Builder. + # This is somewhat brittle as it relies on the message string, but if they drift then the worst case + # is that this message will be shown in the Builder. + if ( + traced_exception.message is not None + and "During the sync, the following streams did not sync successfully" + in traced_exception.message + ): + return + yield traced_exception.as_airbyte_message() + except Exception as e: + error_message = f"{e.args[0] if len(e.args) > 0 else str(e)}" + yield AirbyteTracedException.from_exception( + e, message=error_message + ).as_airbyte_message() + + def _has_reached_limit(self, slices: List[StreamReadSlices]) -> bool: + """ + Determines whether the provided collection of slices has reached any defined limits. + + This function checks for three types of limits: + 1. If the number of slices is greater than or equal to a maximum slice limit. + 2. If any individual slice has a number of pages that meets or exceeds a maximum number of pages per slice. + 3. If the cumulative number of records across all pages in all slices reaches or exceeds a maximum record limit. + + Parameters: + slices (List[StreamReadSlices]): A list where each element represents a slice containing one or more pages, and each page has a collection of records. + + Returns: + bool: True if any of the following conditions is met: + - The number of slices is at or above the maximum allowed slices. + - Any slice contains pages at or above the maximum allowed per slice. + - The total count of records reaches or exceeds the maximum record limit. + False otherwise. + """ + if len(slices) >= self._max_slices: + return True + + record_count = 0 + + for _slice in slices: + if len(_slice.pages) >= self._max_pages_per_slice: + return True + for page in _slice.pages: + record_count += len(page.records) + if record_count >= self._max_record_limit: + return True + return False diff --git a/airbyte_cdk/connector_builder/test_reader/types.py b/airbyte_cdk/connector_builder/test_reader/types.py new file mode 100644 index 000000000..b20a009af --- /dev/null +++ b/airbyte_cdk/connector_builder/test_reader/types.py @@ -0,0 +1,75 @@ +# +# Copyright (c) 2023 Airbyte, Inc., all rights reserved. +# + +""" +This module defines type aliases utilized in the Airbyte Connector Builder's test reader. +These aliases streamline type-checking for heterogeneous message groups and schema outputs, +ensuring consistency throughout the processing of stream data and associated messages. + +Type Aliases: + MESSAGE_GROUPS: + An iterable union of message-like objects which may include: + - StreamReadSlices: Represents slices used to read data from a stream. + - AirbyteControlMessage: Represents control commands used in the Airbyte protocol. + - AirbyteLogMessage: Represents log messages generated by the system. + - AirbyteTraceMessage: Represents trace messages typically used for debugging. + - AuxiliaryRequest: Represents any supplementary request issued during processing. + + INFERRED_SCHEMA_OUTPUT_TYPE: + A tuple where: + - The first element is either an InferredSchema instance or None, denoting the inferred JSON schema. + - The second element is a list of LogMessage instances capturing logs produced during inference. + + GROUPED_MESSAGES: + A tuple representing grouped messages divided as follows: + - A list of StreamReadSlices. + - A list of LogMessage instances. + - A list of AuxiliaryRequest instances. + - An optional AirbyteControlMessage that, if present, governs control flow in message processing. +""" + +from typing import Any, Iterable, List + +from airbyte_cdk.connector_builder.models import ( + AuxiliaryRequest, + HttpRequest, + HttpResponse, + LogMessage, + StreamReadSlices, +) +from airbyte_cdk.models import ( + AirbyteControlMessage, + AirbyteLogMessage, + AirbyteTraceMessage, +) +from airbyte_cdk.utils.schema_inferrer import ( + InferredSchema, +) + +MESSAGE_GROUPS = Iterable[ + StreamReadSlices + | AirbyteControlMessage + | AirbyteLogMessage + | AirbyteTraceMessage + | AuxiliaryRequest, +] + +INFERRED_SCHEMA_OUTPUT_TYPE = tuple[ + InferredSchema | None, + List[LogMessage], +] + +GROUPED_MESSAGES = tuple[ + List[StreamReadSlices], + List[LogMessage], + List[AuxiliaryRequest], + AirbyteControlMessage | None, +] + +LOG_MESSAGES_OUTPUT_TYPE = tuple[ + bool, + HttpRequest | None, + HttpResponse | None, + AuxiliaryRequest | AirbyteLogMessage | None, +] diff --git a/unit_tests/connector_builder/test_connector_builder_handler.py b/unit_tests/connector_builder/test_connector_builder_handler.py index b0c91ce30..e6e69bd1d 100644 --- a/unit_tests/connector_builder/test_connector_builder_handler.py +++ b/unit_tests/connector_builder/test_connector_builder_handler.py @@ -550,7 +550,7 @@ def test_read(): ) limits = TestReadLimits() with patch( - "airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups", + "airbyte_cdk.connector_builder.test_reader.TestReader.run_test_read", return_value=stream_read, ) as mock: output_record = handle_connector_builder_request( @@ -1169,7 +1169,7 @@ def test_read_stream_exception_with_secrets(): # Patch the handler to raise an exception with patch( - "airbyte_cdk.connector_builder.message_grouper.MessageGrouper.get_message_groups" + "airbyte_cdk.connector_builder.test_reader.TestReader.run_test_read" ) as mock_handler: mock_handler.side_effect = Exception("Test exception with secret key: super_secret_key") diff --git a/unit_tests/connector_builder/test_message_grouper.py b/unit_tests/connector_builder/test_message_grouper.py index c3fc73308..c40514a27 100644 --- a/unit_tests/connector_builder/test_message_grouper.py +++ b/unit_tests/connector_builder/test_message_grouper.py @@ -9,7 +9,6 @@ import orjson import pytest -from airbyte_cdk.connector_builder.message_grouper import MessageGrouper from airbyte_cdk.connector_builder.models import ( HttpRequest, HttpResponse, @@ -17,6 +16,8 @@ StreamRead, StreamReadPages, ) +from airbyte_cdk.connector_builder.test_reader import TestReader +from airbyte_cdk.connector_builder.test_reader.helpers import create_response_from_log_message from airbyte_cdk.models import ( AirbyteControlConnectorConfigMessage, AirbyteControlMessage, @@ -144,7 +145,7 @@ A_SOURCE = MagicMock() -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages(mock_entrypoint_read: Mock) -> None: url = "https://demonslayers.com/api/v1/hashiras?era=taisho" request = { @@ -202,8 +203,8 @@ def test_get_grouped_messages(mock_entrypoint_read: Mock) -> None: ), ) - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) - actual_response: StreamRead = connector_builder_handler.get_message_groups( + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) + actual_response: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -218,7 +219,7 @@ def test_get_grouped_messages(mock_entrypoint_read: Mock) -> None: assert actual_page == expected_pages[i] -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_with_logs(mock_entrypoint_read: Mock) -> None: url = "https://demonslayers.com/api/v1/hashiras?era=taisho" request = { @@ -286,9 +287,9 @@ def test_get_grouped_messages_with_logs(mock_entrypoint_read: Mock) -> None: ), ) - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) - actual_response: StreamRead = connector_builder_handler.get_message_groups( + actual_response: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -309,7 +310,7 @@ def test_get_grouped_messages_with_logs(mock_entrypoint_read: Mock) -> None: pytest.param(3, 1, True, id="test_create_request_record_limit_exceeds_max"), ], ) -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_record_limit( mock_entrypoint_read: Mock, request_record_limit: int, max_record_limit: int, should_fail: bool ) -> None: @@ -339,11 +340,11 @@ def test_get_grouped_messages_record_limit( n_records = 2 record_limit = min(request_record_limit, max_record_limit) - api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit) + api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit) # this is the call we expect to raise an exception if should_fail: with pytest.raises(ValueError): - api.get_message_groups( + api.run_test_read( mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -351,7 +352,7 @@ def test_get_grouped_messages_record_limit( record_limit=request_record_limit, ) else: - actual_response: StreamRead = api.get_message_groups( + actual_response: StreamRead = api.run_test_read( mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -374,7 +375,7 @@ def test_get_grouped_messages_record_limit( pytest.param(1, id="test_create_request_no_record_limit_n_records_exceed_max"), ], ) -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_default_record_limit( mock_entrypoint_read: Mock, max_record_limit: int ) -> None: @@ -403,8 +404,8 @@ def test_get_grouped_messages_default_record_limit( ) n_records = 2 - api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit) - actual_response: StreamRead = api.get_message_groups( + api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES, max_record_limit=max_record_limit) + actual_response: StreamRead = api.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -417,7 +418,7 @@ def test_get_grouped_messages_default_record_limit( assert total_records == min([max_record_limit, n_records]) -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_limit_0(mock_entrypoint_read: Mock) -> None: url = "https://demonslayers.com/api/v1/hashiras?era=taisho" request = { @@ -442,10 +443,10 @@ def test_get_grouped_messages_limit_0(mock_entrypoint_read: Mock) -> None: ] ), ) - api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) with pytest.raises(ValueError): - api.get_message_groups( + api.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -454,7 +455,7 @@ def test_get_grouped_messages_limit_0(mock_entrypoint_read: Mock) -> None: ) -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_no_records(mock_entrypoint_read: Mock) -> None: url = "https://demonslayers.com/api/v1/hashiras?era=taisho" request = { @@ -500,9 +501,9 @@ def test_get_grouped_messages_no_records(mock_entrypoint_read: Mock) -> None: ), ) - message_grouper = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + message_grouper = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) - actual_response: StreamRead = message_grouper.get_message_groups( + actual_response: StreamRead = message_grouper.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -585,13 +586,10 @@ def test_create_response_from_log_message( else: response_message = log_message - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) - actual_response = connector_builder_handler._create_response_from_log_message(response_message) + assert create_response_from_log_message(response_message) == expected_response - assert actual_response == expected_response - -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_with_many_slices(mock_entrypoint_read: Mock) -> None: url = "http://a-url.com" request: Mapping[str, Any] = {} @@ -616,9 +614,9 @@ def test_get_grouped_messages_with_many_slices(mock_entrypoint_read: Mock) -> No ), ) - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = connector_builder_handler.get_message_groups( + stream_read: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -645,7 +643,7 @@ def test_get_grouped_messages_with_many_slices(mock_entrypoint_read: Mock) -> No ) -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_given_maximum_number_of_slices_then_test_read_limit_reached( mock_entrypoint_read: Mock, ) -> None: @@ -660,9 +658,9 @@ def test_get_grouped_messages_given_maximum_number_of_slices_then_test_read_limi ), ) - api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = api.get_message_groups( + stream_read: StreamRead = api.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -672,7 +670,7 @@ def test_get_grouped_messages_given_maximum_number_of_slices_then_test_read_limi assert stream_read.test_read_limit_reached -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_get_grouped_messages_given_maximum_number_of_pages_then_test_read_limit_reached( mock_entrypoint_read: Mock, ) -> None: @@ -688,9 +686,9 @@ def test_get_grouped_messages_given_maximum_number_of_pages_then_test_read_limit ), ) - api = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + api = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = api.get_message_groups( + stream_read: StreamRead = api.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -707,8 +705,8 @@ def test_read_stream_returns_error_if_stream_does_not_exist() -> None: full_config: Mapping[str, Any] = {**CONFIG, **{"__injected_declarative_manifest": MANIFEST}} - message_grouper = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) - actual_response = message_grouper.get_message_groups( + message_grouper = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) + actual_response = message_grouper.run_test_read( source=mock_source, config=full_config, configured_catalog=create_configured_catalog("not_in_manifest"), @@ -720,7 +718,7 @@ def test_read_stream_returns_error_if_stream_does_not_exist() -> None: assert "ERROR" in actual_response.logs[0].level -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_given_control_message_then_stream_read_has_config_update( mock_entrypoint_read: Mock, ) -> None: @@ -732,8 +730,8 @@ def test_given_control_message_then_stream_read_has_config_update( + [connector_configuration_control_message(1, updated_config)] ), ) - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = connector_builder_handler.get_message_groups( + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) + stream_read: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -743,7 +741,7 @@ def test_given_control_message_then_stream_read_has_config_update( assert stream_read.latest_config_update == updated_config -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_given_multiple_control_messages_then_stream_read_has_latest_based_on_emitted_at( mock_entrypoint_read: Mock, ) -> None: @@ -762,8 +760,8 @@ def test_given_multiple_control_messages_then_stream_read_has_latest_based_on_em ] ), ) - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = connector_builder_handler.get_message_groups( + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) + stream_read: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -773,7 +771,7 @@ def test_given_multiple_control_messages_then_stream_read_has_latest_based_on_em assert stream_read.latest_config_update == latest_config -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_given_multiple_control_messages_with_same_timestamp_then_stream_read_has_latest_based_on_message_order( mock_entrypoint_read: Mock, ) -> None: @@ -790,8 +788,8 @@ def test_given_multiple_control_messages_with_same_timestamp_then_stream_read_ha ] ), ) - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = connector_builder_handler.get_message_groups( + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) + stream_read: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -801,14 +799,14 @@ def test_given_multiple_control_messages_with_same_timestamp_then_stream_read_ha assert stream_read.latest_config_update == latest_config -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_given_auxiliary_requests_then_return_auxiliary_request(mock_entrypoint_read: Mock) -> None: mock_source = make_mock_source( mock_entrypoint_read, iter(any_request_and_response_with_a_record() + [auxiliary_request_log_message()]), ) - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = connector_builder_handler.get_message_groups( + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) + stream_read: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -818,11 +816,11 @@ def test_given_auxiliary_requests_then_return_auxiliary_request(mock_entrypoint_ assert len(stream_read.auxiliary_requests) == 1 -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_given_no_slices_then_return_empty_slices(mock_entrypoint_read: Mock) -> None: mock_source = make_mock_source(mock_entrypoint_read, iter([auxiliary_request_log_message()])) - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = connector_builder_handler.get_message_groups( + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) + stream_read: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -832,7 +830,7 @@ def test_given_no_slices_then_return_empty_slices(mock_entrypoint_read: Mock) -> assert len(stream_read.slices) == 0 -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_given_pk_then_ensure_pk_is_pass_to_schema_inferrence(mock_entrypoint_read: Mock) -> None: mock_source = make_mock_source( mock_entrypoint_read, @@ -847,9 +845,9 @@ def test_given_pk_then_ensure_pk_is_pass_to_schema_inferrence(mock_entrypoint_re mock_source.streams.return_value = [Mock()] mock_source.streams.return_value[0].primary_key = [["id"]] mock_source.streams.return_value[0].cursor_field = _NO_CURSOR_FIELD - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = connector_builder_handler.get_message_groups( + stream_read: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), @@ -859,7 +857,7 @@ def test_given_pk_then_ensure_pk_is_pass_to_schema_inferrence(mock_entrypoint_re assert stream_read.inferred_schema["required"] == ["id"] -@patch("airbyte_cdk.connector_builder.message_grouper.AirbyteEntrypoint.read") +@patch("airbyte_cdk.connector_builder.test_reader.reader.AirbyteEntrypoint.read") def test_given_cursor_field_then_ensure_cursor_field_is_pass_to_schema_inferrence( mock_entrypoint_read: Mock, ) -> None: @@ -876,9 +874,9 @@ def test_given_cursor_field_then_ensure_cursor_field_is_pass_to_schema_inferrenc mock_source.streams.return_value = [Mock()] mock_source.streams.return_value[0].primary_key = _NO_PK mock_source.streams.return_value[0].cursor_field = ["date"] - connector_builder_handler = MessageGrouper(MAX_PAGES_PER_SLICE, MAX_SLICES) + connector_builder_handler = TestReader(MAX_PAGES_PER_SLICE, MAX_SLICES) - stream_read: StreamRead = connector_builder_handler.get_message_groups( + stream_read: StreamRead = connector_builder_handler.run_test_read( source=mock_source, config=CONFIG, configured_catalog=create_configured_catalog("hashiras"), diff --git a/unit_tests/sources/declarative/decoders/test_decoders_memory_usage.py b/unit_tests/sources/declarative/decoders/test_decoders_memory_usage.py index 241b45822..6901c6382 100644 --- a/unit_tests/sources/declarative/decoders/test_decoders_memory_usage.py +++ b/unit_tests/sources/declarative/decoders/test_decoders_memory_usage.py @@ -16,6 +16,7 @@ ) +@pytest.mark.slow @pytest.fixture(name="large_events_response") def large_event_response_fixture(): data = {"email": "email1@example.com"} diff --git a/unit_tests/sources/declarative/decoders/test_json_decoder.py b/unit_tests/sources/declarative/decoders/test_json_decoder.py index c78d157ab..5992bf45a 100644 --- a/unit_tests/sources/declarative/decoders/test_json_decoder.py +++ b/unit_tests/sources/declarative/decoders/test_json_decoder.py @@ -48,6 +48,7 @@ def test_jsonl_decoder(requests_mock, response_body, expected_json): ) +@pytest.mark.slow @pytest.fixture(name="large_events_response") def large_event_response_fixture(): data = {"email": "email1@example.com"}