Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions airflow-core/tests/unit/api_fastapi/common/test_dagbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def factory(*args, **kwargs):
purge_cached_app()
yield

def test_dagbag_used_as_singleton_in_dependency(self, session, dag_maker, test_client):
def test_dagbag_used_as_singleton_in_dependency(self, session, dag_maker, fresh_test_client):
"""
Ensure DagBag is created only once and reused across multiple API requests.

Expand All @@ -76,10 +76,10 @@ def test_dagbag_used_as_singleton_in_dependency(self, session, dag_maker, test_c
BaseOperator(task_id="test_task")
session.commit()

resp1 = test_client.get(f"/api/v2/dags/{dag_id}")
resp1 = fresh_test_client.get(f"/api/v2/dags/{dag_id}")
assert resp1.status_code == 200

resp2 = test_client.get(f"/api/v2/dags/{dag_id}")
resp2 = fresh_test_client.get(f"/api/v2/dags/{dag_id}")
assert resp2.status_code == 200

assert self.dagbag_call_counter["count"] == 1
Expand Down
139 changes: 105 additions & 34 deletions airflow-core/tests/unit/api_fastapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

import pytest
import time_machine
from fastapi import FastAPI
from fastapi.routing import Mount
from fastapi.testclient import TestClient

from airflow.api_fastapi.app import create_app
Expand Down Expand Up @@ -54,8 +56,16 @@ def get_api_path(request):
return API_PATHS.get(subdirectory_name, "/")


@pytest.fixture
def test_client(request):
@pytest.fixture(scope="session")
def _shared_api_app():
"""
Build the FastAPI app once per test session.

``create_app()`` rebuilds two full FastAPI apps (core + execution), registers every route and
builds the OpenAPI schema -- ~0.5s. The default ``test_client`` always uses the same config
(SimpleAuthManager), so the app structure is identical across tests; only per-test DB state and
request data differ. Building it once and reusing it removes that per-test rebuild cost.
"""
with conf_vars(
{
(
Expand All @@ -64,35 +74,85 @@ def test_client(request):
): "airflow.api_fastapi.auth.managers.simple.simple_auth_manager.SimpleAuthManager",
}
):
app = create_app()
auth_manager: SimpleAuthManager = app.state.auth_manager
# set time_very_before to 2014-01-01 00:00:00 and time_very_after to tomorrow
# to make the JWT token always valid for all test cases with time_machine
time_very_before = datetime.datetime(2014, 1, 1, 0, 0, 0)
time_after = datetime.datetime.now() + datetime.timedelta(days=1)
with time_machine.travel(time_very_before, tick=False):
token = auth_manager._get_token_signer(
expiration_time_in_seconds=(time_after - time_very_before).total_seconds()
).generate(
auth_manager.serialize_user(
SimpleAuthManagerUser(username="test", role="admin", teams=["team1"])
),
)
with mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked", return_value=False):
yield TestClient(
app,
headers={"Authorization": f"Bearer {token}"},
base_url=f"{BASE_URL}{get_api_path(request)}",
)
return create_app()


def _mounted_fastapi_apps(app: FastAPI) -> list[FastAPI]:
"""Return ``app`` and every FastAPI app mounted under it, recursively (``/execution``, ``/auth``, ...)."""
apps = [app]
for route in app.routes:
if isinstance(route, Mount) and isinstance(route.app, FastAPI):
apps.extend(_mounted_fastapi_apps(route.app))
return apps


@pytest.fixture
def _isolated_shared_app(_shared_api_app):
"""
Yield the session-shared app with its mutable state snapshotted and restored around each test.

The app is built once per session, so a test that rebinds something on ``app.state`` (the auth
endpoint tests swap ``auth_manager`` for a mock) or installs a ``dependency_overrides`` entry
(extra-links/tasks/logs install a ``dag_bag_from_app`` override) would leak into later tests.
Snapshotting ``state`` and ``dependency_overrides`` on the root app and every mounted sub-app on
entry and restoring them on exit keeps the reset resilient to future mutations without having to
enumerate them.

