diff --git a/providers/apache/livy/docs/index.rst b/providers/apache/livy/docs/index.rst index 058669470b5c2..cc3d26534ed2d 100644 --- a/providers/apache/livy/docs/index.rst +++ b/providers/apache/livy/docs/index.rst @@ -103,7 +103,6 @@ PIP package Version required ``apache-airflow-providers-http`` ``>=5.1.0`` ``apache-airflow-providers-common-compat`` ``>=1.12.0`` ``aiohttp`` ``>=3.9.2`` -``asgiref`` ``>=2.3.0`` ========================================== ================== Cross provider package dependencies diff --git a/providers/apache/livy/pyproject.toml b/providers/apache/livy/pyproject.toml index 3000a45eb0530..9a013b15034b8 100644 --- a/providers/apache/livy/pyproject.toml +++ b/providers/apache/livy/pyproject.toml @@ -59,10 +59,9 @@ requires-python = ">=3.10" # After you modify the dependencies, and rebuild your Breeze CI image with ``breeze ci-image build`` dependencies = [ "apache-airflow>=2.11.0", - "apache-airflow-providers-http>=5.1.0", + "apache-airflow-providers-http>=5.1.0", # use next version "apache-airflow-providers-common-compat>=1.12.0", "aiohttp>=3.9.2", - "asgiref>=2.3.0", ] [dependency-groups] diff --git a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py index 9de3582de1109..e9f8e94c74843 100644 --- a/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py +++ b/providers/apache/livy/src/airflow/providers/apache/livy/hooks/livy.py @@ -16,24 +16,18 @@ # under the License. from __future__ import annotations -import asyncio import json import re from collections.abc import Sequence from enum import Enum -from typing import TYPE_CHECKING, Any +from typing import Any -import aiohttp import requests from aiohttp import ClientResponseError -from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException from airflow.providers.http.hooks.http import HttpAsyncHook, HttpHook -if TYPE_CHECKING: - from airflow.models import Connection - class BatchState(Enum): """Batch session states.""" @@ -502,101 +496,10 @@ def __init__( self.extra_options = extra_options or {} self.endpoint_prefix = sanitize_endpoint_prefix(endpoint_prefix) - async def _do_api_call_async( - self, - endpoint: str | None = None, - data: dict[str, Any] | str | None = None, - headers: dict[str, Any] | None = None, - extra_options: dict[str, Any] | None = None, - ) -> Any: - """ - Perform an asynchronous HTTP request call. - - :param endpoint: the endpoint to be called i.e. resource/v1/query? - :param data: payload to be uploaded or request parameters - :param headers: additional headers to be passed through as a dictionary - :param extra_options: Additional kwargs to pass when creating a request. - For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)`` - """ - extra_options = extra_options or {} - - # headers may be passed through directly or in the "extra" field in the connection - # definition - _headers = {} - auth = None - - if self.http_conn_id: - conn = await get_async_connection(self.http_conn_id) - - self.base_url = self._generate_base_url(conn) # type: ignore[arg-type] - if conn.login: - auth = self.auth_type(conn.login, conn.password) - if conn.extra: - try: - _headers.update(conn.extra_dejson) - except TypeError: - self.log.warning("Connection to %s has invalid extra field.", conn.host) - if headers: - _headers.update(headers) - - if self.base_url and not self.base_url.endswith("/") and endpoint and not endpoint.startswith("/"): - url = self.base_url + "/" + endpoint - else: - url = (self.base_url or "") + (endpoint or "") - - async with aiohttp.ClientSession() as session: - if self.method == "GET": - request_func = session.get - elif self.method == "POST": - request_func = session.post - elif self.method == "PATCH": - request_func = session.patch - else: - return {"Response": f"Unexpected HTTP Method: {self.method}", "status": "error"} - - for attempt_num in range(1, 1 + self.retry_limit): - response = await request_func( - url, - json=data if self.method in ("POST", "PATCH") else None, - params=data if self.method == "GET" else None, - headers=_headers or None, - auth=auth, - **extra_options, - ) - try: - response.raise_for_status() - return await response.json() - except ClientResponseError as e: - self.log.warning( - "[Try %d of %d] Request to %s failed.", - attempt_num, - self.retry_limit, - url, - ) - if not self._retryable_error_async(e) or attempt_num == self.retry_limit: - self.log.exception("HTTP error, status code: %s", e.status) - # In this case, the user probably made a mistake. - # Don't retry. - return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"} - - await asyncio.sleep(self.retry_delay) - - def _generate_base_url(self, conn: Connection) -> str: - if conn.host and "://" in conn.host: - base_url: str = conn.host - else: - # schema defaults to HTTP - schema = conn.schema if conn.schema else "http" - host = conn.host if conn.host else "" - base_url = f"{schema}://{host}" - if conn.port: - base_url = f"{base_url}:{conn.port}" - return base_url - async def run_method( self, endpoint: str, - method: str = "GET", + method: str | None = None, data: Any | None = None, headers: dict[str, Any] | None = None, ) -> Any: @@ -609,16 +512,29 @@ async def run_method( :param headers: headers :return: http response """ - if method not in ("GET", "POST", "PUT", "DELETE", "HEAD"): + method = method or self.method + if method not in {"GET", "PATCH", "POST", "PUT", "DELETE", "HEAD"}: return {"status": "error", "response": f"Invalid http method {method}"} - back_method = self.method - self.method = method + endpoint = ( + f"{self.endpoint_prefix}/{endpoint}" + if self.endpoint_prefix and endpoint + else endpoint or self.endpoint_prefix + ) + try: - result = await self._do_api_call_async(endpoint, data, headers, self.extra_options) - finally: - self.method = back_method - return {"status": "success", "response": result} + async with self.session() as session: + response = await session.run( + endpoint=endpoint, + data=data, + headers={**self._def_headers, **self.extra_headers, **(headers or {})}, + extra_options=self.extra_options, + ) + + result = await response.json() + return {"status": "success", "response": result} + except ClientResponseError as e: + return {"Response": {e.message}, "Status Code": {e.status}, "status": "error"} async def get_batch_state(self, session_id: int | str) -> Any: """ diff --git a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py index 3a819c38064dc..90bdcad8866a6 100644 --- a/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py +++ b/providers/apache/livy/tests/unit/apache/livy/hooks/test_livy.py @@ -592,159 +592,163 @@ async def test_dump_batch_logs_error(self, mock_get_batch_logs): assert log_dump == {"id": 1, "log": ["mock_log_1", "mock_log_2"]} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async") - async def test_run_method_success(self, mock_do_api_call_async): + @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession") + @mock.patch( + "airflow.providers.common.compat.connection.get_async_connection", + return_value=Connection( + conn_id=LIVY_CONN_ID, + conn_type="http", + host="http://host", + port=80, + ), + ) + async def test_run_method_success(self, mock_get_connection, mock_session): """Asserts the run_method for success response.""" - mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}} + mock_session.return_value.__aenter__.return_value.post = AsyncMock() + mock_session.return_value.__aenter__.return_value.post.return_value.json = AsyncMock( + return_value={"id": 1} + ) hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) response = await hook.run_method("localhost", "GET") - assert response["status"] == "success" + assert response == {"status": "success", "response": {"id": 1}} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.LivyAsyncHook._do_api_call_async") - async def test_run_method_error(self, mock_do_api_call_async): + async def test_run_method_error(self): """Asserts the run_method for error response.""" - mock_do_api_call_async.return_value = {"status": "error", "response": {"id": 1}} hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) response = await hook.run_method("localhost", "abc") assert response == {"status": "error", "response": "Invalid http method abc"} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") - async def test_do_api_call_async_post_method_with_success(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for success response for POST method.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"status": "success"} - - mock_session.return_value.__aexit__.return_value = mock_fun + @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession") + @mock.patch( + "airflow.providers.common.compat.connection.get_async_connection", + return_value=Connection( + conn_id=LIVY_CONN_ID, + conn_type="http", + host="http://host", + port=80, + ), + ) + async def test_run_post_method_with_success(self, mock_get_connection, mock_session): + """Asserts the run_method for success response for POST method.""" mock_session.return_value.__aenter__.return_value.post = AsyncMock() mock_session.return_value.__aenter__.return_value.post.return_value.json = AsyncMock( - return_value={"status": "success"} + return_value={"hello": "world"} ) - GET_RUN_ENDPOINT = "api/jobs/runs/get" hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "https://localhost" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - response = await hook._do_api_call_async(GET_RUN_ENDPOINT) - assert response == {"status": "success"} + response = await hook.run_method("api/jobs/runs/get") + assert response["status"] == "success" + assert response["response"] == {"hello": "world"} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") - async def test_do_api_call_async_get_method_with_success(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for GET method.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"status": "success"} - - mock_session.return_value.__aexit__.return_value = mock_fun + @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession") + @mock.patch( + "airflow.providers.common.compat.connection.get_async_connection", + return_value=Connection( + conn_id=LIVY_CONN_ID, + conn_type="http", + host="http://host", + port=80, + ), + ) + async def test_run_get_method_with_success(self, mock_get_connection, mock_session): + """Asserts the run_method for GET method.""" mock_session.return_value.__aenter__.return_value.get = AsyncMock() mock_session.return_value.__aenter__.return_value.get.return_value.json = AsyncMock( - return_value={"status": "success"} + return_value={"hello": "world"} ) - GET_RUN_ENDPOINT = "api/jobs/runs/get" hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) hook.method = "GET" - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "test.com" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - hook.http_conn_id.extra_dejson = "" - response = await hook._do_api_call_async(GET_RUN_ENDPOINT) - assert response == {"status": "success"} + response = await hook.run_method("api/jobs/runs/get") + assert response["status"] == "success" + assert response["response"] == {"hello": "world"} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") - async def test_do_api_call_async_patch_method_with_success(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for PATCH method.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"status": "success"} - - mock_session.return_value.__aexit__.return_value = mock_fun + @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession") + @mock.patch( + "airflow.providers.common.compat.connection.get_async_connection", + return_value=Connection( + conn_id=LIVY_CONN_ID, + conn_type="http", + host="http://host", + port=80, + ), + ) + async def test_run_patch_method_with_success(self, mock_get_connection, mock_session): + """Asserts the run_method for PATCH method.""" mock_session.return_value.__aenter__.return_value.patch = AsyncMock() mock_session.return_value.__aenter__.return_value.patch.return_value.json = AsyncMock( - return_value={"status": "success"} + return_value={"hello": "world"} ) - GET_RUN_ENDPOINT = "api/jobs/runs/get" hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) hook.method = "PATCH" - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "test.com" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - hook.http_conn_id.extra_dejson = "" - response = await hook._do_api_call_async(GET_RUN_ENDPOINT) - assert response == {"status": "success"} + response = await hook.run_method("api/jobs/runs/get") + assert response["status"] == "success" + assert response["response"] == {"hello": "world"} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") - async def test_do_api_call_async_unexpected_method_error(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for unexpected method error""" - GET_RUN_ENDPOINT = "api/jobs/runs/get" + @mock.patch( + "airflow.providers.common.compat.connection.get_async_connection", + return_value=Connection( + conn_id=LIVY_CONN_ID, + conn_type="http", + host="http://host", + port=80, + ), + ) + async def test_run_unexpected_method_with_success(self, mock_get_connection): + """Asserts the run_method for unexpected method error""" hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) hook.method = "abc" - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "test.com" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - hook.http_conn_id.extra_dejson = "" - response = await hook._do_api_call_async(endpoint=GET_RUN_ENDPOINT, headers={}) - assert response == {"Response": "Unexpected HTTP Method: abc", "status": "error"} + response = await hook.run_method(endpoint="api/jobs/runs/get", headers={}) + assert response == {"response": "Invalid http method abc", "status": "error"} @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") - async def test_do_api_call_async_with_type_error(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for TypeError.""" + @mock.patch( + "airflow.providers.common.compat.connection.get_async_connection", + return_value=Connection( + conn_id=LIVY_CONN_ID, + conn_type="http", + host="http://host", + port=80, + ), + ) + async def test_run_put_method_with_type_error(self, mock_get_connection): + """Asserts the run_method for TypeError.""" async def mock_fun(arg1, arg2, arg3, arg4): return {"random value"} - mock_session.return_value.__aexit__.return_value = mock_fun - mock_session.return_value.__aenter__.return_value.patch.return_value.json.return_value = {} hook = LivyAsyncHook(livy_conn_id=LIVY_CONN_ID) hook.method = "PATCH" - hook.retry_limit = 1 - hook.retry_delay = 1 - hook.http_conn_id = mock_get_connection with pytest.raises(TypeError): - await hook._do_api_call_async(endpoint="", data="test", headers=mock_fun, extra_options=mock_fun) + await hook.run_method(endpoint="api/jobs/runs/get", data="test", headers=mock_fun) @pytest.mark.asyncio - @mock.patch("airflow.providers.apache.livy.hooks.livy.aiohttp.ClientSession") - @mock.patch("airflow.providers.apache.livy.hooks.livy.get_async_connection") - async def test_do_api_call_async_with_client_response_error(self, mock_get_connection, mock_session): - """Asserts the _do_api_call_async for Client Response Error.""" - - async def mock_fun(arg1, arg2, arg3, arg4): - return {"random value"} + @mock.patch("airflow.providers.http.hooks.http.aiohttp.ClientSession") + @mock.patch( + "airflow.providers.common.compat.connection.get_async_connection", + return_value=Connection( + conn_id=LIVY_CONN_ID, + conn_type="http", + host="http://host", + port=80, + ), + ) + async def test_run_method_with_client_response_error(self, mock_get_connection, mock_session): + """Asserts the run_method for Client Response Error.""" - mock_session.return_value.__aexit__.return_value = mock_fun - mock_session.return_value.__aenter__.return_value.patch = AsyncMock() - mock_session.return_value.__aenter__.return_value.patch.return_value.json.side_effect = ( - ClientResponseError( + mock_session.return_value.__aenter__.return_value.patch = AsyncMock( + side_effect=ClientResponseError( request_info=RequestInfo(url="example.com", method="PATCH", headers=multidict.CIMultiDict()), status=500, history=[], ) ) - GET_RUN_ENDPOINT = "" hook = LivyAsyncHook(livy_conn_id="livy_default") hook.method = "PATCH" - hook.base_url = "" - hook.http_conn_id = mock_get_connection - hook.http_conn_id.host = "test.com" - hook.http_conn_id.login = "login" - hook.http_conn_id.password = "PASSWORD" - hook.http_conn_id.extra_dejson = "" - response = await hook._do_api_call_async(GET_RUN_ENDPOINT) + response = await hook.run_method("") assert response["status"] == "error" @pytest.fixture @@ -764,7 +768,8 @@ def setup_livy_conn(self, create_connection_without_db): create_connection_without_db(Connection(conn_id="missing_host", conn_type="http", port=1234)) create_connection_without_db(Connection(conn_id="invalid_uri", uri="http://invalid_uri:4321")) - def test_build_get_hook(self, setup_livy_conn): + @pytest.mark.asyncio + async def test_build_get_hook(self, setup_livy_conn): connection_url_mapping = { # id, expected "default_port": "http://host", @@ -776,8 +781,8 @@ def test_build_get_hook(self, setup_livy_conn): for conn_id, expected in connection_url_mapping.items(): hook = LivyAsyncHook(livy_conn_id=conn_id) - response_conn = hook.get_connection(conn_id=conn_id) - assert hook._generate_base_url(response_conn) == expected + async with hook.session() as session: + assert session.base_url == expected def test_build_body(self): # minimal request diff --git a/providers/http/src/airflow/providers/http/hooks/http.py b/providers/http/src/airflow/providers/http/hooks/http.py index ed137a651c426..240bd24d06cc2 100644 --- a/providers/http/src/airflow/providers/http/hooks/http.py +++ b/providers/http/src/airflow/providers/http/hooks/http.py @@ -18,22 +18,25 @@ from __future__ import annotations import copy -from collections.abc import Callable +from collections.abc import AsyncGenerator, Awaitable, Callable +from contextlib import asynccontextmanager from typing import TYPE_CHECKING, Any, cast from urllib.parse import urlparse import aiohttp import tenacity from aiohttp import ClientResponseError +from pydantic import BaseModel from requests import PreparedRequest, Request, Response, Session from requests.auth import HTTPBasicAuth from requests.exceptions import ConnectionError, HTTPError from requests.models import DEFAULT_REDIRECT_LIMIT from requests_toolbelt.adapters.socket_options import TCPKeepAliveAdapter +from tenacity import retry_if_exception -from airflow.providers.common.compat.connection import get_async_connection from airflow.providers.common.compat.sdk import AirflowException, BaseHook from airflow.providers.http.exceptions import HttpErrorException, HttpMethodException +from airflow.utils.log.logging_mixin import LoggingMixin if TYPE_CHECKING: from aiohttp.client_reqrep import ClientResponse @@ -95,6 +98,28 @@ def _process_extra_options_from_connection( return conn_extra_options, passed_extra_options +def _retryable_error_async(exception: BaseException) -> bool: + """ + Determine whether an exception may successful on a subsequent attempt. + + It considers the following to be retryable: + - requests_exceptions.ConnectionError + - requests_exceptions.Timeout + - anything with a status code >= 500 + + Most retryable errors are covered by status code >= 500. + """ + if not isinstance(exception, ClientResponseError): + return False + if exception.status == 429: + # don't retry for too Many Requests + return False + if exception.status == 413: + # don't retry for payload Too Large + return False + return exception.status >= 500 + + class HttpHook(BaseHook): """ Interact with HTTP servers. @@ -399,6 +424,132 @@ def test_connection(self): return False, str(e) +class SessionConfig(BaseModel): + """Configuration container for an asynchronous HTTP session.""" + + base_url: str + headers: dict[str, Any] | None = None + auth: aiohttp.BasicAuth | None = None + extra_options: dict[str, Any] | None = None + + +class AsyncHttpSession(LoggingMixin): + """ + Wrapper around an ``aiohttp.ClientSession`` providing a session bound ``HttpAsyncHook``. + + This class binds an asynchronous HTTP client session to an ``HttpAsyncHook`` and applies connection + configuration, authentication, headers, and retry logic consistently across requests. A single + ``AsyncHttpSession`` instance is intended to be used for multiple HTTP calls within the same logical session. + + :param hook: The ``HttpAsyncHook`` instance that owns this session and provides connection-level behavior + such as retries and logging. + :param request: A callable used to perform the underlying HTTP request. This is typically a bound + ``aiohttp.ClientSession`` request method. + :param config: Resolved session configuration containing base URL, headers, and authentication settings. + """ + + def __init__( + self, + hook: HttpAsyncHook, + request: Callable[..., Awaitable[ClientResponse]], + config: SessionConfig, + ) -> None: + super().__init__() + self._hook = hook + self._request = request + self.config = config + + @property + def http_conn_id(self) -> str: + return self._hook.http_conn_id + + @property + def base_url(self) -> str: + return self.config.base_url + + @property + def method(self) -> str: + return self._hook.method + + @property + def retry_limit(self) -> int: + return self._hook.retry_limit + + @property + def retry_delay(self) -> float: + return self._hook.retry_delay + + @property + def headers(self) -> dict[str, Any] | None: + return self.config.headers + + @property + def extra_options(self) -> dict[str, Any] | None: + return self.config.extra_options + + @property + def auth(self) -> aiohttp.BasicAuth | None: + return self.config.auth + + async def run( + self, + endpoint: str | None = None, + data: dict[str, Any] | str | None = None, + json: dict[str, Any] | str | None = None, + headers: dict[str, Any] | None = None, + extra_options: dict[str, Any] | None = None, + ) -> ClientResponse: + """ + Perform an asynchronous HTTP request call. + + :param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``. + :param data: Payload to be uploaded or request parameters. + :param json: Payload to be uploaded as JSON. + :param headers: Additional headers to be passed through as a dict. + :param extra_options: Additional kwargs to pass when creating a request. + For example, ``run(json=obj)`` is passed as + ``aiohttp.ClientSession().get(json=obj)``. + """ + from tenacity import AsyncRetrying, stop_after_attempt, wait_fixed + + url = _url_from_endpoint(self.base_url, endpoint) + merged_headers = {**(self.headers or {}), **(headers or {})} + extra_options = {**(self.extra_options or {}), **(extra_options or {})} + + async def request_func() -> ClientResponse: + response = await self._request( + url, + params=data if self.method == "GET" else None, + data=data if self.method in {"POST", "PUT", "PATCH"} else None, + json=json, + headers=merged_headers, + auth=self.auth, + **extra_options, + ) + response.raise_for_status() + return response + + async for attempt in AsyncRetrying( + stop=stop_after_attempt(self.retry_limit), + wait=wait_fixed(self.retry_delay), + retry=retry_if_exception(_retryable_error_async), + reraise=True, + ): + with attempt: + try: + return await request_func() + except ClientResponseError as e: + self.log.warning( + "[Try %d of %d] Request to %s failed.", + attempt.retry_state.attempt_number, + self.retry_limit, + url, + ) + raise e + + raise NotImplementedError # should not reach this, but makes mypy happy + + class HttpAsyncHook(BaseHook): """ Interact with HTTP servers asynchronously. @@ -408,6 +559,8 @@ class HttpAsyncHook(BaseHook): API url i.e https://www.google.com/ and optional authentication credentials. Default headers can also be specified in the Extra field in json format. :param auth_type: The auth type for the service + :param retry_limit: Maximum number of times to retry this job if it fails (default is 3) + :param retry_delay: Delay between retry attempts (default is 1.0) """ conn_name_attr = "http_conn_id" @@ -429,13 +582,82 @@ def __init__( self._retry_obj: Callable[..., Any] self.auth_type: Any = auth_type if retry_limit < 1: - raise ValueError("Retry limit must be greater than equal to 1") + raise ValueError("Retry limit must be greater or equal to 1") self.retry_limit = retry_limit self.retry_delay = retry_delay + self._config: SessionConfig | None = None + + def _get_request_func(self, session: aiohttp.ClientSession) -> Callable[..., Any]: + method = self.method + if method == "GET": + return session.get + if method == "POST": + return session.post + if method == "PATCH": + return session.patch + if method == "HEAD": + return session.head + if method == "PUT": + return session.put + if method == "DELETE": + return session.delete + if method == "OPTIONS": + return session.options + raise HttpMethodException(f"Unexpected HTTP Method: {method}") + + async def config(self) -> SessionConfig: + if not self._config: + from airflow.providers.common.compat.connection import get_async_connection + + base_url: str = self.base_url + auth: aiohttp.BasicAuth | None = None + headers: dict[str, Any] = {} + extra_options: dict[str, Any] = {} + + if self.http_conn_id: + conn = await get_async_connection(conn_id=self.http_conn_id) + + if conn.host and "://" in conn.host: + base_url = conn.host + else: + schema = conn.schema or "http" + base_url = f"{schema}://{conn.host or ''}" + + if conn.port: + base_url += f":{conn.port}" + + if conn.login: + auth = self.auth_type(conn.login, conn.password) + + if conn.extra: + conn_extra_options, extra_options = _process_extra_options_from_connection( + conn=conn, extra_options={} + ) + headers.update(conn_extra_options) + + self._config = SessionConfig( + base_url=base_url, + headers=headers, + auth=auth, + extra_options=extra_options, + ) + return self._config + + @asynccontextmanager + async def session(self) -> AsyncGenerator[AsyncHttpSession, None]: + """ + Create an ``AsyncHttpSession`` bound to a single ``aiohttp.ClientSession``. + + Airflow connection resolution happens exactly once here. + """ + async with aiohttp.ClientSession() as session: + request = self._get_request_func(session=session) + config = await self.config() + yield AsyncHttpSession(hook=self, request=request, config=config) async def run( self, - session: aiohttp.ClientSession, + session: aiohttp.ClientSession | None = None, endpoint: str | None = None, data: dict[str, Any] | str | None = None, json: dict[str, Any] | str | None = None, @@ -445,6 +667,7 @@ async def run( """ Perform an asynchronous HTTP request call. + :param session: ``aiohttp.ClientSession`` :param endpoint: Endpoint to be called, i.e. ``resource/v1/query?``. :param data: Payload to be uploaded or request parameters. :param json: Payload to be uploaded as JSON. @@ -453,103 +676,17 @@ async def run( For example, ``run(json=obj)`` is passed as ``aiohttp.ClientSession().get(json=obj)``. """ - extra_options = extra_options or {} - - # headers may be passed through directly or in the "extra" field in the connection - # definition - _headers = {} - auth = None - - if self.http_conn_id: - conn = await get_async_connection(self.http_conn_id) - - if conn.host and "://" in conn.host: - self.base_url = conn.host - else: - # schema defaults to HTTP - schema = conn.schema if conn.schema else "http" - host = conn.host if conn.host else "" - self.base_url = schema + "://" + host - - if conn.port: - self.base_url += f":{conn.port}" - if conn.login: - auth = self.auth_type(conn.login, conn.password) - if conn.extra: - conn_extra_options, extra_options = _process_extra_options_from_connection( - conn=conn, extra_options=extra_options + try: + if session is not None: + request = self._get_request_func(session=session) + config = await self.config() + return await AsyncHttpSession(hook=self, request=request, config=config).run( + endpoint=endpoint, data=data, json=json, headers=headers, extra_options=extra_options ) - try: - _headers.update(conn_extra_options) - except TypeError: - self.log.warning("Connection to %s has invalid extra field.", conn.host) - if headers: - _headers.update(headers) - - url = _url_from_endpoint(self.base_url, endpoint) - - if self.method == "GET": - request_func = session.get - elif self.method == "POST": - request_func = session.post - elif self.method == "PATCH": - request_func = session.patch - elif self.method == "HEAD": - request_func = session.head - elif self.method == "PUT": - request_func = session.put - elif self.method == "DELETE": - request_func = session.delete - elif self.method == "OPTIONS": - request_func = session.options - else: - raise HttpMethodException(f"Unexpected HTTP Method: {self.method}") - - for attempt in range(1, 1 + self.retry_limit): - response = await request_func( - url, - params=data if self.method == "GET" else None, - data=data if self.method in ("POST", "PUT", "PATCH") else None, - json=json, - headers=_headers, - auth=auth, - **extra_options, - ) - try: - response.raise_for_status() - except ClientResponseError as e: - self.log.warning( - "[Try %d of %d] Request to %s failed.", - attempt, - self.retry_limit, - url, + async with self.session() as http: + return await http.run( + endpoint=endpoint, data=data, json=json, headers=headers, extra_options=extra_options ) - if not self._retryable_error_async(e) or attempt == self.retry_limit: - self.log.exception("HTTP error with status: %s", e.status) - # In this case, the user probably made a mistake. - # Don't retry. - raise HttpErrorException(f"{e.status}:{e.message}") - else: - return response - - raise NotImplementedError # should not reach this, but makes mypy happy - - def _retryable_error_async(self, exception: ClientResponseError) -> bool: - """ - Determine whether an exception may successful on a subsequent attempt. - - It considers the following to be retryable: - - requests_exceptions.ConnectionError - - requests_exceptions.Timeout - - anything with a status code >= 500 - - Most retryable errors are covered by status code >= 500. - """ - if exception.status == 429: - # don't retry for too Many Requests - return False - if exception.status == 413: - # don't retry for payload Too Large - return False - return exception.status >= 500 + except ClientResponseError as e: + raise HttpErrorException(f"{e.status}:{e.message}")