From 56e716f5591902090005b9a6e47dbea31cff696f Mon Sep 17 00:00:00 2001 From: Pankaj Date: Wed, 3 May 2023 14:25:40 +0530 Subject: [PATCH 1/2] Add deferrable param in EmrContainerOperator Add the deferrable param in EmrContainerOperator. This will allow running EmrContainerOperator in an async way that means we only submit a job from the worker to run a job then defer to the trigger for polling and wait for a job the job status and the worker slot won't be occupied for the whole period of task execution. --- airflow/providers/amazon/aws/operators/emr.py | 22 +++++ airflow/providers/amazon/aws/triggers/emr.py | 82 +++++++++++++++++++ .../amazon/aws/waiters/emr-containers.json | 18 ++++ .../aws/operators/test_emr_containers.py | 13 ++- .../aws/triggers/test_emr_containers.py | 69 ++++++++++++++++ 5 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 airflow/providers/amazon/aws/triggers/emr.py create mode 100644 airflow/providers/amazon/aws/waiters/emr-containers.json create mode 100644 tests/providers/amazon/aws/triggers/test_emr_containers.py diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 220a4ddea0032..28cac488cfccb 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -26,6 +26,7 @@ from airflow.models import BaseOperator from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook, EmrHook, EmrServerlessHook from airflow.providers.amazon.aws.links.emr import EmrClusterLink +from airflow.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger from airflow.providers.amazon.aws.utils.waiter import waiter from airflow.utils.helpers import exactly_one, prune_dict from airflow.utils.types import NOTSET, ArgNotSet @@ -437,6 +438,7 @@ class EmrContainerOperator(BaseOperator): Defaults to None, which will poll until the job is *not* in a pending, submitted, or running state. :param tags: The tags assigned to job runs. Defaults to None + :param deferrable: Run operator in the deferrable mode. """ template_fields: Sequence[str] = ( @@ -465,6 +467,7 @@ def __init__( max_tries: int | None = None, tags: dict | None = None, max_polling_attempts: int | None = None, + deferrable: bool = False, **kwargs: Any, ) -> None: super().__init__(**kwargs) @@ -481,6 +484,7 @@ def __init__( self.max_polling_attempts = max_polling_attempts self.tags = tags self.job_id: str | None = None + self.deferrable = deferrable if max_tries: warnings.warn( @@ -513,6 +517,18 @@ def execute(self, context: Context) -> str | None: self.client_request_token, self.tags, ) + if self.deferrable: + self.defer( + timeout=self.execution_timeout, + trigger=EmrContainerOperatorTrigger( + virtual_cluster_id=self.virtual_cluster_id, + job_id=self.job_id, + max_attempts=self.max_polling_attempts, + aws_conn_id=self.aws_conn_id, + poll_interval=self.poll_interval, + ), + method_name="execute_complete", + ) if self.wait_for_completion: query_status = self.hook.poll_query_status( self.job_id, @@ -534,6 +550,12 @@ def execute(self, context: Context) -> str | None: return self.job_id + def execute_complete(self, context, event=None): + if event["status"] != "success": + raise AirflowException(f"Error while running job: {event}") + else: + self.log.info(event["message"]) + def on_kill(self) -> None: """Cancel the submitted job run""" if self.job_id: diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py new file mode 100644 index 0000000000000..0c56f96c17487 --- /dev/null +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import Any, AsyncIterator + +from airflow.compat.functools import cached_property +from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook +from airflow.triggers.base import BaseTrigger, TriggerEvent + + +class EmrContainerOperatorTrigger(BaseTrigger): + """ + Poll for the status of EMR container until reaches terminal state + + :param virtual_cluster_id: Reference Emr cluster id + :param job_id: job_id to check the state + :param max_attempts: maximum try attempts for polling the status + :param aws_conn_id: Reference to AWS connection id + :param poll_interval: polling period in seconds to check for the status + """ + + def __init__( + self, + virtual_cluster_id: str, + job_id: str, + aws_conn_id: str = "aws_default", + poll_interval: int = 10, + max_attempts: int | None = None, + **kwargs: Any, + ): + self.virtual_cluster_id = virtual_cluster_id + self.job_id = job_id + self.aws_conn_id = aws_conn_id + self.poll_interval = poll_interval + self.max_attempts = max_attempts + super().__init__(**kwargs) + + @cached_property + def hook(self) -> EmrContainerHook: + return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) + + def serialize(self) -> tuple[str, dict[str, Any]]: + """Serializes EmrContainerSensorTrigger arguments and classpath.""" + return ( + "airflow.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger", + { + "virtual_cluster_id": self.virtual_cluster_id, + "job_id": self.job_id, + "aws_conn_id": self.aws_conn_id, + "max_attempts": self.max_attempts, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self) -> AsyncIterator[TriggerEvent]: + async with self.hook.async_conn as client: + waiter = self.hook.get_waiter("container_job_complete", deferrable=True, client=client) + await waiter.wait( + id=self.job_id, + virtualClusterId=self.virtual_cluster_id, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": self.max_attempts, + }, + ) + yield TriggerEvent({"status": "success", "message": "Job completed."}) diff --git a/airflow/providers/amazon/aws/waiters/emr-containers.json b/airflow/providers/amazon/aws/waiters/emr-containers.json new file mode 100644 index 0000000000000..6925ccaab6c6d --- /dev/null +++ b/airflow/providers/amazon/aws/waiters/emr-containers.json @@ -0,0 +1,18 @@ +{ + "version": 2, + "waiters": { + "container_job_complete": { + "operation": "DescribeJobRun", + "delay": 30, + "maxAttempts": 60, + "acceptors": [ + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "COMPLETED", + "state": "success" + } + ] + } + } +} diff --git a/tests/providers/amazon/aws/operators/test_emr_containers.py b/tests/providers/amazon/aws/operators/test_emr_containers.py index ddc11b15c56ce..8646e000fe997 100644 --- a/tests/providers/amazon/aws/operators/test_emr_containers.py +++ b/tests/providers/amazon/aws/operators/test_emr_containers.py @@ -22,9 +22,10 @@ import pytest from airflow.configuration import conf -from airflow.exceptions import AirflowException +from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.providers.amazon.aws.operators.emr import EmrContainerOperator, EmrEksCreateClusterOperator +from airflow.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger SUBMIT_JOB_SUCCESS_RETURN = { "ResponseMetadata": {"HTTPStatusCode": 200}, @@ -144,6 +145,16 @@ def test_execute_with_polling_timeout(self, mock_check_query_status): assert "Final state of EMR Containers job is SUBMITTED" in str(ctx.value) assert "Max tries of poll status exceeded" in str(ctx.value) + @mock.patch.object(EmrContainerHook, "submit_job") + def test_operator_defer(self, mock_submit_job): + self.emr_container.deferrable = True + self.emr_container.wait_for_completion = False + with pytest.raises(TaskDeferred) as exc: + self.emr_container.execute(context=None) + assert isinstance( + exc.value.trigger, EmrContainerOperatorTrigger + ), "Trigger is not a EmrContainerOperatorTrigger" + class TestEmrEksCreateClusterOperator: def setup_method(self): diff --git a/tests/providers/amazon/aws/triggers/test_emr_containers.py b/tests/providers/amazon/aws/triggers/test_emr_containers.py new file mode 100644 index 0000000000000..9965c2ddec188 --- /dev/null +++ b/tests/providers/amazon/aws/triggers/test_emr_containers.py @@ -0,0 +1,69 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest + +from airflow.providers.amazon.aws.triggers.emr import EmrContainerOperatorTrigger +from airflow.triggers.base import TriggerEvent +from tests.providers.amazon.aws.utils.compat import AsyncMock, async_mock + +VIRTUAL_CLUSTER_ID = "vzwemreks" +JOB_ID = "job-1234" +AWS_CONN_ID = "aws_emr_conn" +POLL_INTERVAL = 60 +MAX_ATTEMPTS = 5 + + +class TestEmrContainerSensorTrigger: + def test_emr_container_operator_trigger_serialize(self): + emr_trigger = EmrContainerOperatorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + max_attempts=MAX_ATTEMPTS, + ) + class_path, args = emr_trigger.serialize() + assert class_path == "airflow.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger" + assert args["virtual_cluster_id"] == VIRTUAL_CLUSTER_ID + assert args["job_id"] == JOB_ID + assert args["aws_conn_id"] == AWS_CONN_ID + assert args["poll_interval"] == POLL_INTERVAL + assert args["max_attempts"] == MAX_ATTEMPTS + + @pytest.mark.asyncio + @async_mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.get_waiter") + @async_mock.patch("airflow.providers.amazon.aws.hooks.emr.EmrContainerHook.async_conn") + async def test_emr_container_trigger_run(self, mock_async_conn, mock_get_waiter): + mock = async_mock.MagicMock() + mock_async_conn.__aenter__.return_value = mock + + mock_get_waiter().wait = AsyncMock() + + emr_trigger = EmrContainerOperatorTrigger( + virtual_cluster_id=VIRTUAL_CLUSTER_ID, + job_id=JOB_ID, + aws_conn_id=AWS_CONN_ID, + poll_interval=POLL_INTERVAL, + max_attempts=MAX_ATTEMPTS, + ) + + generator = emr_trigger.run() + response = await generator.asend(None) + + assert response == TriggerEvent({"status": "success", "message": "Job completed."}) From f075c303ffccba192d1d0eb2b7e26c2ce50b9190 Mon Sep 17 00:00:00 2001 From: Pankaj Date: Mon, 12 Jun 2023 17:20:38 +0530 Subject: [PATCH 2/2] Add logs in trigger --- airflow/providers/amazon/aws/operators/emr.py | 9 +++- airflow/providers/amazon/aws/triggers/emr.py | 43 +++++++++++++------ .../amazon/aws/waiters/emr-containers.json | 12 ++++++ 3 files changed, 48 insertions(+), 16 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 28cac488cfccb..fc0af516d1930 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -19,6 +19,7 @@ import ast import warnings +from datetime import timedelta from typing import TYPE_CHECKING, Any, Sequence from uuid import uuid4 @@ -518,12 +519,16 @@ def execute(self, context: Context) -> str | None: self.tags, ) if self.deferrable: + timeout = ( + timedelta(seconds=self.max_polling_attempts * self.poll_interval + 60) + if self.max_polling_attempts + else self.execution_timeout + ) self.defer( - timeout=self.execution_timeout, + timeout=timeout, trigger=EmrContainerOperatorTrigger( virtual_cluster_id=self.virtual_cluster_id, job_id=self.job_id, - max_attempts=self.max_polling_attempts, aws_conn_id=self.aws_conn_id, poll_interval=self.poll_interval, ), diff --git a/airflow/providers/amazon/aws/triggers/emr.py b/airflow/providers/amazon/aws/triggers/emr.py index 0c56f96c17487..15b89172de7e9 100644 --- a/airflow/providers/amazon/aws/triggers/emr.py +++ b/airflow/providers/amazon/aws/triggers/emr.py @@ -17,8 +17,11 @@ from __future__ import annotations +import asyncio from typing import Any, AsyncIterator +from botocore.exceptions import WaiterError + from airflow.compat.functools import cached_property from airflow.providers.amazon.aws.hooks.emr import EmrContainerHook from airflow.triggers.base import BaseTrigger, TriggerEvent @@ -30,7 +33,6 @@ class EmrContainerOperatorTrigger(BaseTrigger): :param virtual_cluster_id: Reference Emr cluster id :param job_id: job_id to check the state - :param max_attempts: maximum try attempts for polling the status :param aws_conn_id: Reference to AWS connection id :param poll_interval: polling period in seconds to check for the status """ @@ -41,14 +43,12 @@ def __init__( job_id: str, aws_conn_id: str = "aws_default", poll_interval: int = 10, - max_attempts: int | None = None, **kwargs: Any, ): self.virtual_cluster_id = virtual_cluster_id self.job_id = job_id self.aws_conn_id = aws_conn_id self.poll_interval = poll_interval - self.max_attempts = max_attempts super().__init__(**kwargs) @cached_property @@ -56,14 +56,13 @@ def hook(self) -> EmrContainerHook: return EmrContainerHook(self.aws_conn_id, virtual_cluster_id=self.virtual_cluster_id) def serialize(self) -> tuple[str, dict[str, Any]]: - """Serializes EmrContainerSensorTrigger arguments and classpath.""" + """Serializes EmrContainerOperatorTrigger arguments and classpath.""" return ( "airflow.providers.amazon.aws.triggers.emr.EmrContainerOperatorTrigger", { "virtual_cluster_id": self.virtual_cluster_id, "job_id": self.job_id, "aws_conn_id": self.aws_conn_id, - "max_attempts": self.max_attempts, "poll_interval": self.poll_interval, }, ) @@ -71,12 +70,28 @@ def serialize(self) -> tuple[str, dict[str, Any]]: async def run(self) -> AsyncIterator[TriggerEvent]: async with self.hook.async_conn as client: waiter = self.hook.get_waiter("container_job_complete", deferrable=True, client=client) - await waiter.wait( - id=self.job_id, - virtualClusterId=self.virtual_cluster_id, - WaiterConfig={ - "Delay": self.poll_interval, - "MaxAttempts": self.max_attempts, - }, - ) - yield TriggerEvent({"status": "success", "message": "Job completed."}) + attempt = 0 + while True: + attempt = attempt + 1 + try: + await waiter.wait( + id=self.job_id, + virtualClusterId=self.virtual_cluster_id, + WaiterConfig={ + "Delay": self.poll_interval, + "MaxAttempts": 1, + }, + ) + break + except WaiterError as error: + if "terminal failure" in str(error): + yield TriggerEvent({"status": "failure", "message": f"Job Failed: {error}"}) + break + self.log.info( + "Job status is %s. Retrying attempt %s", + error.last_response["jobRun"]["state"], + attempt, + ) + await asyncio.sleep(int(self.poll_interval)) + + yield TriggerEvent({"status": "success", "job_id": self.job_id}) diff --git a/airflow/providers/amazon/aws/waiters/emr-containers.json b/airflow/providers/amazon/aws/waiters/emr-containers.json index 6925ccaab6c6d..a4174b0536e50 100644 --- a/airflow/providers/amazon/aws/waiters/emr-containers.json +++ b/airflow/providers/amazon/aws/waiters/emr-containers.json @@ -11,6 +11,18 @@ "argument": "jobRun.state", "expected": "COMPLETED", "state": "success" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "FAILED", + "state": "failure" + }, + { + "matcher": "path", + "argument": "jobRun.state", + "expected": "CANCELLED", + "state": "failure" } ] }