``app.state.dag_bag`` is the exception: tests mutate the DagBag object *in place* (its cache of
deserialized Dags fills as requests resolve them), which a state snapshot can't undo, so its
cache is cleared explicitly. A leaked warm entry would otherwise let a later test skip a
serialized-Dag DB read and break query-count assertions (e.g. the grid ``ti_summaries`` stream
tests) depending on execution order.
"""
apps = _mounted_fastapi_apps(_shared_api_app)
# ``app.state._state`` is Starlette's backing dict for ``State`` -- the only way to enumerate it.
saved = [(app, dict(app.state._state), dict(app.dependency_overrides)) for app in apps]
_shared_api_app.state.dag_bag.clear_cache()
try:
yield _shared_api_app
finally:
for app, state, overrides in saved:
app.state._state.clear()
app.state._state.update(state)
app.dependency_overrides.clear()
app.dependency_overrides.update(overrides)


def _authed_test_client(app: FastAPI, request):
auth_manager: SimpleAuthManager = app.state.auth_manager
# set time_very_before to 2014-01-01 00:00:00 and time_very_after to tomorrow
# to make the JWT token always valid for all test cases with time_machine
time_very_before = datetime.datetime(2014, 1, 1, 0, 0, 0)
time_after = datetime.datetime.now() + datetime.timedelta(days=1)
with time_machine.travel(time_very_before, tick=False):
token = auth_manager._get_token_signer(
expiration_time_in_seconds=(time_after - time_very_before).total_seconds()
).generate(
auth_manager.serialize_user(
SimpleAuthManagerUser(username="test", role="admin", teams=["team1"])
),
)
with mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked", return_value=False):
yield TestClient(
app,
headers={"Authorization": f"Bearer {token}"},
base_url=f"{BASE_URL}{get_api_path(request)}",
)


@pytest.fixture
def unauthenticated_test_client(request):
return TestClient(create_app(), base_url=f"{BASE_URL}{get_api_path(request)}")
def test_client(request, _isolated_shared_app):
yield from _authed_test_client(_isolated_shared_app, request)


@pytest.fixture
def unauthorized_test_client(request):
def fresh_test_client(request):
"""
Like ``test_client`` but backed by a freshly built app instead of the session-shared one.

For the rare tests that patch app construction (e.g. counting ``DBDagBag`` instantiation) and
so need the app built *after* their patch is applied.
"""
with conf_vars(
{
(
Expand All @@ -102,16 +162,27 @@ def unauthorized_test_client(request):
}
):
app = create_app()
auth_manager: SimpleAuthManager = app.state.auth_manager
token = auth_manager._get_token_signer().generate(
auth_manager.serialize_user(SimpleAuthManagerUser(username="dummy", role=None))
yield from _authed_test_client(app, request)


@pytest.fixture
def unauthenticated_test_client(request, _isolated_shared_app):
return TestClient(_isolated_shared_app, base_url=f"{BASE_URL}{get_api_path(request)}")


@pytest.fixture
def unauthorized_test_client(request, _isolated_shared_app):
app = _isolated_shared_app
auth_manager: SimpleAuthManager = app.state.auth_manager
token = auth_manager._get_token_signer().generate(
auth_manager.serialize_user(SimpleAuthManagerUser(username="dummy", role=None))
)
with mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked", return_value=False):
yield TestClient(
app,
headers={"Authorization": f"Bearer {token}"},
base_url=f"{BASE_URL}{get_api_path(request)}",
)
with mock.patch("airflow.models.revoked_token.RevokedToken.is_revoked", return_value=False):
yield TestClient(
app,
headers={"Authorization": f"Bearer {token}"},
base_url=f"{BASE_URL}{get_api_path(request)}",
)


@pytest.fixture
Expand Down
Loading