From a0e08470e135622a95d3e6d6cd52d19eab34dc2f Mon Sep 17 00:00:00 2001 From: Hussein Awala Date: Mon, 4 Sep 2023 23:20:06 +0200 Subject: [PATCH] Async DB connection for Airflow --- .github/workflows/ci.yml | 6 + airflow/config_templates/config.yml | 11 ++ airflow/hooks/base.py | 14 +++ airflow/models/connection.py | 52 +++++++- airflow/models/variable.py | 103 ++++++++++++--- airflow/providers/mysql/provider.yaml | 1 + airflow/providers/postgres/provider.yaml | 1 + airflow/providers/sqlite/provider.yaml | 1 + airflow/secrets/base_secrets.py | 39 ++++++ airflow/secrets/metastore.py | 48 ++++++- airflow/settings.py | 41 ++++++ airflow/utils/session.py | 44 ++++++- .../airflow_breeze/commands/common_options.py | 7 ++ .../commands/testing_commands.py | 6 + .../commands/testing_commands_config.py | 2 + .../src/airflow_breeze/global_constants.py | 1 + .../src/airflow_breeze/params/shell_params.py | 8 ++ generated/provider_dependencies.json | 3 + images/breeze/output_testing_db-tests.svg | 118 +++++++++--------- images/breeze/output_testing_db-tests.txt | 2 +- images/breeze/output_testing_tests.svg | 110 ++++++++-------- images/breeze/output_testing_tests.txt | 2 +- scripts/ci/docker-compose/backend-mysql.yml | 1 + scripts/ci/docker-compose/backend-none.yml | 1 + .../ci/docker-compose/backend-postgres.yml | 1 + .../backend-sqlite-no-volume.yml | 1 + scripts/ci/docker-compose/backend-sqlite.yml | 1 + setup.cfg | 2 +- tests/always/test_secrets_backends.py | 24 +++- tests/conftest.py | 44 ++++++- tests/core/test_configuration.py | 1 + tests/models/test_variable.py | 6 + 32 files changed, 560 insertions(+), 142 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a508c06f4fc44..a559089e0c4a1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -986,6 +986,7 @@ jobs: python-version: "${{fromJson(needs.build-info.outputs.python-versions)}}" postgres-version: "${{fromJson(needs.build-info.outputs.postgres-versions)}}" exclude: "${{fromJson(needs.build-info.outputs.postgres-exclude)}}" + db-tests-mode: [sync, async] fail-fast: false env: RUNS_ON: "${{needs.build-info.outputs.runs-on}}" @@ -1018,6 +1019,7 @@ jobs: Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.parallel-test-types-list-as-string}} run: > breeze testing db-tests + --db-tests-mode="${{matrix.db-tests-mode}}" --parallel-test-types "${{needs.build-info.outputs.parallel-test-types-list-as-string}}" - name: "Tests ARM Pytest collection: ${{matrix.python-version}}" run: breeze testing db-tests --collect-only --remove-arm-packages @@ -1193,6 +1195,7 @@ jobs: python-version: "${{fromJson(needs.build-info.outputs.python-versions)}}" mysql-version: "${{fromJson(needs.build-info.outputs.mysql-versions)}}" exclude: "${{fromJson(needs.build-info.outputs.mysql-exclude)}}" + db-tests-mode: [sync, async] fail-fast: false env: RUNS_ON: "${{needs.build-info.outputs.runs-on}}" @@ -1225,6 +1228,7 @@ jobs: Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.parallel-test-types-list-as-string}} run: > breeze testing db-tests + --db-tests-mode="${{matrix.db-tests-mode}}" --parallel-test-types "${{needs.build-info.outputs.parallel-test-types-list-as-string}}" - name: > Post Tests success: MySQL" @@ -1309,6 +1313,7 @@ jobs: matrix: python-version: ${{ fromJson(needs.build-info.outputs.python-versions) }} exclude: ${{ fromJson(needs.build-info.outputs.sqlite-exclude) }} + db-tests-mode: [sync, async] fail-fast: false if: needs.build-info.outputs.run-tests == 'true' env: @@ -1340,6 +1345,7 @@ jobs: Tests: ${{matrix.python-version}}:${{needs.build-info.outputs.parallel-test-types-list-as-string}} run: > breeze testing db-tests + --db-tests-mode="${{matrix.db-tests-mode}}" --parallel-test-types "${{needs.build-info.outputs.parallel-test-types-list-as-string}}" - name: > Post Tests success: Sqlite" diff --git a/airflow/config_templates/config.yml b/airflow/config_templates/config.yml index 059467e4a0c03..60ab049f3fd79 100644 --- a/airflow/config_templates/config.yml +++ b/airflow/config_templates/config.yml @@ -494,6 +494,17 @@ database: sensitive: true example: ~ default: "sqlite:///{AIRFLOW_HOME}/airflow.db" + async_sql_alchemy_conn: + description: | + The SqlAlchemy connection string to the metadata database. + SqlAlchemy supports many different database engines. + More information here: + http://airflow.apache.org/docs/apache-airflow/stable/howto/set-up-database.html#database-uri + version_added: 2.9.0 + type: string + sensitive: true + example: ~ + default: "sqlite+aiosqlite:///{AIRFLOW_HOME}/airflow.db" sql_alchemy_engine_args: description: | Extra engine specific keyword args passed to SQLAlchemy's create_engine, as a JSON-encoded value diff --git a/airflow/hooks/base.py b/airflow/hooks/base.py index 6ec0a8938ed20..c3b2fa29f44cb 100644 --- a/airflow/hooks/base.py +++ b/airflow/hooks/base.py @@ -83,6 +83,20 @@ def get_connection(cls, conn_id: str) -> Connection: log.info("Using connection ID '%s' for task execution.", conn.conn_id) return conn + @classmethod + async def async_get_connection(cls, conn_id: str) -> Connection: + """ + Get connection, given connection id. + + :param conn_id: connection id + :return: connection + """ + from airflow.models.connection import Connection + + conn = await Connection.async_get_connection_from_secrets(conn_id) + log.info("Using connection ID '%s' for task execution.", conn.conn_id) + return conn + @classmethod def get_hook(cls, conn_id: str) -> BaseHook: """ diff --git a/airflow/models/connection.py b/airflow/models/connection.py index 521a5d880e4d4..dc05ce6d226ec 100644 --- a/airflow/models/connection.py +++ b/airflow/models/connection.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import contextlib import json import logging import warnings @@ -446,21 +447,32 @@ def extra_dejson(self) -> dict: return obj - @classmethod - def get_connection_from_secrets(cls, conn_id: str) -> Connection: + @staticmethod + def get_connection_from_cache(conn_id: str) -> Connection | None: """ - Get connection by conn_id. + Get connection by conn_id from cache. :param conn_id: connection id :return: connection """ # check cache first # enabled only if SecretCache.init() has been called first - try: + with contextlib.suppress(SecretCache.NotPresentException): uri = SecretCache.get_connection_uri(conn_id) return Connection(conn_id=conn_id, uri=uri) - except SecretCache.NotPresentException: - pass # continue business + return None + + @classmethod + def get_connection_from_secrets(cls, conn_id: str) -> Connection: + """ + Get connection by conn_id. + + :param conn_id: connection id + :return: connection + """ + cached_conn = cls.get_connection_from_cache(conn_id) + if cached_conn: + return cached_conn # iterate over backends if not in cache (or expired) for secrets_backend in ensure_secrets_loaded(): @@ -478,6 +490,34 @@ def get_connection_from_secrets(cls, conn_id: str) -> Connection: raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") + @classmethod + async def async_get_connection_from_secrets(cls, conn_id: str) -> Connection: + """ + Get connection by conn_id. + + :param conn_id: connection id + :return: connection + """ + cached_conn = cls.get_connection_from_cache(conn_id) + if cached_conn: + return cached_conn + + # iterate over backends if not in cache (or expired) + for secrets_backend in ensure_secrets_loaded(): + try: + conn = await secrets_backend.async_get_connection(conn_id=conn_id) + if conn: + SecretCache.save_connection_uri(conn_id, conn.get_uri()) + return conn + except Exception: + log.exception( + "Unable to retrieve connection from secrets backend (%s). " + "Checking subsequent secrets backend.", + type(secrets_backend).__name__, + ) + + raise AirflowNotFoundException(f"The conn_id `{conn_id}` isn't defined") + def to_dict(self, *, prune_empty: bool = False, validate: bool = True) -> dict[str, Any]: """ Convert Connection to json-serializable dictionary. diff --git a/airflow/models/variable.py b/airflow/models/variable.py index bfd70ae4c2b42..73c19278b6c14 100644 --- a/airflow/models/variable.py +++ b/airflow/models/variable.py @@ -17,6 +17,7 @@ # under the License. from __future__ import annotations +import contextlib import json import logging from typing import TYPE_CHECKING, Any @@ -123,19 +124,13 @@ def setdefault(cls, key, default, description=None, deserialize_json=False): return obj @classmethod - def get( + def _process_variable_value( cls, key: str, - default_var: Any = __NO_DEFAULT_SENTINEL, - deserialize_json: bool = False, - ) -> Any: - """Get a value for an Airflow Variable Key. - - :param key: Variable Key - :param default_var: Default value of the Variable if the Variable doesn't exist - :param deserialize_json: Deserialize the value to a Python dict - """ - var_val = Variable.get_variable_from_secrets(key=key) + default_var: Any, + deserialize_json: bool, + var_val: str | None, + ): if var_val is None: if default_var is not cls.__NO_DEFAULT_SENTINEL: return default_var @@ -150,6 +145,40 @@ def get( mask_secret(var_val, key) return var_val + @classmethod + def get( + cls, + key: str, + default_var: Any = __NO_DEFAULT_SENTINEL, + deserialize_json: bool = False, + ) -> Any: + """Get a value for an Airflow Variable Key. + + :param key: Variable Key + :param default_var: Default value of the Variable if the Variable doesn't exist + :param deserialize_json: Deserialize the value to a Python dict + """ + var_val = Variable.get_variable_from_secrets(key=key) + return cls._process_variable_value(key, default_var, deserialize_json, var_val) + + @classmethod + async def async_get( + cls, + key: str, + default_var: Any = __NO_DEFAULT_SENTINEL, + deserialize_json: bool = False, + ) -> Any: + """Get a value for an Airflow Variable Key asynchronously. + + :param key: Variable Key + :param default_var: Default value of the Variable if the Variable doesn't exist + :param deserialize_json: Deserialize the value to a Python dict + """ + var_val = await Variable.async_get_variable_from_secrets(key=key) + return cls._process_variable_value(key, default_var, deserialize_json, var_val) + + # TODO: implement async version for the other methods + @staticmethod @provide_session @internal_api_call @@ -258,21 +287,31 @@ def check_for_write_conflict(key: str) -> None: return None @staticmethod - def get_variable_from_secrets(key: str) -> str | None: + def get_variable_from_cache(key: str) -> str | None: """ - Get Airflow Variable by iterating over all Secret Backends. + Get variable by key from cache. :param key: Variable Key :return: Variable Value """ # check cache first # enabled only if SecretCache.init() has been called first - try: + with contextlib.suppress(SecretCache.NotPresentException): return SecretCache.get_variable(key) - except SecretCache.NotPresentException: - pass # continue business + return None + + @staticmethod + def get_variable_from_secrets(key: str) -> str | None: + """ + Get Airflow Variable by iterating over all Secret Backends. + + :param key: Variable Key + :return: Variable Value + """ + var_val = Variable.get_variable_from_cache(key) + if var_val is not None: + return var_val - var_val = None # iterate over backends if not in cache (or expired) for secrets_backend in ensure_secrets_loaded(): try: @@ -288,3 +327,33 @@ def get_variable_from_secrets(key: str) -> str | None: SecretCache.save_variable(key, var_val) # we save None as well return var_val + + @staticmethod + async def async_get_variable_from_secrets(key: str) -> str | None: + """ + Get Airflow Variable by iterating over all Secret Backends asynchronously. + + :param key: Variable Key + :return: Variable Value + """ + # check cache first + # enabled only if SecretCache.init() has been called first + var_val = Variable.get_variable_from_cache(key) + if var_val is not None: + return var_val + + # iterate over backends if not in cache (or expired) + for secrets_backend in ensure_secrets_loaded(): + try: + var_val = await secrets_backend.async_get_variable(key=key) + if var_val is not None: + break + except Exception: + log.exception( + "Unable to retrieve variable from secrets backend (%s). " + "Checking subsequent secrets backend.", + type(secrets_backend).__name__, + ) + + SecretCache.save_variable(key, var_val) # we save None as well + return var_val diff --git a/airflow/providers/mysql/provider.yaml b/airflow/providers/mysql/provider.yaml index 9712ec8676ae4..83015876acd00 100644 --- a/airflow/providers/mysql/provider.yaml +++ b/airflow/providers/mysql/provider.yaml @@ -59,6 +59,7 @@ dependencies: - apache-airflow-providers-common-sql>=1.3.1 - mysqlclient>=1.3.6 - mysql-connector-python>=8.0.11 + - aiomysql>=0.2.0 integrations: - integration-name: MySQL diff --git a/airflow/providers/postgres/provider.yaml b/airflow/providers/postgres/provider.yaml index c6e4338978fb7..617377ab5cc7b 100644 --- a/airflow/providers/postgres/provider.yaml +++ b/airflow/providers/postgres/provider.yaml @@ -59,6 +59,7 @@ dependencies: - apache-airflow>=2.6.0 - apache-airflow-providers-common-sql>=1.3.1 - psycopg2-binary>=2.8.0 + - asyncpg integrations: - integration-name: PostgreSQL diff --git a/airflow/providers/sqlite/provider.yaml b/airflow/providers/sqlite/provider.yaml index 753b433552bf3..65f62e5764b8a 100644 --- a/airflow/providers/sqlite/provider.yaml +++ b/airflow/providers/sqlite/provider.yaml @@ -51,6 +51,7 @@ versions: dependencies: - apache-airflow>=2.6.0 - apache-airflow-providers-common-sql>=1.3.1 + - aiosqlite>=0.19.0 integrations: - integration-name: SQLite diff --git a/airflow/secrets/base_secrets.py b/airflow/secrets/base_secrets.py index 3346d880f2eb5..ae88626d5a27d 100644 --- a/airflow/secrets/base_secrets.py +++ b/airflow/secrets/base_secrets.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import contextlib import warnings from abc import ABC from typing import TYPE_CHECKING @@ -51,6 +52,18 @@ def get_conn_value(self, conn_id: str) -> str | None: """ raise NotImplementedError + async def async_get_conn_value(self, conn_id: str) -> str | None: + """ + Retrieve from Secrets Backend a string value representing the Connection object asynchronously. + + If the client your secrets backend uses already returns a python dict, you should override + ``get_connection`` instead. + + :param conn_id: connection id + """ + # if the backend doesn't implement async_get_conn_value, we fallback to get_conn_value + return self.get_conn_value(conn_id=conn_id) + def deserialize_connection(self, conn_id: str, value: str) -> Connection: """ Given a serialized representation of the airflow Connection, return an instance. @@ -117,6 +130,22 @@ def get_connection(self, conn_id: str) -> Connection | None: else: return None + async def async_get_connection(self, conn_id: str) -> Connection | None: + """ + Return connection object with a given ``conn_id`` asynchronously. + + Tries ``get_conn_value`` first and if not implemented, tries ``get_conn_uri`` + + :param conn_id: connection id + """ + value = None + with contextlib.suppress(Exception): + value = await self.async_get_conn_value(conn_id=conn_id) + if value: + return self.deserialize_connection(conn_id=conn_id, value=value) + else: + return None + def get_connections(self, conn_id: str) -> list[Connection]: """ Return connection object with a given ``conn_id``. @@ -143,6 +172,16 @@ def get_variable(self, key: str) -> str | None: """ raise NotImplementedError() + async def async_get_variable(self, key: str) -> str | None: + """ + Return value for Airflow Variable asynchronously. + + :param key: Variable Key + :return: Variable Value + """ + # if the backend doesn't implement async_get_variable, we fallback to get_variable + return self.get_variable(key=key) + def get_config(self, key: str) -> str | None: """ Return value for Airflow Config Key. diff --git a/airflow/secrets/metastore.py b/airflow/secrets/metastore.py index 481b7aaf3a8fb..419a7704b344a 100644 --- a/airflow/secrets/metastore.py +++ b/airflow/secrets/metastore.py @@ -26,9 +26,15 @@ from airflow.api_internal.internal_api_call import internal_api_call from airflow.exceptions import RemovedInAirflow3Warning from airflow.secrets import BaseSecretsBackend -from airflow.utils.session import NEW_SESSION, provide_session +from airflow.utils.session import ( + NEW_ASYNC_SESSION, + NEW_SESSION, + provide_async_session, + provide_session, +) if TYPE_CHECKING: + from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from airflow.models.connection import Connection @@ -41,6 +47,12 @@ class MetastoreBackend(BaseSecretsBackend): def get_connection(self, conn_id: str, session: Session = NEW_SESSION) -> Connection | None: return MetastoreBackend._fetch_connection(conn_id, session=session) + @provide_async_session + async def async_get_connection( + self, conn_id: str, session: AsyncSession = NEW_ASYNC_SESSION + ) -> Connection | None: + return await MetastoreBackend._async_fetch_connection(conn_id, session=session) + @provide_session def get_connections(self, conn_id: str, session: Session = NEW_SESSION) -> list[Connection]: warnings.warn( @@ -64,6 +76,16 @@ def get_variable(self, key: str, session: Session = NEW_SESSION) -> str | None: """ return MetastoreBackend._fetch_variable(key=key, session=session) + @provide_async_session + async def async_get_variable(self, key: str, session: AsyncSession = NEW_ASYNC_SESSION) -> str | None: + """ + Get Airflow Variable from Metadata DB asynchronously. + + :param key: Variable Key + :return: Variable Value + """ + return await MetastoreBackend._async_fetch_variable(key=key, session=session) + @staticmethod @internal_api_call @provide_session @@ -74,6 +96,18 @@ def _fetch_connection(conn_id: str, session: Session = NEW_SESSION) -> Connectio session.expunge_all() return conn + @staticmethod + @internal_api_call + @provide_async_session + async def _async_fetch_connection( + conn_id: str, session: AsyncSession = NEW_ASYNC_SESSION + ) -> Connection | None: + from airflow.models.connection import Connection + + conn = await session.scalar(select(Connection).where(Connection.conn_id == conn_id).limit(1)) + session.expunge_all() + return conn + @staticmethod @internal_api_call @provide_session @@ -85,3 +119,15 @@ def _fetch_variable(key: str, session: Session = NEW_SESSION) -> str | None: if var_value: return var_value.val return None + + @staticmethod + @internal_api_call + @provide_async_session + async def _async_fetch_variable(key: str, session: AsyncSession = NEW_ASYNC_SESSION) -> str | None: + from airflow.models.variable import Variable + + var_value = await session.scalar(select(Variable).where(Variable.key == key).limit(1)) + session.expunge_all() + if var_value: + return var_value.val + return None diff --git a/airflow/settings.py b/airflow/settings.py index e86a8d557348b..9e9fadf01f94c 100644 --- a/airflow/settings.py +++ b/airflow/settings.py @@ -24,12 +24,17 @@ import os import sys import warnings +from asyncio import current_task from typing import TYPE_CHECKING, Any, Callable import pendulum import pluggy import sqlalchemy from sqlalchemy import create_engine, exc, text + +# from sqlalchemy.ext.asyncio.session import async_sessionmaker +from sqlalchemy.ext.asyncio import AsyncSession as SAAsyncSession, create_async_engine +from sqlalchemy.ext.asyncio.scoping import async_scoped_session from sqlalchemy.orm import scoped_session, sessionmaker from sqlalchemy.pool import NullPool @@ -43,6 +48,7 @@ if TYPE_CHECKING: from sqlalchemy.engine import Engine + from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.orm import Session as SASession from airflow.www.utils import UIAlert @@ -80,6 +86,7 @@ SIMPLE_LOG_FORMAT = conf.get("logging", "simple_log_format") SQL_ALCHEMY_CONN: str | None = None +ASYNC_SQL_ALCHEMY_CONN: str | None = None PLUGINS_FOLDER: str | None = None LOGGING_CLASS_PATH: str | None = None DONOT_MODIFY_HANDLERS: bool | None = None @@ -88,6 +95,9 @@ engine: Engine Session: Callable[..., SASession] +async_engine: AsyncEngine +AsyncSession: Callable[..., SAAsyncSession] + # The JSON library to use for DAG Serialization and De-Serialization json = json @@ -189,10 +199,12 @@ def load_policy_plugins(pm: pluggy.PluginManager): def configure_vars(): """Configure Global Variables from airflow.cfg.""" global SQL_ALCHEMY_CONN + global ASYNC_SQL_ALCHEMY_CONN global DAGS_FOLDER global PLUGINS_FOLDER global DONOT_MODIFY_HANDLERS SQL_ALCHEMY_CONN = conf.get("database", "SQL_ALCHEMY_CONN") + ASYNC_SQL_ALCHEMY_CONN = conf.get("database", "ASYNC_SQL_ALCHEMY_CONN") DAGS_FOLDER = os.path.expanduser(conf.get("core", "DAGS_FOLDER")) PLUGINS_FOLDER = conf.get("core", "plugins_folder", fallback=os.path.join(AIRFLOW_HOME, "plugins")) @@ -226,6 +238,9 @@ def configure_orm(disable_connection_pool=False, pool_class=None): global Session global engine + global Session + global async_engine + global AsyncSession if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true": # Skip DB initialization in unit tests, if DB tests are skipped Session = SkipDBTestsSession @@ -240,6 +255,7 @@ def configure_orm(disable_connection_pool=False, pool_class=None): connect_args = {} engine = create_engine(SQL_ALCHEMY_CONN, connect_args=connect_args, **engine_args, future=True) + async_engine = create_async_engine(ASYNC_SQL_ALCHEMY_CONN, connect_args=connect_args, future=True) mask_secret(engine.url.password) @@ -253,6 +269,17 @@ def configure_orm(disable_connection_pool=False, pool_class=None): expire_on_commit=False, ) ) + # async_sessionmaker is available only from SQLAlchemy 2.0+, so we need to use the sync version + AsyncSession = async_scoped_session( + sessionmaker( + class_=SAAsyncSession, + autocommit=False, + autoflush=False, + bind=async_engine, + expire_on_commit=False, + ), + scopefunc=current_task, + ) if engine.dialect.name == "mssql": session = Session() try: @@ -380,9 +407,23 @@ def dispose_orm(): engine = None +async def dispose_async_orm(): + global async_engine + global AsyncSession + if AsyncSession is not None: # type: ignore[truthy-function] + await AsyncSession.remove() + AsyncSession = None + if async_engine: + await async_engine.dispose() + async_engine = None + + def reconfigure_orm(disable_connection_pool=False, pool_class=None): """Properly close database connections and re-configure ORM.""" + import asyncio + dispose_orm() + asyncio.run(dispose_async_orm()) configure_orm(disable_connection_pool=disable_connection_pool, pool_class=pool_class) diff --git a/airflow/utils/session.py b/airflow/utils/session.py index b3b610d19952f..8738ba9b1aa9e 100644 --- a/airflow/utils/session.py +++ b/airflow/utils/session.py @@ -19,8 +19,9 @@ import contextlib from functools import wraps from inspect import signature -from typing import Callable, Generator, TypeVar, cast +from typing import Any, AsyncIterator, Callable, Coroutine, Generator, TypeVar, cast +from sqlalchemy.ext.asyncio import AsyncSession as ASAsyncSession from sqlalchemy.orm import Session as SASession from airflow import settings @@ -44,6 +45,23 @@ def create_session() -> Generator[SASession, None, None]: session.close() +@contextlib.asynccontextmanager +async def create_async_session() -> AsyncIterator[ASAsyncSession]: + """Contextmanager that will create and teardown an async session.""" + AsyncSession = getattr(settings, "AsyncSession", None) + if AsyncSession is None: + raise RuntimeError("AsyncSession must be set before!") + try: + async with AsyncSession() as async_session: + yield async_session + await async_session.commit() + except Exception: + await async_session.rollback() + raise + finally: + await async_session.close() + + PS = ParamSpec("PS") RT = TypeVar("RT") @@ -81,8 +99,32 @@ def wrapper(*args, **kwargs) -> RT: return wrapper +def provide_async_session( + func: Callable[PS, Coroutine[Any, Any, RT]], +) -> Callable[PS, Coroutine[Any, Any, RT]]: + """ + Provide an async session if it isn't provided. + + If you want to reuse a session or run the function as part of a + database transaction, you pass it to the function, if not this wrapper + will create one and close it for you. + """ + session_args_idx = find_session_idx(func) + + @wraps(func) + async def wrapper(*args, **kwargs) -> RT: + if "session" in kwargs or session_args_idx < len(args): + return await func(*args, **kwargs) + else: + async with create_async_session() as session: + return await func(*args, session=session, **kwargs) + + return wrapper + + # A fake session to use in functions decorated by provide_session. This allows # the 'session' argument to be of type Session instead of Session | None, # making it easier to type hint the function body without dealing with the None # case that can never happen at runtime. NEW_SESSION: SASession = cast(SASession, None) +NEW_ASYNC_SESSION: ASAsyncSession = cast(ASAsyncSession, None) diff --git a/dev/breeze/src/airflow_breeze/commands/common_options.py b/dev/breeze/src/airflow_breeze/commands/common_options.py index 12d0ee77b8d7f..5fddb35e7b777 100644 --- a/dev/breeze/src/airflow_breeze/commands/common_options.py +++ b/dev/breeze/src/airflow_breeze/commands/common_options.py @@ -307,6 +307,13 @@ def _set_default_from_parent(ctx: click.core.Context, option: click.core.Option, is_flag=True, envvar="run_db_tests_only", ) +option_db_tests_mode = click.option( + "--db-tests-mode", + help="Mode of running tests that require a database", + type=click.Choice(["sync", "async", "all"]), + default="sync", + show_default=True, +) option_run_in_parallel = click.option( "--run-in-parallel", help="Run the operation in parallel on all or selected subset of parameters.", diff --git a/dev/breeze/src/airflow_breeze/commands/testing_commands.py b/dev/breeze/src/airflow_breeze/commands/testing_commands.py index c4aff8797e796..0f49e0c5b3cc4 100644 --- a/dev/breeze/src/airflow_breeze/commands/testing_commands.py +++ b/dev/breeze/src/airflow_breeze/commands/testing_commands.py @@ -20,6 +20,7 @@ import sys from datetime import datetime from time import sleep +from typing import Literal import click from click import IntRange @@ -28,6 +29,7 @@ from airflow_breeze.commands.common_options import ( option_backend, option_db_reset, + option_db_tests_mode, option_debug_resources, option_downgrade_sqlalchemy, option_dry_run, @@ -460,6 +462,7 @@ def _verify_parallelism_parameters( @option_test_type @option_test_timeout @option_run_db_tests_only +@option_db_tests_mode @option_skip_db_tests @option_db_reset @option_run_in_parallel @@ -521,6 +524,7 @@ def command_for_tests(**kwargs): @option_verbose @option_dry_run @option_github_repository +@option_db_tests_mode def command_for_db_tests(**kwargs): _run_test_command( integration=(), @@ -603,6 +607,7 @@ def _run_test_command( python: str, remove_arm_packages: bool, run_db_tests_only: bool, + db_tests_mode: Literal["sync", "async", "all"] = "sync", run_in_parallel: bool, skip_cleanup: bool, skip_db_tests: bool, @@ -647,6 +652,7 @@ def _run_test_command( python=python, remove_arm_packages=remove_arm_packages, run_db_tests_only=run_db_tests_only, + db_tests_mode=db_tests_mode, skip_db_tests=skip_db_tests, skip_provider_tests=skip_provider_tests, test_type=test_type, diff --git a/dev/breeze/src/airflow_breeze/commands/testing_commands_config.py b/dev/breeze/src/airflow_breeze/commands/testing_commands_config.py index 404e0cabe0eb9..c4ea9693f2b8e 100644 --- a/dev/breeze/src/airflow_breeze/commands/testing_commands_config.py +++ b/dev/breeze/src/airflow_breeze/commands/testing_commands_config.py @@ -45,6 +45,7 @@ "options": [ "--run-db-tests-only", "--skip-db-tests", + "--db-tests-mode", ], }, { @@ -137,6 +138,7 @@ "options": [ "--parallel-test-types", "--excluded-parallel-test-types", + "--db-tests-mode", ], }, { diff --git a/dev/breeze/src/airflow_breeze/global_constants.py b/dev/breeze/src/airflow_breeze/global_constants.py index 818b3754de1db..42d9622d8dfa2 100644 --- a/dev/breeze/src/airflow_breeze/global_constants.py +++ b/dev/breeze/src/airflow_breeze/global_constants.py @@ -219,6 +219,7 @@ def get_default_platform_machine() -> str: CELERY_BROKER_URLS_MAP = {"rabbitmq": "amqp://guest:guest@rabbitmq:5672", "redis": "redis://redis:6379/0"} SQLITE_URL = "sqlite:////root/airflow/sqlite/airflow.db" +ASYNC_SQLITE_URL = "sqlite+aiosqlite:////root/airflow/sqlite/airflow.db" PYTHONDONTWRITEBYTECODE = True PRODUCTION_IMAGE = False diff --git a/dev/breeze/src/airflow_breeze/params/shell_params.py b/dev/breeze/src/airflow_breeze/params/shell_params.py index e90d3480464e2..4277eff3b3ac7 100644 --- a/dev/breeze/src/airflow_breeze/params/shell_params.py +++ b/dev/breeze/src/airflow_breeze/params/shell_params.py @@ -21,6 +21,7 @@ from dataclasses import dataclass, field from functools import cached_property from pathlib import Path +from typing import Literal from airflow_breeze.branch_defaults import AIRFLOW_BRANCH, DEFAULT_AIRFLOW_CONSTRAINTS_BRANCH from airflow_breeze.global_constants import ( @@ -189,6 +190,7 @@ class ShellParams: remove_arm_packages: bool = False restart: bool = False run_db_tests_only: bool = False + db_tests_mode: Literal["sync", "async", "all"] = "sync" run_system_tests: bool = os.environ.get("RUN_SYSTEM_TESTS", "false") == "true" run_tests: bool = False skip_db_tests: bool = False @@ -287,6 +289,10 @@ def backend_version(self) -> str: def sqlite_url(self) -> str: return "sqlite:////root/airflow/sqlite/airflow.db" + @cached_property + def async_sqlite_url(self) -> str: + return "sqlite+aiosqlite:////root/airflow/sqlite/airflow.db" + def print_badge_info(self): if get_verbose(): get_console().print(f"[info]Use {self.image_type} image[/]") @@ -557,6 +563,7 @@ def env_variables_for_docker_commands(self) -> dict[str, str]: _set_var(_env, "SKIP_ENVIRONMENT_INITIALIZATION", self.skip_environment_initialization) _set_var(_env, "SKIP_SSH_SETUP", self.skip_ssh_setup) _set_var(_env, "SQLITE_URL", self.sqlite_url) + _set_var(_env, "ASYNC_SQLITE_URL", self.async_sqlite_url) _set_var(_env, "SSH_PORT", None, SSH_PORT) _set_var(_env, "STANDALONE_DAG_PROCESSOR", self.standalone_dag_processor) _set_var(_env, "START_AIRFLOW", self.start_airflow) @@ -571,6 +578,7 @@ def env_variables_for_docker_commands(self) -> dict[str, str]: _set_var(_env, "VERSION_SUFFIX_FOR_PYPI", self.version_suffix_for_pypi) _set_var(_env, "WEBSERVER_HOST_PORT", None, WEBSERVER_HOST_PORT) _set_var(_env, "_AIRFLOW_RUN_DB_TESTS_ONLY", self.run_db_tests_only) + _set_var(_env, "_AIRFLOW_DB_TESTS_MODE", self.db_tests_mode) _set_var(_env, "_AIRFLOW_SKIP_DB_TESTS", self.skip_db_tests) self._generate_env_for_docker_compose_file_if_needed(_env) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index c0131f8ca8d7c..62afced8d3b2f 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -680,6 +680,7 @@ }, "mysql": { "deps": [ + "aiomysql>=0.2.0", "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow>=2.6.0", "mysql-connector-python>=8.0.11", @@ -824,6 +825,7 @@ "deps": [ "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow>=2.6.0", + "asyncpg", "psycopg2-binary>=2.8.0" ], "cross-providers-deps": [ @@ -954,6 +956,7 @@ }, "sqlite": { "deps": [ + "aiosqlite>=0.19.0", "apache-airflow-providers-common-sql>=1.3.1", "apache-airflow>=2.6.0" ], diff --git a/images/breeze/output_testing_db-tests.svg b/images/breeze/output_testing_db-tests.svg index a4c191f60175c..6ab91535ed0b0 100644 --- a/images/breeze/output_testing_db-tests.svg +++ b/images/breeze/output_testing_db-tests.svg @@ -1,4 +1,4 @@ - +