diff --git a/airflow/api_fastapi/execution_api/routes/assets.py b/airflow/api_fastapi/execution_api/routes/assets.py index 213c599befb3e..573ecce4cf9f3 100644 --- a/airflow/api_fastapi/execution_api/routes/assets.py +++ b/airflow/api_fastapi/execution_api/routes/assets.py @@ -19,11 +19,12 @@ from typing import Annotated -from fastapi import HTTPException, Query, status +from fastapi import Depends, HTTPException, Query, status from sqlalchemy import select from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.core_api.security import requires_access_asset from airflow.api_fastapi.execution_api.datamodels.asset import AssetResponse from airflow.models.asset import AssetModel @@ -33,6 +34,7 @@ status.HTTP_404_NOT_FOUND: {"description": "Asset not found"}, status.HTTP_401_UNAUTHORIZED: {"description": "Unauthorized"}, }, + dependencies=[Depends(requires_access_asset("GET"))], ) diff --git a/tests/api_fastapi/execution_api/routes/test_assets.py b/tests/api_fastapi/execution_api/routes/test_assets.py index 2cf34f8dd7bc7..2b0975e896d0e 100644 --- a/tests/api_fastapi/execution_api/routes/test_assets.py +++ b/tests/api_fastapi/execution_api/routes/test_assets.py @@ -28,7 +28,7 @@ class TestGetAssetByName: - def test_get_asset_by_name(self, client, session): + def test_get_asset_by_name(self, test_client, session): asset = AssetModel( id=1, name="test_get_asset_by_name", @@ -44,7 +44,7 @@ def test_get_asset_by_name(self, client, session): session.add_all([asset, asset_active]) session.commit() - response = client.get("/execution/assets/by-name", params={"name": "test_get_asset_by_name"}) + response = test_client.get("/execution/assets/by-name", params={"name": "test_get_asset_by_name"}) assert response.status_code == 200 assert response.json() == { @@ -58,8 +58,8 @@ def test_get_asset_by_name(self, client, session): session.delete(asset_active) session.commit() - def test_asset_name_not_found(self, client): - response = client.get("/execution/assets/by-name", params={"name": "non_existent"}) + def test_asset_name_not_found(self, test_client): + response = test_client.get("/execution/assets/by-name", params={"name": "non_existent"}) assert response.status_code == 404 assert response.json() == { @@ -69,9 +69,21 @@ def test_asset_name_not_found(self, client): } } + def test_get_config_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/execution/assets/by-name", params={"name": "test_get_asset_by_name"} + ) + assert response.status_code == 401 + + def test_get_config_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/execution/assets/by-name", params={"name": "test_get_asset_by_name"} + ) + assert response.status_code == 403 + class TestGetAssetByUri: - def test_get_asset_by_uri(self, client, session): + def test_get_asset_by_uri(self, test_client, session): asset = AssetModel( name="test_get_asset_by_uri", uri="s3://bucket/key", @@ -84,7 +96,7 @@ def test_get_asset_by_uri(self, client, session): session.add_all([asset, asset_active]) session.commit() - response = client.get("/execution/assets/by-uri", params={"uri": "s3://bucket/key"}) + response = test_client.get("/execution/assets/by-uri", params={"uri": "s3://bucket/key"}) assert response.status_code == 200 assert response.json() == { @@ -98,8 +110,8 @@ def test_get_asset_by_uri(self, client, session): session.delete(asset_active) session.commit() - def test_asset_uri_not_found(self, client): - response = client.get("/execution/assets/by-uri", params={"uri": "non_existent"}) + def test_asset_uri_not_found(self, test_client): + response = test_client.get("/execution/assets/by-uri", params={"uri": "non_existent"}) assert response.status_code == 404 assert response.json() == { @@ -108,3 +120,13 @@ def test_asset_uri_not_found(self, client): "reason": "not_found", } } + + def test_get_config_should_response_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/execution/assets/by-uri", params={"uri": "s3://bucket/key"} + ) + assert response.status_code == 401 + + def test_get_config_should_response_403(self, unauthorized_test_client): + response = unauthorized_test_client.get("/execution/assets/by-uri", params={"uri": "s3://bucket/key"}) + assert response.status_code == 403