From 266ae494c13ae8c63df82d8746dcb47d7d96c650 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 1 Apr 2024 14:21:45 -0700 Subject: [PATCH 01/12] Amazon Bedrock - Customize Model Operator/Sensor/Waiter/Trigger --- airflow/providers/amazon/aws/hooks/bedrock.py | 45 +++++ .../providers/amazon/aws/operators/bedrock.py | 154 ++++++++++++++++- .../providers/amazon/aws/sensors/bedrock.py | 111 ++++++++++++ .../providers/amazon/aws/triggers/bedrock.py | 61 +++++++ .../providers/amazon/aws/waiters/bedrock.json | 42 +++++ airflow/providers/amazon/provider.yaml | 6 + .../operators/bedrock.rst | 38 +++++ .../amazon/aws/hooks/test_bedrock.py | 64 +++++++ .../amazon/aws/operators/test_bedrock.py | 158 +++++++++++++++--- .../amazon/aws/sensors/test_bedrock.py | 93 +++++++++++ .../amazon/aws/triggers/test_bedrock.py | 52 ++++++ .../providers/amazon/aws/example_bedrock.py | 106 +++++++++++- 12 files changed, 901 insertions(+), 29 deletions(-) create mode 100644 airflow/providers/amazon/aws/sensors/bedrock.py create mode 100644 airflow/providers/amazon/aws/triggers/bedrock.py create mode 100644 airflow/providers/amazon/aws/waiters/bedrock.json create mode 100644 tests/providers/amazon/aws/sensors/test_bedrock.py create mode 100644 tests/providers/amazon/aws/triggers/test_bedrock.py diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py index 11bacd9414598..b4d3bd440c400 100644 --- a/airflow/providers/amazon/aws/hooks/bedrock.py +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -16,9 +16,54 @@ # under the License. from __future__ import annotations +from botocore.exceptions import ClientError + from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook +class BedrockHook(AwsBaseHook): + """ + Interact with Amazon Bedrock. + + Provide thin wrapper around :external+boto3:py:class:`boto3.client("bedrock") `. + + Additional arguments (such as ``aws_conn_id``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + client_type = "bedrock" + + def __init__(self, *args, **kwargs) -> None: + kwargs["client_type"] = self.client_type + super().__init__(*args, **kwargs) + + def _get_job_by_name(self, job_name: str): + return self.conn.get_model_customization_job(jobIdentifier=job_name) + + def get_customize_model_job_state(self, job_name) -> str: + state = self._get_job_by_name(job_name)["status"] + self.log.info("Job '%s' state: %s", job_name, state) + return state + + def job_name_exists(self, job_name: str) -> bool: + try: + self._get_job_by_name(job_name) + self.log.info("Verified that job name '%s' does exist.", job_name) + return True + except ClientError as e: + if e.response["Error"]["Code"] == "ValidationException": + self.log.info("Job name '%s' does not exist.", job_name) + return False + else: + raise e + + def get_job_arn(self, job_name: str) -> str: + return self._get_job_by_name(job_name)["jobArn"] + + class BedrockRuntimeHook(AwsBaseHook): """ Interact with the Amazon Bedrock Runtime. diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index d8eaf9e5d3c23..35ed2c9dfd142 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -19,10 +19,15 @@ import json from typing import TYPE_CHECKING, Any, Sequence -from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.providers.amazon.aws.utils import validate_execute_complete_event from airflow.providers.amazon.aws.utils.mixins import aws_template_fields from airflow.utils.helpers import prune_dict +from airflow.utils.timezone import utcnow if TYPE_CHECKING: from airflow.utils.context import Context @@ -91,3 +96,150 @@ def execute(self, context: Context) -> dict[str, str | int]: self.log.info("Bedrock %s prompt: %s", self.model_id, self.input_data) self.log.info("Bedrock model response: %s", response_body) return response_body + + +class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]): + """ + Create a fine-tuning job to customize a base model. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:BedrockCustomizeModelOperator` + + :param job_name: A unique name for the fine-tuning job. + :param custom_model_name: A name for the custom model being created. + :param role_arn: The Amazon Resource Name (ARN) of an IAM role that Amazon Bedrock can assume + to perform tasks on your behalf. + :param base_model_id: Name of the base model. + :param training_data_uri: The S3 URI where the training data is stored. + :param output_data_uri: The S3 URI where the output data is stored. + :param hyperparameters: Parameters related to tuning the model. + :param check_if_job_exists: If set to true, operator will check whether a model customization + job already exists for the name in the config. (Default: True) + :param action_if_job_exists: Behavior if the job name already exists. Options are "timestamp" (default), + and "fail" + :param customization_job_kwargs: Any optional parameters to pass to the API. + + :param wait_for_completion: Whether to wait for cluster to stop. (default: True) + :param waiter_delay: Time in seconds to wait between status checks. + :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + aws_hook_class = BedrockHook + template_fields: Sequence[str] = aws_template_fields( + "job_name", + "custom_model_name", + "role_arn", + "base_model_id", + "hyperparameters", + "check_if_job_exists", + "action_if_job_exists", + "customization_job_kwargs", + ) + + def __init__( + self, + job_name: str, + custom_model_name: str, + role_arn: str, + base_model_id: str, + training_data_uri: str, + output_data_uri: str, + hyperparameters: dict[str, str], + check_if_job_exists: bool = True, + action_if_job_exists: str = "timestamp", + customization_job_kwargs: dict[str, Any] | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 120, + waiter_max_attempts: int = 75, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + + self.job_name = job_name + self.custom_model_name = custom_model_name + self.role_arn = role_arn + self.base_model_id = base_model_id + self.training_data_config = {"s3Uri": training_data_uri} + self.output_data_config = {"s3Uri": output_data_uri} + self.hyperparameters = hyperparameters + self.check_if_job_exists = check_if_job_exists + self.customization_job_kwargs = customization_job_kwargs or {} + if action_if_job_exists in {"timestamp", "fail"}: + self.action_if_job_exists = action_if_job_exists + else: + raise AirflowException( + f"Argument action_if_job_exists accepts only 'timestamp', and 'fail'. \ + Provided value: '{action_if_job_exists}." + ) + + def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: + event = validate_execute_complete_event(event) + + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + + self.log.info("Bedrock model customization job `%s` complete.", self.job_name) + return self.hook.get_job_arn(event["job_name"]) + + def execute(self, context: Context) -> dict: + if self.check_if_job_exists and self.hook.job_name_exists(self.job_name): + if self.action_if_job_exists == "fail": + raise AirflowException(f"A Bedrock job with name {self.job_name} already exists.") + self.job_name = f"{self.job_name}-{int(utcnow().timestamp())}" + self.log.info("Changed job name to '%s' to avoid collision.", self.job_name) + + self.log.info("Creating Bedrock model customization job '%s'.", self.job_name) + + response = self.hook.conn.create_model_customization_job( + jobName=self.job_name, + customModelName=self.custom_model_name, + roleArn=self.role_arn, + baseModelIdentifier=self.base_model_id, + trainingDataConfig=self.training_data_config, + outputDataConfig=self.output_data_config, + hyperParameters=self.hyperparameters, + **self.customization_job_kwargs, + ) + + if response["ResponseMetadata"]["HTTPStatusCode"] != 201: + raise AirflowException(f"Bedrock model customization job creation failed: {response}") + + task_description = f"Bedrock model customization job {self.job_name} to complete." + if self.deferrable: + self.log.info("Deferring for %s", task_description) + self.defer( + trigger=BedrockCustomizeModelCompletedTrigger( + job_name=self.job_name, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: + self.log.info("Waiting for %s", task_description) + self.hook.get_waiter("model_customization_job_complete").wait( + jobIdentifier=self.job_name, + WaiterConfig={"Delay": self.waiter_delay, "MaxAttempts": self.waiter_max_attempts}, + ) + + return response["jobArn"] diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py new file mode 100644 index 0000000000000..25b2b00b0632e --- /dev/null +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -0,0 +1,111 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Sequence + +from airflow.configuration import conf +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor +from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.providers.amazon.aws.utils.mixins import aws_template_fields + +if TYPE_CHECKING: + from airflow.utils.context import Context + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook + + +class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]): + """ + Poll the state of the model customization job until it reaches a terminal state; fails if the job fails. + + .. seealso:: + For more information on how to use this sensor, take a look at the guide: + :ref:`howto/sensor:BedrockCustomizeModelCompletedSensor` + + + :param job_name: The name of the Bedrock model customization job. + + :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore + module to be installed. + (default: False, but can be overridden in config file by setting default_deferrable to True) + :param max_retries: Number of times before returning the current state, defaults to None + :param poke_interval: Polling period in seconds to check for the status of the job. + :param aws_conn_id: The Airflow connection used for AWS credentials. + If this is ``None`` or empty then the default boto3 behaviour is used. If + running Airflow in a distributed manner and aws_conn_id is None or + empty, then default boto3 configuration would be used (and must be + maintained on each worker node). + :param region_name: AWS region_name. If not specified then the default boto3 behaviour is used. + :param verify: Whether or not to verify SSL certificates. See: + https://boto3.amazonaws.com/v1/documentation/api/latest/reference/core/session.html + :param botocore_config: Configuration dictionary (key-values) for botocore client. See: + https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html + """ + + INTERMEDIATE_STATES = ("InProgress",) + FAILURE_STATES = ("Failed", "Stopping", "Stopped") + SUCCESS_STATES = ("Completed",) + + aws_hook_class = BedrockHook + template_fields: Sequence[str] = aws_template_fields("job_name") + ui_color = "#66c3ff" + + def __init__( + self, + *, + job_name: str, + max_retries: int = 75, + poke_interval: int = 120, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs: Any, + ) -> None: + super().__init__(**kwargs) + self.job_name = job_name + self.poke_interval = poke_interval + self.max_retries = max_retries + self.deferrable = deferrable + + def execute(self, context: Context) -> Any: + if self.deferrable: + self.defer( + trigger=BedrockCustomizeModelCompletedTrigger( + job_name=self.job_name, + waiter_delay=int(self.poke_interval), + waiter_max_attempts=self.max_retries, + aws_conn_id=self.aws_conn_id, + ), + method_name="execute_complete", + ) + else: + super().execute(context=context) + + def poke(self, context: Context) -> bool: + state = self.hook.get_customize_model_job_state(self.job_name) + + if state in self.FAILURE_STATES: + # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 + message = "Bedrock model customization job sensor failed" + if self.soft_fail: + raise AirflowSkipException(message) + raise AirflowException(message) + + if state in self.INTERMEDIATE_STATES: + return False + return True diff --git a/airflow/providers/amazon/aws/triggers/bedrock.py b/airflow/providers/amazon/aws/triggers/bedrock.py new file mode 100644 index 0000000000000..0af8eea1d52d4 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/bedrock.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + + +class BedrockCustomizeModelCompletedTrigger(AwsBaseWaiterTrigger): + """ + Trigger when a Bedrock model customization job is complete. + + :param job_name: The name of the Bedrock model customization job. + :param waiter_delay: The amount of time in seconds to wait between attempts. + :param waiter_max_attempts: The maximum number of attempts to be made. + :param aws_conn_id: The Airflow connection used for AWS credentials. + """ + + def __init__( + self, + *, + job_name: str, + waiter_delay: int = 120, + waiter_max_attempts: int = 75, + aws_conn_id: str | None = None, + ) -> None: + super().__init__( + serialized_fields={"job_name": job_name}, + waiter_name="model_customization_job_complete", + waiter_args={"jobIdentifier": job_name}, + failure_message="Bedrock model customization failed.", + status_message="Status of Bedrock model customization job is", + status_queries=["status"], + return_key="job_name", + return_value=job_name, + waiter_delay=waiter_delay, + waiter_max_attempts=waiter_max_attempts, + aws_conn_id=aws_conn_id, + ) + + def hook(self) -> AwsGenericHook: + return BedrockHook(aws_conn_id=self.aws_conn_id) diff --git a/airflow/providers/amazon/aws/waiters/bedrock.json b/airflow/providers/amazon/aws/waiters/bedrock.json new file mode 100644 index 0000000000000..c44b7c058917b --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/bedrock.json @@ -0,0 +1,42 @@ +{ + "version": 2, + "waiters": { + "model_customization_job_complete": { + "delay": 120, + "maxAttempts": 75, + "operation": "GetModelCustomizationJob", + "acceptors": [ + { + "matcher": "path", + "argument": "status", + "expected": "InProgress", + "state": "retry" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Completed", + "state": "success" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Failed", + "state": "failure" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Stopping", + "state": "failure" + }, + { + "matcher": "path", + "argument": "status", + "expected": "Stopped", + "state": "failure" + } + ] + } + } +} diff --git a/airflow/providers/amazon/provider.yaml b/airflow/providers/amazon/provider.yaml index 4c4f7cf5970ab..dc072b324e20e 100644 --- a/airflow/providers/amazon/provider.yaml +++ b/airflow/providers/amazon/provider.yaml @@ -454,6 +454,9 @@ sensors: - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.sensors.batch + - integration-name: Amazon Bedrock + python-modules: + - airflow.providers.amazon.aws.sensors.bedrock - integration-name: Amazon CloudFormation python-modules: - airflow.providers.amazon.aws.sensors.cloud_formation @@ -650,6 +653,9 @@ triggers: - integration-name: AWS Batch python-modules: - airflow.providers.amazon.aws.triggers.batch + - integration-name: Amazon Bedrock + python-modules: + - airflow.providers.amazon.aws.triggers.bedrock - integration-name: Amazon EC2 python-modules: - airflow.providers.amazon.aws.triggers.ec2 diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst b/docs/apache-airflow-providers-amazon/operators/bedrock.rst index 3e84cbc445357..0c6278e744ee5 100644 --- a/docs/apache-airflow-providers-amazon/operators/bedrock.rst +++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst @@ -65,6 +65,44 @@ To invoke an Amazon Titan model you would use: For details on the different formats, see `Inference parameters for foundation models `__ +.. _howto/operator:BedrockCustomizeModelOperator: + +Customize an existing Amazon Bedrock Model +========================================== + +To create a fine-tuning job to customize a base model, you can use +:class:`~airflow.providers.amazon.aws.operators.bedrock.BedrockCustomizeModelOperator`. + +Model-customization jobs are asynchronous and the completion time depends on the base model +and the training/validation data size. To monitor the state of the job, you can use the +"model_customization_job_complete" Waiter, the +:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockCustomizeModelCompletedSensor` Sensor, +or the :class:`~airflow.providers.amazon.aws.triggersBedrockCustomizeModelCompletedTrigger` Trigger. + + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_customize_model] + :end-before: [END howto_operator_customize_model] + + +Sensors +------- + +.. _howto/sensor:BedrockCustomizeModelCompletedSensor: + +Wait for an Amazon Bedrock customize model job +============================================== + +To wait on the state of an AWS CloudFormation stack creation until it reaches a terminal state you can use +:class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockCustomizeModelCompletedSensor` + +.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py + :language: python + :dedent: 4 + :start-after: [START howto_sensor_customize_model] + :end-before: [END howto_sensor_customize_model] Reference --------- diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index 73612aacbc86d..bd30de469c3e0 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -16,6 +16,70 @@ # under the License. from __future__ import annotations +from unittest import mock + +import pytest +from botocore.exceptions import ClientError + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook + +JOB_NAME = "testJobName" +EXPECTED_STATUS = "InProgress" + + +@pytest.fixture +def mock_conn(): + with mock.patch.object(BedrockHook, "conn") as _conn: + _conn.get_model_customization_job.return_value = {"jobName": JOB_NAME, "status": EXPECTED_STATUS} + yield _conn + + +class TestBedrockHook: + def setup_method(self): + self.hook = BedrockHook() + + self.validation_exception_error = ClientError( + error_response={"Error": {"Code": "ValidationException", "Message": ""}}, + operation_name="GetModelCustomizationJob", + ) + + self.unexpected_exception = ClientError( + error_response={"Error": {"Code": "ExpiredTokenException", "Message": ""}}, + operation_name="GetModelCustomizationJob", + ) + + def test_conn_returns_a_boto3_connection(self): + assert self.hook.conn is not None + assert self.hook.conn.meta.service_model.service_name == "bedrock" + + def test_get_customize_model_job_state(self, mock_conn): + response = self.hook.get_customize_model_job_state(JOB_NAME) + + mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=JOB_NAME) + assert response == EXPECTED_STATUS + + def test_job_name_exists_positive(self, mock_conn): + response = self.hook.job_name_exists(JOB_NAME) + + mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=JOB_NAME) + assert response is True + + def test_job_name_exists_negative(self, mock_conn): + invalid_job_name = "invalid_job_name" + mock_conn.get_model_customization_job.side_effect = self.validation_exception_error + + response = self.hook.job_name_exists(invalid_job_name) + + mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=invalid_job_name) + assert response is False + + def test_job_name_exists_unexpected_exception(self, mock_conn): + mock_conn.get_model_customization_job.side_effect = self.unexpected_exception + + with pytest.raises(ClientError): + self.hook.job_name_exists(JOB_NAME) + + mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=JOB_NAME) from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index f6274de48f0b0..956490b54c302 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -18,42 +18,152 @@ from __future__ import annotations import json -from typing import Generator +from typing import TYPE_CHECKING, Generator from unittest import mock import pytest from moto import mock_aws -from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook -from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator - -MODEL_ID = "meta.llama2-13b-chat-v1" -PROMPT = "A very important question." -GENERATED_RESPONSE = "An important answer." -MOCK_RESPONSE = json.dumps( - { - "generation": GENERATED_RESPONSE, - "prompt_token_count": len(PROMPT), - "generation_token_count": len(GENERATED_RESPONSE), - "stop_reason": "stop", - } +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook +from airflow.providers.amazon.aws.operators.bedrock import ( + BedrockCustomizeModelOperator, + BedrockInvokeModelOperator, ) - -@pytest.fixture -def runtime_hook() -> Generator[BedrockRuntimeHook, None, None]: - with mock_aws(): - yield BedrockRuntimeHook(aws_conn_id="aws_default") +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import BaseAwsConnection class TestBedrockInvokeModelOperator: - @mock.patch.object(BedrockRuntimeHook, "conn") - def test_invoke_model_prompt_good_combinations(self, mock_conn): - mock_conn.invoke_model.return_value["body"].read.return_value = MOCK_RESPONSE + def setup_method(self): + self.model_id = "meta.llama2-13b-chat-v1" + self.prompt = "A very important question." + self.generated_response = "An important answer." + + @pytest.fixture + def mock_runtime_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(BedrockRuntimeHook, "conn") as _conn: + _conn.invoke_model.return_value["body"].read.return_value = json.dumps( + { + "generation": self.generated_response, + "prompt_token_count": len(self.prompt), + "generation_token_count": len(self.generated_response), + "stop_reason": "stop", + } + ) + yield _conn + + @pytest.fixture + def runtime_hook(self) -> Generator[BedrockRuntimeHook, None, None]: + with mock_aws(): + yield BedrockRuntimeHook(aws_conn_id="aws_default") + + def test_invoke_model_prompt_good_combinations(self, mock_runtime_conn): operator = BedrockInvokeModelOperator( - task_id="test_task", model_id=MODEL_ID, input_data={"input_data": {"prompt": PROMPT}} + task_id="test_task", model_id=self.model_id, input_data={"input_data": {"prompt": self.prompt}} ) response = operator.execute({}) - assert response["generation"] == GENERATED_RESPONSE + assert response["generation"] == self.generated_response + + +class TestBedrockCustomizeModelOperator: + @pytest.fixture + def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: + with mock.patch.object(BedrockHook, "conn") as _conn: + _conn.create_model_customization_job.return_value = { + "ResponseMetadata": {"HTTPStatusCode": 201}, + "jobArn": self.custom_job_arn, + } + _conn.get_model_customization_job.return_value = { + "jobName": self.customize_model_job_name, + "status": "InProgress", + } + yield _conn + + @pytest.fixture + def bedrock_hook(self) -> Generator[BedrockHook, None, None]: + with mock_aws(): + hook = BedrockHook(aws_conn_id="aws_default") + yield hook + + def setup_method(self): + self.custom_job_arn = "valid_arn" + self.customize_model_job_name = "testModelJob" + + self.operator = BedrockCustomizeModelOperator( + task_id="test_task", + job_name=self.customize_model_job_name, + custom_model_name="testModelName", + role_arn="valid_arn", + base_model_id="base_model_id", + hyperparameters={ + "epochCount": "1", + "batchSize": "1", + "learningRate": ".0005", + "learningRateWarmupSteps": "0", + }, + training_data_uri="s3://uri", + output_data_uri="s3://uri/output", + ) + self.operator.defer = mock.MagicMock() + + @pytest.mark.parametrize( + "wait_for_completion, deferrable", + [ + pytest.param(False, False, id="no_wait"), + pytest.param(True, False, id="wait"), + pytest.param(False, True, id="defer"), + ], + ) + @mock.patch.object(BedrockHook, "get_waiter") + def test_customize_model_wait_combinations( + self, _, wait_for_completion, deferrable, mock_conn, bedrock_hook + ): + self.operator.wait_for_completion = wait_for_completion + self.operator.deferrable = deferrable + + response = self.operator.execute({}) + + assert response == self.custom_job_arn + assert bedrock_hook.get_waiter.call_count == wait_for_completion + assert self.operator.defer.call_count == deferrable + + @pytest.mark.parametrize( + "action_if_job_exists, succeeds", + [ + pytest.param("timestamp", True, id="timestamp"), + pytest.param("fail", True, id="fail"), + pytest.param("call me maybe", False, id="invalid"), + ], + ) + def test_customize_model_action_if_job_exists_happy_combinations(self, action_if_job_exists, succeeds): + exception = None + + try: + operator = BedrockCustomizeModelOperator( + task_id="test_task", + job_name=self.customize_model_job_name, + custom_model_name="testModelName", + role_arn="valid_arn", + base_model_id="base_model_id", + hyperparameters={ + "epochCount": "1", + "batchSize": "1", + "learningRate": ".0005", + "learningRateWarmupSteps": "0", + }, + training_data_uri="s3://uri", + output_data_uri="s3://uri/output", + action_if_job_exists=action_if_job_exists, + ) + except Exception as e: + exception = e + + if succeeds: + assert operator.action_if_job_exists == action_if_job_exists + assert exception is None + else: + assert isinstance(exception, AirflowException) diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py b/tests/providers/amazon/aws/sensors/test_bedrock.py new file mode 100644 index 0000000000000..8d56177ca16f2 --- /dev/null +++ b/tests/providers/amazon/aws/sensors/test_bedrock.py @@ -0,0 +1,93 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock + +import pytest + +from airflow.exceptions import AirflowException, AirflowSkipException +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor + +JOB_NAME = "test_job_name" + + +@pytest.fixture +def mock_get_job_state(): + with mock.patch.object(BedrockHook, "get_customize_model_job_state") as mock_state: + yield mock_state + + +class TestBedrockCustomizeModelCompletedSensor: + def setup_method(self): + self.default_op_kwargs = dict( + task_id="test_bedrock_customize_model_sensor", + job_name=JOB_NAME, + poke_interval=5, + max_retries=1, + ) + self.sensor = BedrockCustomizeModelCompletedSensor(**self.default_op_kwargs, aws_conn_id=None) + + def test_base_aws_op_attributes(self): + op = BedrockCustomizeModelCompletedSensor(**self.default_op_kwargs) + assert op.hook.aws_conn_id == "aws_default" + assert op.hook._region_name is None + assert op.hook._verify is None + assert op.hook._config is None + + op = BedrockCustomizeModelCompletedSensor( + **self.default_op_kwargs, + aws_conn_id="aws-test-custom-conn", + region_name="eu-west-1", + verify=False, + botocore_config={"read_timeout": 42}, + ) + assert op.hook.aws_conn_id == "aws-test-custom-conn" + assert op.hook._region_name == "eu-west-1" + assert op.hook._verify is False + assert op.hook._config is not None + assert op.hook._config.read_timeout == 42 + + @pytest.mark.parametrize("state", ["Completed"]) + def test_poke_success_states(self, state, mock_get_job_state): + mock_get_job_state.side_effect = [state] + assert self.sensor.poke({}) is True + + @pytest.mark.parametrize("state", ["InProgress"]) + def test_poke_intermediate_states(self, state, mock_get_job_state): + mock_get_job_state.side_effect = [state] + assert self.sensor.poke({}) is False + + @pytest.mark.parametrize( + "soft_fail, expected_exception", + [ + pytest.param(False, AirflowException, id="not-soft-fail"), + pytest.param(True, AirflowSkipException, id="soft-fail"), + ], + ) + @pytest.mark.parametrize("state", ["Failed", "Stopping", "Stopped"]) + def test_poke_failure_states(self, state, soft_fail, expected_exception, mock_get_job_state): + mock_get_job_state.side_effect = [state] + sensor = BedrockCustomizeModelCompletedSensor( + **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail + ) + message = "Bedrock model customization job sensor failed" + + with pytest.raises(expected_exception, match=message): + sensor.poke({}) diff --git a/tests/providers/amazon/aws/triggers/test_bedrock.py b/tests/providers/amazon/aws/triggers/test_bedrock.py new file mode 100644 index 0000000000000..f51d32ffceec2 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_bedrock.py @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest import mock +from unittest.mock import AsyncMock + +import pytest + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger +from airflow.triggers.base import TriggerEvent + +JOB_NAME = "test_job" +BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.bedrock." + + +class TestBedrockCustomizeModelCompletedTrigger: + def test_serialization(self): + """Assert that arguments and classpath are correctly serialized.""" + trigger = BedrockCustomizeModelCompletedTrigger(job_name=JOB_NAME) + classpath, kwargs = trigger.serialize() + assert classpath == BASE_TRIGGER_CLASSPATH + "BedrockCustomizeModelCompletedTrigger" + assert kwargs.get("job_name") == JOB_NAME + + @pytest.mark.asyncio + @mock.patch.object(BedrockHook, "get_waiter") + @mock.patch.object(BedrockHook, "async_conn") + async def test_run_success(self, mock_async_conn, mock_get_waiter): + mock_async_conn.__aenter__.return_value = mock.MagicMock() + mock_get_waiter().wait = AsyncMock() + trigger = BedrockCustomizeModelCompletedTrigger(job_name=JOB_NAME) + + generator = trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "job_name": JOB_NAME}) + assert mock_get_waiter().wait.call_count == 1 diff --git a/tests/system/providers/amazon/aws/example_bedrock.py b/tests/system/providers/amazon/aws/example_bedrock.py index e86e5a2e92b9d..12e246154733d 100644 --- a/tests/system/providers/amazon/aws/example_bedrock.py +++ b/tests/system/providers/amazon/aws/example_bedrock.py @@ -16,17 +16,61 @@ # under the License. from __future__ import annotations +import json from datetime import datetime +from botocore.exceptions import ClientError + +from airflow.decorators import task from airflow.models.baseoperator import chain from airflow.models.dag import DAG -from airflow.providers.amazon.aws.operators.bedrock import BedrockInvokeModelOperator +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.operators.bedrock import ( + BedrockCustomizeModelOperator, + BedrockInvokeModelOperator, +) +from airflow.providers.amazon.aws.operators.s3 import ( + S3CreateBucketOperator, + S3CreateObjectOperator, + S3DeleteBucketOperator, +) +from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor +from airflow.utils.trigger_rule import TriggerRule from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder -sys_test_context_task = SystemTestContextBuilder().build() +# Externally fetched variables: +ROLE_ARN_KEY = "ROLE_ARN" +sys_test_context_task = SystemTestContextBuilder().add_variable(ROLE_ARN_KEY).build() DAG_ID = "example_bedrock" + +# Creating a custom model takes nearly two hours. If SKIP_LONG_TASKS is True then set +# the trigger rule to an improbable state. This way we can still have the code snippets +# for docs, and we can manually run the full tests occasionally. +SKIP_LONG_TASKS = True + +LLAMA_MODEL_ID = "meta.llama2-13b-chat-v1" PROMPT = "What color is an orange?" +TITAN_MODEL_ID = "amazon.titan-text-express-v1" +TRAIN_DATA = {"prompt": "what is AWS", "completion": "it's Amazon Web Services"} +HYPERPARAMETERS = { + "epochCount": "1", + "batchSize": "1", + "learningRate": ".0005", + "learningRateWarmupSteps": "0", +} + + +@task +def delete_custom_model(model_name: str): + try: + BedrockHook().conn.delete_custom_model(modelIdentifier=model_name) + except ClientError as e: + if SKIP_LONG_TASKS and (e.response["Error"]["Code"] == "ValidationException"): + # There is no model to delete. Since we skipped making one, that's fine. + return + raise e + with DAG( dag_id=DAG_ID, @@ -37,11 +81,28 @@ ) as dag: test_context = sys_test_context_task() env_id = test_context["ENV_ID"] + bucket_name = f"{env_id}-bedrock" + input_data_s3_key = f"{env_id}/train.jsonl" + training_data_uri = f"s3://{bucket_name}/{input_data_s3_key}" + custom_model_name = f"CustomModel{env_id}" + custom_model_job_name = f"CustomizeModelJob{env_id}" + + create_bucket = S3CreateBucketOperator( + task_id="create_bucket", + bucket_name=bucket_name, + ) + + upload_training_data = S3CreateObjectOperator( + task_id="upload_data", + s3_bucket=bucket_name, + s3_key=training_data_uri, + data=json.dumps(TRAIN_DATA), + ) # [START howto_operator_invoke_llama_model] invoke_llama_model = BedrockInvokeModelOperator( task_id="invoke_llama", - model_id="meta.llama2-13b-chat-v1", + model_id=LLAMA_MODEL_ID, input_data={"prompt": PROMPT}, ) # [END howto_operator_invoke_llama_model] @@ -49,18 +110,55 @@ # [START howto_operator_invoke_titan_model] invoke_titan_model = BedrockInvokeModelOperator( task_id="invoke_titan", - model_id="amazon.titan-text-express-v1", + model_id=TITAN_MODEL_ID, input_data={"inputText": PROMPT}, ) # [END howto_operator_invoke_titan_model] + # [START howto_operator_customize_model] + customize_model = BedrockCustomizeModelOperator( + task_id="customize_model", + job_name=custom_model_job_name, + custom_model_name=custom_model_name, + role_arn=test_context[ROLE_ARN_KEY], + base_model_id=f"arn:aws:bedrock:us-east-1::foundation-model/{TITAN_MODEL_ID}", + hyperparameters=HYPERPARAMETERS, + training_data_uri=training_data_uri, + output_data_uri=f"s3://{bucket_name}/myOutputData", + ) + # [END howto_operator_customize_model] + + # [START howto_sensor_customize_model] + await_custom_model_job = BedrockCustomizeModelCompletedSensor( + task_id="await_custom_model_job", + job_name=custom_model_job_name, + ) + # [END howto_sensor_customize_model] + + if SKIP_LONG_TASKS: + customize_model.trigger_rule = TriggerRule.ALL_SKIPPED + await_custom_model_job.trigger_rule = TriggerRule.ALL_SKIPPED + + delete_bucket = S3DeleteBucketOperator( + task_id="delete_bucket", + trigger_rule=TriggerRule.ALL_DONE, + bucket_name=bucket_name, + force_delete=True, + ) + chain( # TEST SETUP test_context, + create_bucket, + upload_training_data, # TEST BODY invoke_llama_model, invoke_titan_model, + customize_model, + await_custom_model_job, # TEST TEARDOWN + delete_custom_model(custom_model_name), + delete_bucket, ) from tests.system.utils.watcher import watcher From 906542a9be8eace0eb3e5e34edd54014a21b565e Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Tue, 2 Apr 2024 20:21:27 -0700 Subject: [PATCH 02/12] fix git merge oops --- tests/providers/amazon/aws/hooks/test_bedrock.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index bd30de469c3e0..e7b496a3fece4 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -80,7 +80,6 @@ def test_job_name_exists_unexpected_exception(self, mock_conn): self.hook.job_name_exists(JOB_NAME) mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=JOB_NAME) -from airflow.providers.amazon.aws.hooks.bedrock import BedrockRuntimeHook class TestBedrockRuntimeHook: From 3999f688dc2e531271e28cf602a3df49f2a18884 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 3 Apr 2024 11:59:53 -0700 Subject: [PATCH 03/12] pr fixes --- airflow/providers/amazon/aws/hooks/bedrock.py | 3 +-- .../providers/amazon/aws/operators/bedrock.py | 16 +++++++++------- airflow/providers/amazon/aws/sensors/bedrock.py | 2 +- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py index b4d3bd440c400..081d6e7fe4f8e 100644 --- a/airflow/providers/amazon/aws/hooks/bedrock.py +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -57,8 +57,7 @@ def job_name_exists(self, job_name: str) -> bool: if e.response["Error"]["Code"] == "ValidationException": self.log.info("Job name '%s' does not exist.", job_name) return False - else: - raise e + raise e def get_job_arn(self, job_name: str) -> str: return self._get_job_by_name(job_name)["jobArn"] diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index 35ed2c9dfd142..9afe9efa56d60 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -183,13 +183,9 @@ def __init__( self.hyperparameters = hyperparameters self.check_if_job_exists = check_if_job_exists self.customization_job_kwargs = customization_job_kwargs or {} - if action_if_job_exists in {"timestamp", "fail"}: - self.action_if_job_exists = action_if_job_exists - else: - raise AirflowException( - f"Argument action_if_job_exists accepts only 'timestamp', and 'fail'. \ - Provided value: '{action_if_job_exists}." - ) + self.action_if_job_exists = action_if_job_exists.lower() + + self.valid_action_if_job_exists: set[str] = {"timestamp", "fail"} def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: event = validate_execute_complete_event(event) @@ -201,6 +197,12 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None return self.hook.get_job_arn(event["job_name"]) def execute(self, context: Context) -> dict: + if self.action_if_job_exists not in self.valid_action_if_job_exists: + raise AirflowException( + f"Invalid value for argument action_if_job_exists {self.action_if_job_exists}; " + f"must be one of: {self.valid_action_if_job_exists}." + ) + if self.check_if_job_exists and self.hook.job_name_exists(self.job_name): if self.action_if_job_exists == "fail": raise AirflowException(f"A Bedrock job with name {self.job_name} already exists.") diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index 25b2b00b0632e..6244c79a0f8c9 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -91,7 +91,7 @@ def execute(self, context: Context) -> Any: waiter_max_attempts=self.max_retries, aws_conn_id=self.aws_conn_id, ), - method_name="execute_complete", + method_name="poke", ) else: super().execute(context=context) From 96803db51b4e84357bb65a2fa69c90b1fcbf972b Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 3 Apr 2024 15:38:59 -0700 Subject: [PATCH 04/12] fix pr fixes which broke unit tests --- .../providers/amazon/aws/operators/bedrock.py | 5 ++- .../amazon/aws/operators/test_bedrock.py | 35 ++++++++++--------- 2 files changed, 22 insertions(+), 18 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index 9afe9efa56d60..a3552ef4f19ec 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -196,13 +196,16 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None self.log.info("Bedrock model customization job `%s` complete.", self.job_name) return self.hook.get_job_arn(event["job_name"]) - def execute(self, context: Context) -> dict: + def _validate_action_if_job_exists(self): if self.action_if_job_exists not in self.valid_action_if_job_exists: raise AirflowException( f"Invalid value for argument action_if_job_exists {self.action_if_job_exists}; " f"must be one of: {self.valid_action_if_job_exists}." ) + def execute(self, context: Context) -> dict: + self._validate_action_if_job_exists() + if self.check_if_job_exists and self.hook.job_name_exists(self.job_name): if self.action_if_job_exists == "fail": raise AirflowException(f"A Bedrock job with name {self.job_name} already exists.") diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index 956490b54c302..89dc434b54717 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -139,26 +139,27 @@ def test_customize_model_wait_combinations( pytest.param("call me maybe", False, id="invalid"), ], ) - def test_customize_model_action_if_job_exists_happy_combinations(self, action_if_job_exists, succeeds): + def test_customize_model_validate_action_if_job_exists(self, action_if_job_exists, succeeds): exception = None + operator = BedrockCustomizeModelOperator( + task_id="test_task", + job_name=self.customize_model_job_name, + custom_model_name="testModelName", + role_arn="valid_arn", + base_model_id="base_model_id", + hyperparameters={ + "epochCount": "1", + "batchSize": "1", + "learningRate": ".0005", + "learningRateWarmupSteps": "0", + }, + training_data_uri="s3://uri", + output_data_uri="s3://uri/output", + action_if_job_exists=action_if_job_exists, + ) try: - operator = BedrockCustomizeModelOperator( - task_id="test_task", - job_name=self.customize_model_job_name, - custom_model_name="testModelName", - role_arn="valid_arn", - base_model_id="base_model_id", - hyperparameters={ - "epochCount": "1", - "batchSize": "1", - "learningRate": ".0005", - "learningRateWarmupSteps": "0", - }, - training_data_uri="s3://uri", - output_data_uri="s3://uri/output", - action_if_job_exists=action_if_job_exists, - ) + operator._validate_action_if_job_exists() except Exception as e: exception = e From 2fdfdf27cbaf9997e5c2d070719eeaf22c7c42c8 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 3 Apr 2024 16:52:56 -0700 Subject: [PATCH 05/12] sensor tweaks --- airflow/providers/amazon/aws/sensors/bedrock.py | 6 +++--- .../providers/amazon/aws/sensors/test_bedrock.py | 15 +++++++-------- 2 files changed, 10 insertions(+), 11 deletions(-) diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index 6244c79a0f8c9..a6bce8ff93f9d 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -62,6 +62,7 @@ class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]): INTERMEDIATE_STATES = ("InProgress",) FAILURE_STATES = ("Failed", "Stopping", "Stopped") SUCCESS_STATES = ("Completed",) + FAILURE_MESSAGE = "Bedrock model customization job sensor failed." aws_hook_class = BedrockHook template_fields: Sequence[str] = aws_template_fields("job_name") @@ -101,10 +102,9 @@ def poke(self, context: Context) -> bool: if state in self.FAILURE_STATES: # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 - message = "Bedrock model customization job sensor failed" if self.soft_fail: - raise AirflowSkipException(message) - raise AirflowException(message) + raise AirflowSkipException(self.FAILURE_MESSAGE) + raise AirflowException(self.FAILURE_MESSAGE) if state in self.INTERMEDIATE_STATES: return False diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py b/tests/providers/amazon/aws/sensors/test_bedrock.py index 8d56177ca16f2..a67b89b6a8bf5 100644 --- a/tests/providers/amazon/aws/sensors/test_bedrock.py +++ b/tests/providers/amazon/aws/sensors/test_bedrock.py @@ -25,8 +25,6 @@ from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor -JOB_NAME = "test_job_name" - @pytest.fixture def mock_get_job_state(): @@ -36,9 +34,11 @@ def mock_get_job_state(): class TestBedrockCustomizeModelCompletedSensor: def setup_method(self): + self.job_name = "test_job_name" + self.default_op_kwargs = dict( task_id="test_bedrock_customize_model_sensor", - job_name=JOB_NAME, + job_name=self.job_name, poke_interval=5, max_retries=1, ) @@ -64,12 +64,12 @@ def test_base_aws_op_attributes(self): assert op.hook._config is not None assert op.hook._config.read_timeout == 42 - @pytest.mark.parametrize("state", ["Completed"]) + @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.SUCCESS_STATES)) def test_poke_success_states(self, state, mock_get_job_state): mock_get_job_state.side_effect = [state] assert self.sensor.poke({}) is True - @pytest.mark.parametrize("state", ["InProgress"]) + @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.INTERMEDIATE_STATES)) def test_poke_intermediate_states(self, state, mock_get_job_state): mock_get_job_state.side_effect = [state] assert self.sensor.poke({}) is False @@ -81,13 +81,12 @@ def test_poke_intermediate_states(self, state, mock_get_job_state): pytest.param(True, AirflowSkipException, id="soft-fail"), ], ) - @pytest.mark.parametrize("state", ["Failed", "Stopping", "Stopped"]) + @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.FAILURE_STATES)) def test_poke_failure_states(self, state, soft_fail, expected_exception, mock_get_job_state): mock_get_job_state.side_effect = [state] sensor = BedrockCustomizeModelCompletedSensor( **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail ) - message = "Bedrock model customization job sensor failed" - with pytest.raises(expected_exception, match=message): + with pytest.raises(expected_exception, match=sensor.FAILURE_MESSAGE): sensor.poke({}) From a5a3b8f1369e42127385c083672b3e99b46cd449 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Wed, 3 Apr 2024 18:00:06 -0700 Subject: [PATCH 06/12] waiter unit tests --- .../amazon/aws/waiters/test_bedrock.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 tests/providers/amazon/aws/waiters/test_bedrock.py diff --git a/tests/providers/amazon/aws/waiters/test_bedrock.py b/tests/providers/amazon/aws/waiters/test_bedrock.py new file mode 100644 index 0000000000000..b52b5439feaeb --- /dev/null +++ b/tests/providers/amazon/aws/waiters/test_bedrock.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from unittest import mock + +import boto3 +import botocore +import pytest + +from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook +from airflow.providers.amazon.aws.sensors.bedrock import BedrockCustomizeModelCompletedSensor + + +class TestBedrockCustomWaiters: + def test_service_waiters(self): + assert "model_customization_job_complete" in BedrockHook().list_waiters() + + +class TestBedrockCustomWaitersBase: + @pytest.fixture(autouse=True) + def mock_conn(self, monkeypatch): + self.client = boto3.client("bedrock") + monkeypatch.setattr(BedrockHook, "conn", self.client) + + +class TestModelCustomizationJobCompleteWaiter(TestBedrockCustomWaitersBase): + @pytest.fixture + def mock_get_job(self): + with mock.patch.object(self.client, "get_model_customization_job") as m: + yield m + + def setup_method(self): + self.waiter_name = "model_customization_job_complete" + + @pytest.mark.parametrize("state", BedrockCustomizeModelCompletedSensor.SUCCESS_STATES) + def test_model_customization_job_complete(self, state, mock_get_job): + mock_get_job.return_value = {"status": state} + + BedrockHook().get_waiter(self.waiter_name).wait(jobIdentifier="job_id") + + @pytest.mark.parametrize("state", BedrockCustomizeModelCompletedSensor.FAILURE_STATES) + def test_model_customization_job_failed(self, state, mock_get_job): + mock_get_job.return_value = {"status": state} + + with pytest.raises(botocore.exceptions.WaiterError): + BedrockHook().get_waiter(self.waiter_name).wait(jobIdentifier="job_id") + + def test_model_customization_job_wait(self, mock_get_job): + wait = {"status": "InProgress"} + success = {"status": "Completed"} + mock_get_job.side_effect = [wait, wait, success] + + BedrockHook().get_waiter(self.waiter_name).wait( + jobIdentifier="job_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} + ) From 49a34b8daeebfcd79e89ffe6231e77e22e022930 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 4 Apr 2024 14:39:06 -0700 Subject: [PATCH 07/12] doc fixes --- docs/apache-airflow-providers-amazon/operators/bedrock.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/apache-airflow-providers-amazon/operators/bedrock.rst b/docs/apache-airflow-providers-amazon/operators/bedrock.rst index 0c6278e744ee5..411deba79ffc7 100644 --- a/docs/apache-airflow-providers-amazon/operators/bedrock.rst +++ b/docs/apache-airflow-providers-amazon/operators/bedrock.rst @@ -77,7 +77,7 @@ Model-customization jobs are asynchronous and the completion time depends on the and the training/validation data size. To monitor the state of the job, you can use the "model_customization_job_complete" Waiter, the :class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockCustomizeModelCompletedSensor` Sensor, -or the :class:`~airflow.providers.amazon.aws.triggersBedrockCustomizeModelCompletedTrigger` Trigger. +or the :class:`~airflow.providers.amazon.aws.triggers.BedrockCustomizeModelCompletedTrigger` Trigger. .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py @@ -95,7 +95,7 @@ Sensors Wait for an Amazon Bedrock customize model job ============================================== -To wait on the state of an AWS CloudFormation stack creation until it reaches a terminal state you can use +To wait on the state of an Amazon Bedrock customize model job until it reaches a terminal state you can use :class:`~airflow.providers.amazon.aws.sensors.bedrock.BedrockCustomizeModelCompletedSensor` .. exampleinclude:: /../../tests/system/providers/amazon/aws/example_bedrock.py From eee56064b55925d1b94132e2f39e37170386f12a Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Thu, 4 Apr 2024 16:58:42 -0700 Subject: [PATCH 08/12] niko fixes --- airflow/providers/amazon/aws/hooks/bedrock.py | 2 +- .../providers/amazon/aws/operators/bedrock.py | 7 ++-- .../providers/amazon/aws/sensors/bedrock.py | 6 ++-- .../providers/amazon/aws/triggers/bedrock.py | 4 +-- .../amazon/aws/hooks/test_bedrock.py | 24 ++++++------- .../amazon/aws/operators/test_bedrock.py | 35 ++++++++++--------- .../amazon/aws/waiters/test_bedrock.py | 11 +++--- 7 files changed, 45 insertions(+), 44 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py index 081d6e7fe4f8e..e938d6f505453 100644 --- a/airflow/providers/amazon/aws/hooks/bedrock.py +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -43,7 +43,7 @@ def __init__(self, *args, **kwargs) -> None: def _get_job_by_name(self, job_name: str): return self.conn.get_model_customization_job(jobIdentifier=job_name) - def get_customize_model_job_state(self, job_name) -> str: + def get_customize_model_job_state(self, job_name: str) -> str: state = self._get_job_by_name(job_name)["status"] self.log.info("Job '%s' state: %s", job_name, state) return state diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index a3552ef4f19ec..99062fd3c0c11 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -117,12 +117,13 @@ class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]): :param check_if_job_exists: If set to true, operator will check whether a model customization job already exists for the name in the config. (Default: True) :param action_if_job_exists: Behavior if the job name already exists. Options are "timestamp" (default), - and "fail" + and "fail". If "timestamp" is used and the job name already exists, the current timestamp + will be appended to the name in order to make it unique. :param customization_job_kwargs: Any optional parameters to pass to the API. :param wait_for_completion: Whether to wait for cluster to stop. (default: True) - :param waiter_delay: Time in seconds to wait between status checks. - :param waiter_max_attempts: Maximum number of attempts to check for job completion. + :param waiter_delay: Time in seconds to wait between status checks. (default: 120) + :param waiter_max_attempts: Maximum number of attempts to check for job completion. (default: 75) :param deferrable: If True, the operator will wait asynchronously for the cluster to stop. This implies waiting for completion. This mode requires aiobotocore module to be installed. (default: False) diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index a6bce8ff93f9d..3d48fd49e136c 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -42,11 +42,11 @@ class BedrockCustomizeModelCompletedSensor(AwsBaseSensor[BedrockHook]): :param job_name: The name of the Bedrock model customization job. - :param deferrable: If True, the sensor will operate in deferrable more. This mode requires aiobotocore + :param deferrable: If True, the sensor will operate in deferrable mode. This mode requires aiobotocore module to be installed. (default: False, but can be overridden in config file by setting default_deferrable to True) - :param max_retries: Number of times before returning the current state, defaults to None - :param poke_interval: Polling period in seconds to check for the status of the job. + :param max_retries: Number of times before returning the current state. (default: 75) + :param poke_interval: Polling period in seconds to check for the status of the job. (default: 120) :param aws_conn_id: The Airflow connection used for AWS credentials. If this is ``None`` or empty then the default boto3 behaviour is used. If running Airflow in a distributed manner and aws_conn_id is None or diff --git a/airflow/providers/amazon/aws/triggers/bedrock.py b/airflow/providers/amazon/aws/triggers/bedrock.py index 0af8eea1d52d4..ae4805ed70631 100644 --- a/airflow/providers/amazon/aws/triggers/bedrock.py +++ b/airflow/providers/amazon/aws/triggers/bedrock.py @@ -30,8 +30,8 @@ class BedrockCustomizeModelCompletedTrigger(AwsBaseWaiterTrigger): Trigger when a Bedrock model customization job is complete. :param job_name: The name of the Bedrock model customization job. - :param waiter_delay: The amount of time in seconds to wait between attempts. - :param waiter_max_attempts: The maximum number of attempts to be made. + :param waiter_delay: The amount of time in seconds to wait between attempts. (default: 120) + :param waiter_max_attempts: The maximum number of attempts to be made. (default: 75) :param aws_conn_id: The Airflow connection used for AWS credentials. """ diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index e7b496a3fece4..1f32e0bdeb12e 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -35,18 +35,18 @@ def mock_conn(): class TestBedrockHook: - def setup_method(self): - self.hook = BedrockHook() + VALIDATION_EXCEPTION_ERROR = ClientError( + error_response={"Error": {"Code": "ValidationException", "Message": ""}}, + operation_name="GetModelCustomizationJob", + ) - self.validation_exception_error = ClientError( - error_response={"Error": {"Code": "ValidationException", "Message": ""}}, - operation_name="GetModelCustomizationJob", - ) + UNEXPECTED_EXCEPTION = ClientError( + error_response={"Error": {"Code": "ExpiredTokenException", "Message": ""}}, + operation_name="GetModelCustomizationJob", + ) - self.unexpected_exception = ClientError( - error_response={"Error": {"Code": "ExpiredTokenException", "Message": ""}}, - operation_name="GetModelCustomizationJob", - ) + def setup_method(self): + self.hook = BedrockHook() def test_conn_returns_a_boto3_connection(self): assert self.hook.conn is not None @@ -66,7 +66,7 @@ def test_job_name_exists_positive(self, mock_conn): def test_job_name_exists_negative(self, mock_conn): invalid_job_name = "invalid_job_name" - mock_conn.get_model_customization_job.side_effect = self.validation_exception_error + mock_conn.get_model_customization_job.side_effect = self.VALIDATION_EXCEPTION_ERROR response = self.hook.job_name_exists(invalid_job_name) @@ -74,7 +74,7 @@ def test_job_name_exists_negative(self, mock_conn): assert response is False def test_job_name_exists_unexpected_exception(self, mock_conn): - mock_conn.get_model_customization_job.side_effect = self.unexpected_exception + mock_conn.get_model_customization_job.side_effect = self.UNEXPECTED_EXCEPTION with pytest.raises(ClientError): self.hook.job_name_exists(JOB_NAME) diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index 89dc434b54717..d5d7f8a747c93 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -36,19 +36,18 @@ class TestBedrockInvokeModelOperator: - def setup_method(self): - self.model_id = "meta.llama2-13b-chat-v1" - self.prompt = "A very important question." - self.generated_response = "An important answer." + MODEL_ID = "meta.llama2-13b-chat-v1" + TEST_PROMPT = "A very important question." + GENERATED_RESPONSE = "An important answer." @pytest.fixture def mock_runtime_conn(self) -> Generator[BaseAwsConnection, None, None]: with mock.patch.object(BedrockRuntimeHook, "conn") as _conn: _conn.invoke_model.return_value["body"].read.return_value = json.dumps( { - "generation": self.generated_response, - "prompt_token_count": len(self.prompt), - "generation_token_count": len(self.generated_response), + "generation": self.GENERATED_RESPONSE, + "prompt_token_count": len(self.TEST_PROMPT), + "generation_token_count": len(self.GENERATED_RESPONSE), "stop_reason": "stop", } ) @@ -61,24 +60,29 @@ def runtime_hook(self) -> Generator[BedrockRuntimeHook, None, None]: def test_invoke_model_prompt_good_combinations(self, mock_runtime_conn): operator = BedrockInvokeModelOperator( - task_id="test_task", model_id=self.model_id, input_data={"input_data": {"prompt": self.prompt}} + task_id="test_task", + model_id=self.MODEL_ID, + input_data={"input_data": {"prompt": self.TEST_PROMPT}}, ) response = operator.execute({}) - assert response["generation"] == self.generated_response + assert response["generation"] == self.GENERATED_RESPONSE class TestBedrockCustomizeModelOperator: + CUSTOMIZE_JOB_ARN = "valid_arn" + CUSTOMIZE_JOB_NAME = "testModelJob" + @pytest.fixture def mock_conn(self) -> Generator[BaseAwsConnection, None, None]: with mock.patch.object(BedrockHook, "conn") as _conn: _conn.create_model_customization_job.return_value = { "ResponseMetadata": {"HTTPStatusCode": 201}, - "jobArn": self.custom_job_arn, + "jobArn": self.CUSTOMIZE_JOB_ARN, } _conn.get_model_customization_job.return_value = { - "jobName": self.customize_model_job_name, + "jobName": self.CUSTOMIZE_JOB_NAME, "status": "InProgress", } yield _conn @@ -90,12 +94,9 @@ def bedrock_hook(self) -> Generator[BedrockHook, None, None]: yield hook def setup_method(self): - self.custom_job_arn = "valid_arn" - self.customize_model_job_name = "testModelJob" - self.operator = BedrockCustomizeModelOperator( task_id="test_task", - job_name=self.customize_model_job_name, + job_name=self.CUSTOMIZE_JOB_NAME, custom_model_name="testModelName", role_arn="valid_arn", base_model_id="base_model_id", @@ -127,7 +128,7 @@ def test_customize_model_wait_combinations( response = self.operator.execute({}) - assert response == self.custom_job_arn + assert response == self.CUSTOMIZE_JOB_ARN assert bedrock_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable @@ -143,7 +144,7 @@ def test_customize_model_validate_action_if_job_exists(self, action_if_job_exist exception = None operator = BedrockCustomizeModelOperator( task_id="test_task", - job_name=self.customize_model_job_name, + job_name=self.CUSTOMIZE_JOB_NAME, custom_model_name="testModelName", role_arn="valid_arn", base_model_id="base_model_id", diff --git a/tests/providers/amazon/aws/waiters/test_bedrock.py b/tests/providers/amazon/aws/waiters/test_bedrock.py index b52b5439feaeb..00521ee013c47 100644 --- a/tests/providers/amazon/aws/waiters/test_bedrock.py +++ b/tests/providers/amazon/aws/waiters/test_bedrock.py @@ -40,32 +40,31 @@ def mock_conn(self, monkeypatch): class TestModelCustomizationJobCompleteWaiter(TestBedrockCustomWaitersBase): + WAITER_NAME = "model_customization_job_complete" + @pytest.fixture def mock_get_job(self): with mock.patch.object(self.client, "get_model_customization_job") as m: yield m - def setup_method(self): - self.waiter_name = "model_customization_job_complete" - @pytest.mark.parametrize("state", BedrockCustomizeModelCompletedSensor.SUCCESS_STATES) def test_model_customization_job_complete(self, state, mock_get_job): mock_get_job.return_value = {"status": state} - BedrockHook().get_waiter(self.waiter_name).wait(jobIdentifier="job_id") + BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_id") @pytest.mark.parametrize("state", BedrockCustomizeModelCompletedSensor.FAILURE_STATES) def test_model_customization_job_failed(self, state, mock_get_job): mock_get_job.return_value = {"status": state} with pytest.raises(botocore.exceptions.WaiterError): - BedrockHook().get_waiter(self.waiter_name).wait(jobIdentifier="job_id") + BedrockHook().get_waiter(self.WAITER_NAME).wait(jobIdentifier="job_id") def test_model_customization_job_wait(self, mock_get_job): wait = {"status": "InProgress"} success = {"status": "Completed"} mock_get_job.side_effect = [wait, wait, success] - BedrockHook().get_waiter(self.waiter_name).wait( + BedrockHook().get_waiter(self.WAITER_NAME).wait( jobIdentifier="job_id", WaiterConfig={"Delay": 0.01, "MaxAttempts": 3} ) From 44bc37c332adbf1ced6cb683b8725912a97d6ab9 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 5 Apr 2024 13:24:16 -0700 Subject: [PATCH 09/12] refactor how we handle job name conflicts --- airflow/providers/amazon/aws/hooks/bedrock.py | 13 ---- .../providers/amazon/aws/operators/bedrock.py | 73 +++++++++---------- .../amazon/aws/hooks/test_bedrock.py | 23 ------ .../amazon/aws/operators/test_bedrock.py | 65 +++++++++-------- 4 files changed, 68 insertions(+), 106 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py index e938d6f505453..2ecec8feb8126 100644 --- a/airflow/providers/amazon/aws/hooks/bedrock.py +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -16,8 +16,6 @@ # under the License. from __future__ import annotations -from botocore.exceptions import ClientError - from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook @@ -48,17 +46,6 @@ def get_customize_model_job_state(self, job_name: str) -> str: self.log.info("Job '%s' state: %s", job_name, state) return state - def job_name_exists(self, job_name: str) -> bool: - try: - self._get_job_by_name(job_name) - self.log.info("Verified that job name '%s' does exist.", job_name) - return True - except ClientError as e: - if e.response["Error"]["Code"] == "ValidationException": - self.log.info("Job name '%s' does not exist.", job_name) - return False - raise e - def get_job_arn(self, job_name: str) -> str: return self._get_job_by_name(job_name)["jobArn"] diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index 99062fd3c0c11..a55f574896dd7 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -19,6 +19,8 @@ import json from typing import TYPE_CHECKING, Any, Sequence +from botocore.exceptions import ClientError + from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook @@ -114,11 +116,9 @@ class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]): :param training_data_uri: The S3 URI where the training data is stored. :param output_data_uri: The S3 URI where the output data is stored. :param hyperparameters: Parameters related to tuning the model. - :param check_if_job_exists: If set to true, operator will check whether a model customization - job already exists for the name in the config. (Default: True) - :param action_if_job_exists: Behavior if the job name already exists. Options are "timestamp" (default), - and "fail". If "timestamp" is used and the job name already exists, the current timestamp - will be appended to the name in order to make it unique. + :param ensure_unique_job_name: If set to true, operator will check whether a model customization + job already exists for the name in the config and append the current timestamp if there is a + name conflict. (Default: True) :param customization_job_kwargs: Any optional parameters to pass to the API. :param wait_for_completion: Whether to wait for cluster to stop. (default: True) @@ -146,8 +146,7 @@ class BedrockCustomizeModelOperator(AwsBaseOperator[BedrockHook]): "role_arn", "base_model_id", "hyperparameters", - "check_if_job_exists", - "action_if_job_exists", + "ensure_unique_job_name", "customization_job_kwargs", ) @@ -160,8 +159,7 @@ def __init__( training_data_uri: str, output_data_uri: str, hyperparameters: dict[str, str], - check_if_job_exists: bool = True, - action_if_job_exists: str = "timestamp", + ensure_unique_job_name: bool = True, customization_job_kwargs: dict[str, Any] | None = None, wait_for_completion: bool = True, waiter_delay: int = 120, @@ -182,9 +180,8 @@ def __init__( self.training_data_config = {"s3Uri": training_data_uri} self.output_data_config = {"s3Uri": output_data_uri} self.hyperparameters = hyperparameters - self.check_if_job_exists = check_if_job_exists + self.ensure_unique_job_name = ensure_unique_job_name self.customization_job_kwargs = customization_job_kwargs or {} - self.action_if_job_exists = action_if_job_exists.lower() self.valid_action_if_job_exists: set[str] = {"timestamp", "fail"} @@ -197,34 +194,34 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None self.log.info("Bedrock model customization job `%s` complete.", self.job_name) return self.hook.get_job_arn(event["job_name"]) - def _validate_action_if_job_exists(self): - if self.action_if_job_exists not in self.valid_action_if_job_exists: - raise AirflowException( - f"Invalid value for argument action_if_job_exists {self.action_if_job_exists}; " - f"must be one of: {self.valid_action_if_job_exists}." - ) - def execute(self, context: Context) -> dict: - self._validate_action_if_job_exists() - - if self.check_if_job_exists and self.hook.job_name_exists(self.job_name): - if self.action_if_job_exists == "fail": - raise AirflowException(f"A Bedrock job with name {self.job_name} already exists.") - self.job_name = f"{self.job_name}-{int(utcnow().timestamp())}" - self.log.info("Changed job name to '%s' to avoid collision.", self.job_name) - - self.log.info("Creating Bedrock model customization job '%s'.", self.job_name) - - response = self.hook.conn.create_model_customization_job( - jobName=self.job_name, - customModelName=self.custom_model_name, - roleArn=self.role_arn, - baseModelIdentifier=self.base_model_id, - trainingDataConfig=self.training_data_config, - outputDataConfig=self.output_data_config, - hyperParameters=self.hyperparameters, - **self.customization_job_kwargs, - ) + while True: + # If there is a name conflict and ensure_unique_job_name is True, append the current timestamp + # to the name and retry until there is no name conflict. + # - Break the loop when the API call returns success. + # - If the API returns an exception other than a name conflict, raise that exception. + # - If the API returns a name conflict and ensure_unique_job_name is false, raise that exception. + try: + self.log.info("Creating Bedrock model customization job '%s'.", self.job_name) + + response = self.hook.conn.create_model_customization_job( + jobName=self.job_name, + customModelName=self.custom_model_name, + roleArn=self.role_arn, + baseModelIdentifier=self.base_model_id, + trainingDataConfig=self.training_data_config, + outputDataConfig=self.output_data_config, + hyperParameters=self.hyperparameters, + **self.customization_job_kwargs, + ) + break + except ClientError as error: + if error.response["Error"]["Message"] != "The provided job name is currently in use.": + raise error + if not self.ensure_unique_job_name: + raise error + self.job_name = f"{self.job_name}-{int(utcnow().timestamp())}" + self.log.info("Changed job name to '%s' to avoid collision.", self.job_name) if response["ResponseMetadata"]["HTTPStatusCode"] != 201: raise AirflowException(f"Bedrock model customization job creation failed: {response}") diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index 1f32e0bdeb12e..65dde985563ba 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -58,29 +58,6 @@ def test_get_customize_model_job_state(self, mock_conn): mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=JOB_NAME) assert response == EXPECTED_STATUS - def test_job_name_exists_positive(self, mock_conn): - response = self.hook.job_name_exists(JOB_NAME) - - mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=JOB_NAME) - assert response is True - - def test_job_name_exists_negative(self, mock_conn): - invalid_job_name = "invalid_job_name" - mock_conn.get_model_customization_job.side_effect = self.VALIDATION_EXCEPTION_ERROR - - response = self.hook.job_name_exists(invalid_job_name) - - mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=invalid_job_name) - assert response is False - - def test_job_name_exists_unexpected_exception(self, mock_conn): - mock_conn.get_model_customization_job.side_effect = self.UNEXPECTED_EXCEPTION - - with pytest.raises(ClientError): - self.hook.job_name_exists(JOB_NAME) - - mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=JOB_NAME) - class TestBedrockRuntimeHook: def test_conn_returns_a_boto3_connection(self): diff --git a/tests/providers/amazon/aws/operators/test_bedrock.py b/tests/providers/amazon/aws/operators/test_bedrock.py index d5d7f8a747c93..2371877b4de9c 100644 --- a/tests/providers/amazon/aws/operators/test_bedrock.py +++ b/tests/providers/amazon/aws/operators/test_bedrock.py @@ -22,9 +22,9 @@ from unittest import mock import pytest +from botocore.exceptions import ClientError from moto import mock_aws -from airflow.exceptions import AirflowException from airflow.providers.amazon.aws.hooks.bedrock import BedrockHook, BedrockRuntimeHook from airflow.providers.amazon.aws.operators.bedrock import ( BedrockCustomizeModelOperator, @@ -132,40 +132,41 @@ def test_customize_model_wait_combinations( assert bedrock_hook.get_waiter.call_count == wait_for_completion assert self.operator.defer.call_count == deferrable + conflict_msg = "The provided job name is currently in use." + conflict_exception = ClientError( + error_response={"Error": {"Message": conflict_msg, "Code": "ValidationException"}}, + operation_name="UnitTest", + ) + success = {"ResponseMetadata": {"HTTPStatusCode": 201}, "jobArn": CUSTOMIZE_JOB_ARN} + @pytest.mark.parametrize( - "action_if_job_exists, succeeds", + "side_effect, ensure_unique_name", [ - pytest.param("timestamp", True, id="timestamp"), - pytest.param("fail", True, id="fail"), - pytest.param("call me maybe", False, id="invalid"), + pytest.param([conflict_exception, success], True, id="conflict_and_ensure_unique"), + pytest.param([conflict_exception, success], False, id="conflict_and_not_ensure_unique"), + pytest.param( + [conflict_exception, conflict_exception, success], + True, + id="multiple_conflict_and_ensure_unique", + ), + pytest.param( + [conflict_exception, conflict_exception, success], + False, + id="multiple_conflict_and_not_ensure_unique", + ), + pytest.param([success], True, id="no_conflict_and_ensure_unique"), + pytest.param([success], False, id="no_conflict_and_not_ensure_unique"), ], ) - def test_customize_model_validate_action_if_job_exists(self, action_if_job_exists, succeeds): - exception = None - operator = BedrockCustomizeModelOperator( - task_id="test_task", - job_name=self.CUSTOMIZE_JOB_NAME, - custom_model_name="testModelName", - role_arn="valid_arn", - base_model_id="base_model_id", - hyperparameters={ - "epochCount": "1", - "batchSize": "1", - "learningRate": ".0005", - "learningRateWarmupSteps": "0", - }, - training_data_uri="s3://uri", - output_data_uri="s3://uri/output", - action_if_job_exists=action_if_job_exists, - ) + @mock.patch.object(BedrockHook, "get_waiter") + def test_ensure_unique_job_name(self, _, side_effect, ensure_unique_name, mock_conn, bedrock_hook): + mock_conn.create_model_customization_job.side_effect = side_effect + expected_call_count = len(side_effect) if ensure_unique_name else 1 + self.operator.wait_for_completion = False - try: - operator._validate_action_if_job_exists() - except Exception as e: - exception = e + response = self.operator.execute({}) - if succeeds: - assert operator.action_if_job_exists == action_if_job_exists - assert exception is None - else: - assert isinstance(exception, AirflowException) + assert response == self.CUSTOMIZE_JOB_ARN + mock_conn.create_model_customization_job.call_count == expected_call_count + bedrock_hook.get_waiter.assert_not_called() + self.operator.defer.assert_not_called() From 09c0e696d9efa146f25f06b5bf6894bf8a3ee97e Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 5 Apr 2024 15:12:17 -0700 Subject: [PATCH 10/12] replace while True loop logic --- airflow/providers/amazon/aws/operators/bedrock.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index a55f574896dd7..45567760f8959 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -195,13 +195,17 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None return self.hook.get_job_arn(event["job_name"]) def execute(self, context: Context) -> dict: - while True: + response = {} + retry = True + while retry: # If there is a name conflict and ensure_unique_job_name is True, append the current timestamp # to the name and retry until there is no name conflict. # - Break the loop when the API call returns success. # - If the API returns an exception other than a name conflict, raise that exception. # - If the API returns a name conflict and ensure_unique_job_name is false, raise that exception. try: + # Ensure the loop is executed at least once, and not repeat unless explicitly set to do so. + retry = False self.log.info("Creating Bedrock model customization job '%s'.", self.job_name) response = self.hook.conn.create_model_customization_job( @@ -214,12 +218,12 @@ def execute(self, context: Context) -> dict: hyperParameters=self.hyperparameters, **self.customization_job_kwargs, ) - break except ClientError as error: if error.response["Error"]["Message"] != "The provided job name is currently in use.": raise error if not self.ensure_unique_job_name: raise error + retry = True self.job_name = f"{self.job_name}-{int(utcnow().timestamp())}" self.log.info("Changed job name to '%s' to avoid collision.", self.job_name) From fec3942a6e0a9dd80c08fc9cbcd1756cf6beb682 Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Fri, 5 Apr 2024 16:47:50 -0700 Subject: [PATCH 11/12] unit test tweaks --- tests/providers/amazon/aws/sensors/test_bedrock.py | 6 +++--- tests/providers/amazon/aws/triggers/test_bedrock.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py b/tests/providers/amazon/aws/sensors/test_bedrock.py index a67b89b6a8bf5..71df06d4af9d4 100644 --- a/tests/providers/amazon/aws/sensors/test_bedrock.py +++ b/tests/providers/amazon/aws/sensors/test_bedrock.py @@ -33,12 +33,12 @@ def mock_get_job_state(): class TestBedrockCustomizeModelCompletedSensor: - def setup_method(self): - self.job_name = "test_job_name" + JOB_NAME = "test_job_name" + def setup_method(self): self.default_op_kwargs = dict( task_id="test_bedrock_customize_model_sensor", - job_name=self.job_name, + job_name=self.JOB_NAME, poke_interval=5, max_retries=1, ) diff --git a/tests/providers/amazon/aws/triggers/test_bedrock.py b/tests/providers/amazon/aws/triggers/test_bedrock.py index f51d32ffceec2..0a54c56a77889 100644 --- a/tests/providers/amazon/aws/triggers/test_bedrock.py +++ b/tests/providers/amazon/aws/triggers/test_bedrock.py @@ -25,17 +25,18 @@ from airflow.providers.amazon.aws.triggers.bedrock import BedrockCustomizeModelCompletedTrigger from airflow.triggers.base import TriggerEvent -JOB_NAME = "test_job" BASE_TRIGGER_CLASSPATH = "airflow.providers.amazon.aws.triggers.bedrock." class TestBedrockCustomizeModelCompletedTrigger: + JOB_NAME = "test_job" + def test_serialization(self): """Assert that arguments and classpath are correctly serialized.""" - trigger = BedrockCustomizeModelCompletedTrigger(job_name=JOB_NAME) + trigger = BedrockCustomizeModelCompletedTrigger(job_name=self.JOB_NAME) classpath, kwargs = trigger.serialize() assert classpath == BASE_TRIGGER_CLASSPATH + "BedrockCustomizeModelCompletedTrigger" - assert kwargs.get("job_name") == JOB_NAME + assert kwargs.get("job_name") == self.JOB_NAME @pytest.mark.asyncio @mock.patch.object(BedrockHook, "get_waiter") @@ -43,10 +44,10 @@ def test_serialization(self): async def test_run_success(self, mock_async_conn, mock_get_waiter): mock_async_conn.__aenter__.return_value = mock.MagicMock() mock_get_waiter().wait = AsyncMock() - trigger = BedrockCustomizeModelCompletedTrigger(job_name=JOB_NAME) + trigger = BedrockCustomizeModelCompletedTrigger(job_name=self.JOB_NAME) generator = trigger.run() response = await generator.asend(None) - assert response == TriggerEvent({"status": "success", "job_name": JOB_NAME}) + assert response == TriggerEvent({"status": "success", "job_name": self.JOB_NAME}) assert mock_get_waiter().wait.call_count == 1 From da7579273fbf8f50a67580b6d116f155af127b8a Mon Sep 17 00:00:00 2001 From: ferruzzi Date: Mon, 8 Apr 2024 11:11:56 -0700 Subject: [PATCH 12/12] vincent fixes --- airflow/providers/amazon/aws/hooks/bedrock.py | 11 ----------- airflow/providers/amazon/aws/operators/bedrock.py | 2 +- airflow/providers/amazon/aws/sensors/bedrock.py | 7 +++---- tests/providers/amazon/aws/hooks/test_bedrock.py | 6 ------ .../providers/amazon/aws/sensors/test_bedrock.py | 15 +++++++++------ 5 files changed, 13 insertions(+), 28 deletions(-) diff --git a/airflow/providers/amazon/aws/hooks/bedrock.py b/airflow/providers/amazon/aws/hooks/bedrock.py index 2ecec8feb8126..96636eb952aa0 100644 --- a/airflow/providers/amazon/aws/hooks/bedrock.py +++ b/airflow/providers/amazon/aws/hooks/bedrock.py @@ -38,17 +38,6 @@ def __init__(self, *args, **kwargs) -> None: kwargs["client_type"] = self.client_type super().__init__(*args, **kwargs) - def _get_job_by_name(self, job_name: str): - return self.conn.get_model_customization_job(jobIdentifier=job_name) - - def get_customize_model_job_state(self, job_name: str) -> str: - state = self._get_job_by_name(job_name)["status"] - self.log.info("Job '%s' state: %s", job_name, state) - return state - - def get_job_arn(self, job_name: str) -> str: - return self._get_job_by_name(job_name)["jobArn"] - class BedrockRuntimeHook(AwsBaseHook): """ diff --git a/airflow/providers/amazon/aws/operators/bedrock.py b/airflow/providers/amazon/aws/operators/bedrock.py index 45567760f8959..ee34a9aef7da7 100644 --- a/airflow/providers/amazon/aws/operators/bedrock.py +++ b/airflow/providers/amazon/aws/operators/bedrock.py @@ -192,7 +192,7 @@ def execute_complete(self, context: Context, event: dict[str, Any] | None = None raise AirflowException(f"Error while running job: {event}") self.log.info("Bedrock model customization job `%s` complete.", self.job_name) - return self.hook.get_job_arn(event["job_name"]) + return self.hook.conn.get_model_customization_job(jobIdentifier=event["job_name"])["jobArn"] def execute(self, context: Context) -> dict: response = {} diff --git a/airflow/providers/amazon/aws/sensors/bedrock.py b/airflow/providers/amazon/aws/sensors/bedrock.py index 3d48fd49e136c..43a8846c73959 100644 --- a/airflow/providers/amazon/aws/sensors/bedrock.py +++ b/airflow/providers/amazon/aws/sensors/bedrock.py @@ -98,7 +98,8 @@ def execute(self, context: Context) -> Any: super().execute(context=context) def poke(self, context: Context) -> bool: - state = self.hook.get_customize_model_job_state(self.job_name) + state = self.hook.conn.get_model_customization_job(jobIdentifier=self.job_name)["status"] + self.log.info("Job '%s' state: %s", self.job_name, state) if state in self.FAILURE_STATES: # TODO: remove this if block when min_airflow_version is set to higher than 2.7.1 @@ -106,6 +107,4 @@ def poke(self, context: Context) -> bool: raise AirflowSkipException(self.FAILURE_MESSAGE) raise AirflowException(self.FAILURE_MESSAGE) - if state in self.INTERMEDIATE_STATES: - return False - return True + return state not in self.INTERMEDIATE_STATES diff --git a/tests/providers/amazon/aws/hooks/test_bedrock.py b/tests/providers/amazon/aws/hooks/test_bedrock.py index 65dde985563ba..16752477d5631 100644 --- a/tests/providers/amazon/aws/hooks/test_bedrock.py +++ b/tests/providers/amazon/aws/hooks/test_bedrock.py @@ -52,12 +52,6 @@ def test_conn_returns_a_boto3_connection(self): assert self.hook.conn is not None assert self.hook.conn.meta.service_model.service_name == "bedrock" - def test_get_customize_model_job_state(self, mock_conn): - response = self.hook.get_customize_model_job_state(JOB_NAME) - - mock_conn.get_model_customization_job.assert_called_once_with(jobIdentifier=JOB_NAME) - assert response == EXPECTED_STATUS - class TestBedrockRuntimeHook: def test_conn_returns_a_boto3_connection(self): diff --git a/tests/providers/amazon/aws/sensors/test_bedrock.py b/tests/providers/amazon/aws/sensors/test_bedrock.py index 71df06d4af9d4..dab0f94ad36dd 100644 --- a/tests/providers/amazon/aws/sensors/test_bedrock.py +++ b/tests/providers/amazon/aws/sensors/test_bedrock.py @@ -65,13 +65,15 @@ def test_base_aws_op_attributes(self): assert op.hook._config.read_timeout == 42 @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.SUCCESS_STATES)) - def test_poke_success_states(self, state, mock_get_job_state): - mock_get_job_state.side_effect = [state] + @mock.patch.object(BedrockHook, "conn") + def test_poke_success_states(self, mock_conn, state): + mock_conn.get_model_customization_job.return_value = {"status": state} assert self.sensor.poke({}) is True @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.INTERMEDIATE_STATES)) - def test_poke_intermediate_states(self, state, mock_get_job_state): - mock_get_job_state.side_effect = [state] + @mock.patch.object(BedrockHook, "conn") + def test_poke_intermediate_states(self, mock_conn, state): + mock_conn.get_model_customization_job.return_value = {"status": state} assert self.sensor.poke({}) is False @pytest.mark.parametrize( @@ -82,8 +84,9 @@ def test_poke_intermediate_states(self, state, mock_get_job_state): ], ) @pytest.mark.parametrize("state", list(BedrockCustomizeModelCompletedSensor.FAILURE_STATES)) - def test_poke_failure_states(self, state, soft_fail, expected_exception, mock_get_job_state): - mock_get_job_state.side_effect = [state] + @mock.patch.object(BedrockHook, "conn") + def test_poke_failure_states(self, mock_conn, state, soft_fail, expected_exception): + mock_conn.get_model_customization_job.return_value = {"status": state} sensor = BedrockCustomizeModelCompletedSensor( **self.default_op_kwargs, aws_conn_id=None, soft_fail=soft_fail )