diff --git a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml index 0eab45da49794..35230f3417559 100644 --- a/airflow/api_fastapi/core_api/openapi/v1-generated.yaml +++ b/airflow/api_fastapi/core_api/openapi/v1-generated.yaml @@ -2081,6 +2081,8 @@ paths: - DagRun summary: Get Dag Run operationId: get_dag_run + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -2131,6 +2133,8 @@ paths: summary: Delete Dag Run description: Delete a DAG Run entry. operationId: delete_dag_run + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -2183,6 +2187,8 @@ paths: summary: Patch Dag Run description: Modify a DAG Run. operationId: patch_dag_run + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -2309,6 +2315,8 @@ paths: - DagRun summary: Clear Dag Run operationId: clear_dag_run + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -2373,6 +2381,8 @@ paths: This endpoint allows specifying `~` as the dag_id to retrieve Dag Runs for all DAGs.' operationId: get_dag_runs + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -2538,6 +2548,8 @@ paths: summary: Trigger Dag Run description: Trigger a DAG. operationId: trigger_dag_run + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -2600,6 +2612,8 @@ paths: summary: Get List Dag Runs Batch description: Get a list of DAG Runs. operationId: get_list_dag_runs_batch + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -2652,6 +2666,8 @@ paths: summary: Get Dag Source description: Get source code using file token. operationId: get_dag_source + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -2732,6 +2748,8 @@ paths: summary: Get Dag Stats description: Get Dag statistics. operationId: get_dag_stats + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_ids in: query @@ -2990,6 +3008,8 @@ paths: summary: List Dag Warnings description: Get a list of DAG warnings. operationId: list_dag_warnings + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: query @@ -4517,6 +4537,8 @@ paths: summary: Get Xcom Entry description: Get an XCom entry. operationId: get_xcom_entry + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -4610,6 +4632,8 @@ paths: summary: Update Xcom Entry description: Update an existing XCom entry. operationId: update_xcom_entry + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -4689,6 +4713,8 @@ paths: This endpoint allows specifying `~` as the dag_id, dag_run_id, task_id to retrieve XCom entries for all DAGs.' operationId: get_xcom_entries + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -4784,6 +4810,8 @@ paths: summary: Create Xcom Entry description: Create an XCom entry. operationId: create_xcom_entry + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -4853,6 +4881,8 @@ paths: summary: Get Task Instance description: Get task instance. operationId: get_task_instance + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -4909,6 +4939,8 @@ paths: summary: Patch Task Instance description: Update a task instance. operationId: patch_task_instance + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5001,6 +5033,8 @@ paths: summary: Get Mapped Task Instances description: Get list of mapped task instances. operationId: get_mapped_task_instances + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5227,6 +5261,8 @@ paths: summary: Get Task Instance Dependencies description: Get dependencies blocking task from getting scheduled. operationId: get_task_instance_dependencies + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5290,6 +5326,8 @@ paths: summary: Get Task Instance Dependencies description: Get dependencies blocking task from getting scheduled. operationId: get_task_instance_dependencies + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5354,6 +5392,8 @@ paths: summary: Get Task Instance Tries description: Get list of task instances history. operationId: get_task_instance_tries + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5417,6 +5457,8 @@ paths: - Task Instance summary: Get Mapped Task Instance Tries operationId: get_mapped_task_instance_tries + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5480,6 +5522,8 @@ paths: summary: Get Mapped Task Instance description: Get task instance. operationId: get_mapped_task_instance + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5542,6 +5586,8 @@ paths: summary: Patch Task Instance description: Update a task instance. operationId: patch_task_instance + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5639,6 +5685,8 @@ paths: and DAG runs.' operationId: get_task_instances + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5875,6 +5923,8 @@ paths: summary: Get Task Instances Batch description: Get list of task instances. operationId: get_task_instances_batch + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -5934,6 +5984,8 @@ paths: summary: Get Task Instance Try Details description: Get task instance details by try number. operationId: get_task_instance_try_details + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -6003,6 +6055,8 @@ paths: - Task Instance summary: Get Mapped Task Instance Try Details operationId: get_mapped_task_instance_try_details + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -6072,6 +6126,8 @@ paths: summary: Post Clear Task Instances description: Clear task instances. operationId: post_clear_task_instances + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -6123,6 +6179,8 @@ paths: summary: Patch Task Instance Dry Run description: Update a task instance dry_run mode. operationId: patch_task_instance_dry_run + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -6208,6 +6266,8 @@ paths: summary: Patch Task Instance Dry Run description: Update a task instance dry_run mode. operationId: patch_task_instance_dry_run + security: + - OAuth2PasswordBearer: [] parameters: - name: dag_id in: path @@ -6820,6 +6880,8 @@ paths: summary: Reparse Dag File description: Request re-parsing a DAG file. operationId: reparse_dag_file + security: + - OAuth2PasswordBearer: [] parameters: - name: file_token in: path diff --git a/airflow/api_fastapi/core_api/routes/public/dag_parsing.py b/airflow/api_fastapi/core_api/routes/public/dag_parsing.py index f5b7f2c359335..a4deb3b18de02 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_parsing.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_parsing.py @@ -27,6 +27,7 @@ from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import requires_access_dag from airflow.api_fastapi.logging.decorators import action_logging from airflow.models.dag import DagModel from airflow.models.dagbag import DagPriorityParsingRequest @@ -41,7 +42,7 @@ "", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), status_code=status.HTTP_201_CREATED, - dependencies=[Depends(action_logging())], + dependencies=[Depends(requires_access_dag(method="PUT")), Depends(action_logging())], ) def reparse_dag_file( file_token: str, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_run.py b/airflow/api_fastapi/core_api/routes/public/dag_run.py index 8703fe42e423a..4e98f0b5aa1cf 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_run.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_run.py @@ -29,6 +29,7 @@ set_dag_run_state_to_queued, set_dag_run_state_to_success, ) +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep, paginated_select from airflow.api_fastapi.common.parameters import ( FilterOptionEnum, @@ -59,7 +60,11 @@ TaskInstanceResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc -from airflow.api_fastapi.core_api.security import requires_access_asset +from airflow.api_fastapi.core_api.security import ( + ReadableDagRunsFilterDep, + requires_access_asset, + requires_access_dag, +) from airflow.api_fastapi.logging.decorators import action_logging from airflow.exceptions import ParamValidationError from airflow.listeners.listener import get_listener_manager @@ -78,6 +83,7 @@ status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))], ) def get_dag_run(dag_id: str, dag_run_id: str, session: SessionDep) -> DAGRunResponse: dag_run = session.scalar(select(DagRun).filter_by(dag_id=dag_id, run_id=dag_run_id)) @@ -99,7 +105,10 @@ def get_dag_run(dag_id: str, dag_run_id: str, session: SessionDep) -> DAGRunResp status.HTTP_404_NOT_FOUND, ], ), - dependencies=[Depends(action_logging())], + dependencies=[ + Depends(requires_access_dag(method="DELETE", access_entity=DagAccessEntity.RUN)), + Depends(action_logging()), + ], ) def delete_dag_run(dag_id: str, dag_run_id: str, session: SessionDep): """Delete a DAG Run entry.""" @@ -121,7 +130,10 @@ def delete_dag_run(dag_id: str, dag_run_id: str, session: SessionDep): status.HTTP_404_NOT_FOUND, ], ), - dependencies=[Depends(action_logging())], + dependencies=[ + Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.RUN)), + Depends(action_logging()), + ], ) def patch_dag_run( dag_id: str, @@ -190,7 +202,10 @@ def patch_dag_run( status.HTTP_404_NOT_FOUND, ] ), - dependencies=[Depends(requires_access_asset(method="GET"))], + dependencies=[ + Depends(requires_access_asset(method="GET")), + Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN)), + ], ) def get_upstream_asset_events( dag_id: str, dag_run_id: str, session: SessionDep @@ -217,7 +232,10 @@ def get_upstream_asset_events( @dag_run_router.post( "/{dag_run_id}/clear", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), - dependencies=[Depends(action_logging())], + dependencies=[ + Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.RUN)), + Depends(action_logging()), + ], ) def clear_dag_run( dag_id: str, @@ -263,7 +281,11 @@ def clear_dag_run( return dag_run_cleared -@dag_run_router.get("", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND])) +@dag_run_router.get( + "", + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))], +) def get_dag_runs( dag_id: str, limit: QueryLimit, @@ -295,6 +317,7 @@ def get_dag_runs( ).dynamic_depends(default="id") ), ], + readable_dag_runs_filter: ReadableDagRunsFilterDep, session: SessionDep, request: Request, ) -> DAGRunCollectionResponse: @@ -314,7 +337,15 @@ def get_dag_runs( dag_run_select, total_entries = paginated_select( statement=query, - filters=[run_after, logical_date, start_date_range, end_date_range, update_at_range, state], + filters=[ + run_after, + logical_date, + start_date_range, + end_date_range, + update_at_range, + state, + readable_dag_runs_filter, + ], order_by=order_by, offset=offset, limit=limit, @@ -337,7 +368,10 @@ def get_dag_runs( status.HTTP_409_CONFLICT, ] ), - dependencies=[Depends(action_logging())], + dependencies=[ + Depends(requires_access_dag(method="POST", access_entity=DagAccessEntity.RUN)), + Depends(action_logging()), + ], ) def trigger_dag_run( dag_id, @@ -383,9 +417,16 @@ def trigger_dag_run( raise HTTPException(status.HTTP_400_BAD_REQUEST, str(e)) -@dag_run_router.post("/list", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND])) +@dag_run_router.post( + "/list", + responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))], +) def get_list_dag_runs_batch( - dag_id: Literal["~"], body: DAGRunsBatchBody, session: SessionDep + dag_id: Literal["~"], + body: DAGRunsBatchBody, + readable_dag_runs_filter: ReadableDagRunsFilterDep, + session: SessionDep, ) -> DAGRunCollectionResponse: """Get a list of DAG Runs.""" dag_ids = FilterParam(DagRun.dag_id, body.dag_ids, FilterOptionEnum.IN) @@ -430,7 +471,7 @@ def get_list_dag_runs_batch( base_query = select(DagRun) dag_runs_select, total_entries = paginated_select( statement=base_query, - filters=[dag_ids, logical_date, run_after, start_date, end_date, state], + filters=[dag_ids, logical_date, run_after, start_date, end_date, state, readable_dag_runs_filter], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_sources.py b/airflow/api_fastapi/core_api/routes/public/dag_sources.py index 4337c92107322..fb00a4dd553b5 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_sources.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_sources.py @@ -16,14 +16,16 @@ # under the License. from __future__ import annotations -from fastapi import HTTPException, Response, status +from fastapi import Depends, HTTPException, Response, status +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.headers import HeaderAcceptJsonOrText from airflow.api_fastapi.common.router import AirflowRouter from airflow.api_fastapi.common.types import Mimetype from airflow.api_fastapi.core_api.datamodels.dag_sources import DAGSourceResponse from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import requires_access_dag from airflow.models.dag_version import DagVersion dag_sources_router = AirflowRouter(tags=["DagSource"], prefix="/dagSources") @@ -47,6 +49,7 @@ }, }, response_model=DAGSourceResponse, + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.CODE))], ) def get_dag_source( accept: HeaderAcceptJsonOrText, diff --git a/airflow/api_fastapi/core_api/routes/public/dag_stats.py b/airflow/api_fastapi/core_api/routes/public/dag_stats.py index a6aa6063c263b..124221571e1c0 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_stats.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_stats.py @@ -21,6 +21,7 @@ from fastapi import Depends, status +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import ( SessionDep, paginated_select, @@ -38,6 +39,7 @@ DagStatsStateResponse, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import ReadableDagRunsFilterDep, requires_access_dag from airflow.models.dagrun import DagRun from airflow.utils.state import DagRunState @@ -52,8 +54,10 @@ status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.RUN))], ) def get_dag_stats( + readable_dag_runs_filter: ReadableDagRunsFilterDep, session: SessionDep, dag_ids: Annotated[ FilterParam[list[str]], @@ -63,7 +67,7 @@ def get_dag_stats( """Get Dag statistics.""" dagruns_select, _ = paginated_select( statement=dagruns_select_with_state_count, - filters=[dag_ids], + filters=[dag_ids, readable_dag_runs_filter], session=session, return_total_entries=False, ) diff --git a/airflow/api_fastapi/core_api/routes/public/dag_warning.py b/airflow/api_fastapi/core_api/routes/public/dag_warning.py index 2c964efce7e45..cf82d2c8dde11 100644 --- a/airflow/api_fastapi/core_api/routes/public/dag_warning.py +++ b/airflow/api_fastapi/core_api/routes/public/dag_warning.py @@ -22,6 +22,7 @@ from fastapi import Depends from sqlalchemy import select +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import ( SessionDep, paginated_select, @@ -37,6 +38,7 @@ from airflow.api_fastapi.core_api.datamodels.dag_warning import ( DAGWarningCollectionResponse, ) +from airflow.api_fastapi.core_api.security import ReadableDagWarningsFilterDep, requires_access_dag from airflow.models.dagwarning import DagWarning, DagWarningType dag_warning_router = AirflowRouter(tags=["DagWarning"]) @@ -44,6 +46,7 @@ @dag_warning_router.get( "/dagWarnings", + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.WARNING))], ) def list_dag_warnings( dag_id: Annotated[FilterParam[str | None], Depends(filter_param_factory(DagWarning.dag_id, str | None))], @@ -57,12 +60,13 @@ def list_dag_warnings( SortParam, Depends(SortParam(["dag_id", "warning_type", "message", "timestamp"], DagWarning).dynamic_depends()), ], + readable_dag_warning_filter: ReadableDagWarningsFilterDep, session: SessionDep, ) -> DAGWarningCollectionResponse: """Get a list of DAG warnings.""" dag_warnings_select, total_entries = paginated_select( statement=select(DagWarning), - filters=[warning_type, dag_id], + filters=[warning_type, dag_id, readable_dag_warning_filter], order_by=order_by, offset=offset, limit=limit, diff --git a/airflow/api_fastapi/core_api/routes/public/task_instances.py b/airflow/api_fastapi/core_api/routes/public/task_instances.py index 7d130bc1e37e9..3e9cbc4d795bc 100644 --- a/airflow/api_fastapi/core_api/routes/public/task_instances.py +++ b/airflow/api_fastapi/core_api/routes/public/task_instances.py @@ -27,6 +27,7 @@ from sqlalchemy.orm import joinedload from sqlalchemy.sql.selectable import Select +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep, paginated_select from airflow.api_fastapi.common.parameters import ( FilterOptionEnum, @@ -60,6 +61,7 @@ TaskInstancesBatchBody, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import ReadableTIFilterDep, requires_access_dag from airflow.api_fastapi.logging.decorators import action_logging from airflow.exceptions import TaskNotFound from airflow.models import Base, DagRun @@ -78,6 +80,7 @@ @task_instances_router.get( task_instances_prefix + "/{task_id}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_task_instance( dag_id: str, dag_run_id: str, task_id: str, session: SessionDep @@ -108,6 +111,7 @@ def get_task_instance( @task_instances_router.get( task_instances_prefix + "/{task_id}/listMapped", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_mapped_task_instances( dag_id: str, @@ -211,10 +215,12 @@ def get_mapped_task_instances( @task_instances_router.get( task_instances_prefix + "/{task_id}/dependencies", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) @task_instances_router.get( task_instances_prefix + "/{task_id}/{map_index}/dependencies", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_task_instance_dependencies( dag_id: str, @@ -265,6 +271,7 @@ def get_task_instance_dependencies( @task_instances_router.get( task_instances_prefix + "/{task_id}/tries", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_task_instance_tries( dag_id: str, @@ -308,6 +315,7 @@ def _query(orm_object: Base) -> Select: @task_instances_router.get( task_instances_prefix + "/{task_id}/{map_index}/tries", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_mapped_task_instance_tries( dag_id: str, @@ -328,6 +336,7 @@ def get_mapped_task_instance_tries( @task_instances_router.get( task_instances_prefix + "/{task_id}/{map_index}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_mapped_task_instance( dag_id: str, @@ -358,6 +367,7 @@ def get_mapped_task_instance( @task_instances_router.get( task_instances_prefix, responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_task_instances( dag_id: str, @@ -406,6 +416,7 @@ def get_task_instances( ).dynamic_depends(default="map_index") ), ], + readable_ti_filter: ReadableTIFilterDep, session: SessionDep, ) -> TaskInstanceCollectionResponse: """ @@ -447,6 +458,7 @@ def get_task_instances( task_id, task_display_name_pattern, version_number, + readable_ti_filter, ], order_by=order_by, offset=offset, @@ -464,12 +476,16 @@ def get_task_instances( @task_instances_router.post( task_instances_prefix + "/list", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), - dependencies=[Depends(action_logging())], + dependencies=[ + Depends(action_logging()), + Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE)), + ], ) def get_task_instances_batch( dag_id: Literal["~"], dag_run_id: Literal["~"], body: TaskInstancesBatchBody, + readable_ti_filter: ReadableTIFilterDep, session: SessionDep, ) -> TaskInstanceCollectionResponse: """Get list of task instances.""" @@ -525,6 +541,7 @@ def get_task_instances_batch( pool, queue, executor, + readable_ti_filter, ], order_by=order_by, offset=offset, @@ -546,6 +563,7 @@ def get_task_instances_batch( @task_instances_router.get( task_instances_prefix + "/{task_id}/tries/{task_try_number}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_task_instance_try_details( dag_id: str, @@ -581,6 +599,7 @@ def _query(orm_object: Base) -> TI | TIH | None: @task_instances_router.get( task_instances_prefix + "/{task_id}/{map_index}/tries/{task_try_number}", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_mapped_task_instance_try_details( dag_id: str, @@ -603,7 +622,10 @@ def get_mapped_task_instance_try_details( @task_instances_router.post( "/clearTaskInstances", responses=create_openapi_http_exception_doc([status.HTTP_404_NOT_FOUND]), - dependencies=[Depends(action_logging())], + dependencies=[ + Depends(action_logging()), + Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE)), + ], ) def post_clear_task_instances( dag_id: str, @@ -748,12 +770,14 @@ def _patch_ti_validate_request( responses=create_openapi_http_exception_doc( [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST], ), + dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))], ) @task_instances_router.patch( task_instances_prefix + "/{task_id}/{map_index}/dry_run", responses=create_openapi_http_exception_doc( [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST], ), + dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def patch_task_instance_dry_run( dag_id: str, @@ -808,14 +832,20 @@ def patch_task_instance_dry_run( responses=create_openapi_http_exception_doc( [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT], ), - dependencies=[Depends(action_logging())], + dependencies=[ + Depends(action_logging()), + Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE)), + ], ) @task_instances_router.patch( task_instances_prefix + "/{task_id}/{map_index}", responses=create_openapi_http_exception_doc( [status.HTTP_404_NOT_FOUND, status.HTTP_400_BAD_REQUEST, status.HTTP_409_CONFLICT], ), - dependencies=[Depends(action_logging())], + dependencies=[ + Depends(action_logging()), + Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE)), + ], ) def patch_task_instance( dag_id: str, diff --git a/airflow/api_fastapi/core_api/routes/public/xcom.py b/airflow/api_fastapi/core_api/routes/public/xcom.py index 3da163f3e4033..c5a4028eef59b 100644 --- a/airflow/api_fastapi/core_api/routes/public/xcom.py +++ b/airflow/api_fastapi/core_api/routes/public/xcom.py @@ -19,9 +19,10 @@ import copy from typing import Annotated -from fastapi import HTTPException, Query, Request, status +from fastapi import Depends, HTTPException, Query, Request, status from sqlalchemy import and_, select +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep, paginated_select from airflow.api_fastapi.common.parameters import QueryLimit, QueryOffset from airflow.api_fastapi.common.router import AirflowRouter @@ -33,6 +34,7 @@ XComUpdateBody, ) from airflow.api_fastapi.core_api.openapi.exceptions import create_openapi_http_exception_doc +from airflow.api_fastapi.core_api.security import ReadableXComFilterDep, requires_access_dag from airflow.exceptions import TaskNotFound from airflow.models import DAG, DagRun as DR, XCom from airflow.settings import conf @@ -50,6 +52,7 @@ status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.XCOM))], ) def get_xcom_entry( dag_id: str, @@ -105,6 +108,7 @@ def get_xcom_entry( status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.XCOM))], ) def get_xcom_entries( dag_id: str, @@ -112,6 +116,7 @@ def get_xcom_entries( task_id: str, limit: QueryLimit, offset: QueryOffset, + readable_xcom_filter: ReadableXComFilterDep, session: SessionDep, xcom_key: Annotated[str | None, Query()] = None, map_index: Annotated[int | None, Query(ge=-1)] = None, @@ -137,6 +142,7 @@ def get_xcom_entries( query, total_entries = paginated_select( statement=query, + filters=[readable_xcom_filter], offset=offset, limit=limit, session=session, @@ -155,6 +161,7 @@ def get_xcom_entries( status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="POST", access_entity=DagAccessEntity.XCOM))], ) def create_xcom_entry( dag_id: str, @@ -234,6 +241,7 @@ def create_xcom_entry( status.HTTP_404_NOT_FOUND, ] ), + dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.XCOM))], ) def update_xcom_entry( dag_id: str, diff --git a/airflow/api_fastapi/core_api/security.py b/airflow/api_fastapi/core_api/security.py index 3d40f78523450..8c3d11d5425db 100644 --- a/airflow/api_fastapi/core_api/security.py +++ b/airflow/api_fastapi/core_api/security.py @@ -40,7 +40,10 @@ ) from airflow.api_fastapi.core_api.base import OrmClause from airflow.configuration import conf -from airflow.models.dag import DagModel +from airflow.models.dag import DagModel, DagRun +from airflow.models.dagwarning import DagWarning +from airflow.models.taskinstance import TaskInstance as TI +from airflow.models.xcom import XCom from airflow.utils.jwt_signer import JWTSigner, get_signing_key if TYPE_CHECKING: @@ -114,7 +117,37 @@ def to_orm(self, select: Select) -> Select: return select.where(DagModel.dag_id.in_(self.value)) -def permitted_dag_filter_factory(method: ResourceMethod) -> Callable[[Request, BaseUser], PermittedDagFilter]: +class PermittedDagRunFilter(PermittedDagFilter): + """A parameter that filters the permitted dag runs for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(DagRun.dag_id.in_(self.value)) + + +class PermittedDagWarningFilter(PermittedDagFilter): + """A parameter that filters the permitted dag warnings for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(DagWarning.dag_id.in_(self.value)) + + +class PermittedTIFilter(PermittedDagFilter): + """A parameter that filters the permitted task instances for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(TI.dag_id.in_(self.value)) + + +class PermittedXComFilter(PermittedDagFilter): + """A parameter that filters the permitted XComs for the user.""" + + def to_orm(self, select: Select) -> Select: + return select.where(XCom.dag_id.in_(self.value)) + + +def permitted_dag_filter_factory( + method: ResourceMethod, filter_class=PermittedDagFilter +) -> Callable[[Request, BaseUser], PermittedDagFilter]: """ Create a callable for Depends in FastAPI that returns a filter of the permitted dags for the user. @@ -128,13 +161,25 @@ def depends_permitted_dags_filter( ) -> PermittedDagFilter: auth_manager: BaseAuthManager = request.app.state.auth_manager permitted_dags: set[str] = auth_manager.get_permitted_dag_ids(user=user, method=method) - return PermittedDagFilter(permitted_dags) + return filter_class(permitted_dags) return depends_permitted_dags_filter EditableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("PUT"))] ReadableDagsFilterDep = Annotated[PermittedDagFilter, Depends(permitted_dag_filter_factory("GET"))] +ReadableDagRunsFilterDep = Annotated[ + PermittedDagRunFilter, Depends(permitted_dag_filter_factory("GET", PermittedDagRunFilter)) +] +ReadableDagWarningsFilterDep = Annotated[ + PermittedDagWarningFilter, Depends(permitted_dag_filter_factory("GET", PermittedDagWarningFilter)) +] +ReadableTIFilterDep = Annotated[ + PermittedTIFilter, Depends(permitted_dag_filter_factory("GET", PermittedTIFilter)) +] +ReadableXComFilterDep = Annotated[ + PermittedXComFilter, Depends(permitted_dag_filter_factory("GET", PermittedXComFilter)) +] def requires_access_pool(method: ResourceMethod) -> Callable[[Request, BaseUser], None]: diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py b/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py index 0943684e32192..00ab7dbbaf324 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_parsing.py @@ -65,6 +65,18 @@ def test_201_and_400_requests(self, url_safe_serializer, session, test_client): assert parsing_requests[0].fileloc == test_dag.fileloc _check_last_log(session, dag_id=None, event="reparse_dag_file", logical_date=None) + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.put( + "/public/parseDagFile/token", headers={"Accept": "application/json"} + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.put( + "/public/parseDagFile/token", headers={"Accept": "application/json"} + ) + assert response.status_code == 403 + def test_bad_file_request(self, url_safe_serializer, session, test_client): url = f"/public/parseDagFile/{url_safe_serializer.dumps('/some/random/file.py')}" response = test_client.put(url, headers={"Accept": "application/json"}) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_run.py b/tests/api_fastapi/core_api/routes/public/test_dag_run.py index f62bff0665ec6..61199fc1f37a5 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_run.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_run.py @@ -244,6 +244,14 @@ def test_get_dag_run_not_found(self, test_client): body = response.json() assert body["detail"] == "The DagRun with dag_id: `test_dag1` and run_id: `invalid` was not found" + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"/public/dags/{DAG1_ID}/dagRuns/invalid") + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"/public/dags/{DAG1_ID}/dagRuns/invalid") + assert response.status_code == 403 + class TestGetDagRuns: @pytest.mark.parametrize("dag_id, total_entries", [(DAG1_ID, 2), (DAG2_ID, 2), ("~", 4)]) @@ -277,6 +285,14 @@ def test_invalid_order_by_raises_400(self, test_client): == "Ordering with 'invalid' is disallowed or the attribute does not exist on the model" ) + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get("/public/dags/test_dag1/dagRuns?order_by=invalid") + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get("/public/dags/test_dag1/dagRuns?order_by=invalid") + assert response.status_code == 403 + @pytest.mark.parametrize( "order_by,expected_order", [ @@ -550,6 +566,14 @@ def test_list_dag_runs_return_200(self, test_client, session): expected = get_dag_run_dict(run) assert each == expected + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post("/public/dags/~/dagRuns/list", json={}) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post("/public/dags/~/dagRuns/list", json={}) + assert response.status_code == 403 + def test_list_dag_runs_with_invalid_dag_id(self, test_client): response = test_client.post("/public/dags/invalid/dagRuns/list", json={}) assert response.status_code == 422 @@ -909,6 +933,14 @@ def test_patch_dag_run(self, test_client, dag_id, run_id, patch_body, response_b assert body.get("note") == response_body.get("note") _check_last_log(session, dag_id=dag_id, event="patch_dag_run", logical_date=None) + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch("/public/dags/dag_1/dagRuns/run_1", json={}) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch("/public/dags/dag_1/dagRuns/run_1", json={}) + assert response.status_code == 403 + @pytest.mark.parametrize( "query_params, patch_body, response_body, expected_status_code", [ @@ -1008,6 +1040,14 @@ def test_delete_dag_run_not_found(self, test_client): body = response.json() assert body["detail"] == "The DagRun with dag_id: `test_dag1` and run_id: `invalid` was not found" + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.delete(f"/public/dags/{DAG1_ID}/dagRuns/invalid") + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.delete(f"/public/dags/{DAG1_ID}/dagRuns/invalid") + assert response.status_code == 403 + class TestGetDagRunAssetTriggerEvents: @pytest.mark.usefixtures("configure_git_connection_for_dag_bundle") @@ -1115,6 +1155,20 @@ def test_clear_dag_run(self, test_client, session): logical_date=None, ) + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post( + f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}/clear", + json={"dry_run": False}, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post( + f"/public/dags/{DAG1_ID}/dagRuns/{DAG1_RUN1_ID}/clear", + json={"dry_run": False}, + ) + assert response.status_code == 403 + @pytest.mark.parametrize( "body, dag_run_id, expected_state", [ @@ -1246,6 +1300,20 @@ def test_should_respond_200( assert response.json() == expected_response_json _check_last_log(session, dag_id=DAG1_ID, event="trigger_dag_run", logical_date=None) + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post( + f"/public/dags/{DAG1_ID}/dagRuns", + json={}, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post( + f"/public/dags/{DAG1_ID}/dagRuns", + json={}, + ) + assert response.status_code == 403 + @pytest.mark.parametrize( "post_body, expected_detail", [ diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_sources.py b/tests/api_fastapi/core_api/routes/public/test_dag_sources.py index 6da8902b232a0..8bd53ea2f97b9 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_sources.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_sources.py @@ -77,6 +77,18 @@ def test_should_respond_200_text(self, test_client, test_dag): json.loads(response.content.decode()) assert response.headers["Content-Type"].startswith("text/plain") + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + f"{API_PREFIX}/{TEST_DAG_ID}", headers={"Accept": "text/plain"} + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + f"{API_PREFIX}/{TEST_DAG_ID}", headers={"Accept": "text/plain"} + ) + assert response.status_code == 403 + @pytest.mark.parametrize( "headers", [{"Accept": "application/json"}, {"Accept": "application/json; charset=utf-8"}, {}] ) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_stats.py b/tests/api_fastapi/core_api/routes/public/test_dag_stats.py index 8a0dae3604c1d..e7264bd48d145 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_stats.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_stats.py @@ -129,7 +129,7 @@ def teardown_method(self) -> None: class TestGetDagStats(TestDagStatsEndpoint): """Unit tests for Get DAG Stats.""" - def test_should_respond_200(self, client, session): + def test_should_respond_200(self, test_client, session): self._create_dag_and_runs(session) exp_payload = { "dags": [ @@ -179,13 +179,21 @@ def test_should_respond_200(self, client, session): "total_entries": 2, } - response = client().get(f"{API_PREFIX}?dag_ids={DAG1_ID}&dag_ids={DAG2_ID}") + response = test_client.get(f"{API_PREFIX}?dag_ids={DAG1_ID}&dag_ids={DAG2_ID}") assert response.status_code == 200 res_json = response.json() assert res_json["total_entries"] == len(res_json["dags"]) assert res_json == exp_payload - def test_all_dags_should_respond_200(self, client, session): + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get(f"{API_PREFIX}?dag_ids={DAG1_ID}&dag_ids={DAG2_ID}") + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get(f"{API_PREFIX}?dag_ids={DAG1_ID}&dag_ids={DAG2_ID}") + assert response.status_code == 403 + + def test_all_dags_should_respond_200(self, test_client, session): self._create_dag_and_runs(session) exp_payload = { "dags": [ @@ -256,7 +264,7 @@ def test_all_dags_should_respond_200(self, client, session): "total_entries": 3, } - response = client().get(API_PREFIX) + response = test_client.get(API_PREFIX) assert response.status_code == 200 res_json = response.json() assert res_json["total_entries"] == len(res_json["dags"]) @@ -403,9 +411,9 @@ def test_all_dags_should_respond_200(self, client, session): ), ], ) - def test_single_dag_in_dag_ids(self, client, session, url, params, exp_payload): + def test_single_dag_in_dag_ids(self, test_client, session, url, params, exp_payload): self._create_dag_and_runs(session) - response = client().get(url, params=params) + response = test_client.get(url, params=params) assert response.status_code == 200 res_json = response.json() assert res_json["total_entries"] == len(res_json["dags"]) diff --git a/tests/api_fastapi/core_api/routes/public/test_dag_warning.py b/tests/api_fastapi/core_api/routes/public/test_dag_warning.py index 61237bd10299a..c3dfc304cd8a3 100644 --- a/tests/api_fastapi/core_api/routes/public/test_dag_warning.py +++ b/tests/api_fastapi/core_api/routes/public/test_dag_warning.py @@ -78,6 +78,14 @@ def test_get_dag_warnings(self, test_client, query_params, expected_total_entrie assert len(response_json["dag_warnings"]) == len(expected_messages) assert [dag_warning["message"] for dag_warning in response_json["dag_warnings"]] == expected_messages + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get("/public/dagWarnings", params={}) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get("/public/dagWarnings", params={}) + assert response.status_code == 403 + def test_get_dag_warnings_bad_request(self, test_client): response = test_client.get("/public/dagWarnings", params={"warning_type": "invalid"}) response_json = response.json() diff --git a/tests/api_fastapi/core_api/routes/public/test_task_instances.py b/tests/api_fastapi/core_api/routes/public/test_task_instances.py index d7e185081c9a6..4403bbc1428ce 100644 --- a/tests/api_fastapi/core_api/routes/public/test_task_instances.py +++ b/tests/api_fastapi/core_api/routes/public/test_task_instances.py @@ -207,6 +207,18 @@ def test_should_respond_200(self, test_client, session): "triggerer_job": None, } + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context" + ) + assert response.status_code == 403 + @pytest.mark.parametrize( "run_id, expected_version_number", [ @@ -525,6 +537,18 @@ def test_should_respond_200_mapped_task_instance_with_rtif(self, test_client, se "triggerer_job": None, } + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/1", + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/1", + ) + assert response.status_code == 403 + def test_should_respond_404_wrong_map_index(self, test_client, session): self.create_task_instances(session) @@ -665,6 +689,18 @@ def one_task_with_zero_mapped_tis(self, dag_maker, session): }, ) + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", + ) + assert response.status_code == 403 + def test_should_respond_404(self, test_client): response = test_client.get( "/public/dags/mapped_tis/dagRuns/run_mapped_tis/taskInstances/task_2/listMapped", @@ -1068,6 +1104,18 @@ def test_should_respond_200( assert response.json()["total_entries"] == expected_ti assert len(response.json()["task_instances"]) == expected_ti + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/public/dags/example_python_operator/dagRuns/~/taskInstances", + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/public/dags/example_python_operator/dagRuns/~/taskInstances", + ) + assert response.status_code == 403 + def test_not_found(self, test_client): response = test_client.get("/public/dags/invalid/dagRuns/~/taskInstances") assert response.status_code == 404 @@ -1085,7 +1133,6 @@ def test_bad_state(self, test_client): == f"Invalid value for state. Valid values are {', '.join(TaskInstanceState)}" ) - @pytest.mark.xfail(reason="permissions not implemented yet.") def test_return_TI_only_from_readable_dags(self, test_client, session): task_instances = { "example_python_operator": 1, @@ -1294,6 +1341,20 @@ def test_should_respond_dependencies_mapped(self, test_client, session): ) assert response.status_code == 200, response.text + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" + "print_the_context/0/dependencies", + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/" + "print_the_context/0/dependencies", + ) + assert response.status_code == 403 + class TestGetTaskInstancesBatch(TestTaskInstanceEndpoint): @pytest.mark.parametrize( @@ -1508,6 +1569,20 @@ def test_should_raise_400_for_no_json(self, test_client): }, ] + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json={}, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post( + "/public/dags/~/dagRuns/~/taskInstances/list", + json={}, + ) + assert response.status_code == 403 + def test_should_respond_422_for_non_wildcard_path_parameters(self, test_client): response = test_client.post( "/public/dags/non_wildcard/dagRuns/~/taskInstances/list", @@ -1820,6 +1895,18 @@ def test_should_respond_200_with_task_state_in_removed(self, test_client, sessio "dag_version": None, } + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries/1", + ) + assert response.status_code == 403 + def test_raises_404_for_nonexistent_task_instance(self, test_client, session): self.create_task_instances(session) response = test_client.get( @@ -2090,6 +2177,20 @@ def test_dag_run_with_future_or_past_flag_returns_400(self, test_client, session in response.json()["detail"] ) + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post( + "/public/dags/dag_id/clearTaskInstances", + json={}, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post( + "/public/dags/dag_id/clearTaskInstances", + json={}, + ) + assert response.status_code == 403 + @pytest.mark.parametrize( "main_dag, task_instances, request_dag, payload, expected_ti", [ @@ -2714,6 +2815,18 @@ def test_should_respond_200(self, test_client, session): "total_entries": 2, } + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries" + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/public/dags/example_python_operator/dagRuns/TEST_DAG_RUN_ID/taskInstances/print_the_context/tries" + ) + assert response.status_code == 403 + def test_ti_in_retry_state_not_returned(self, test_client, session): self.create_task_instances( session=session, task_instances=[{"state": State.SUCCESS}], with_ti_history=True @@ -3029,6 +3142,24 @@ def test_should_update_mapped_task_instance_state(self, test_client, session): assert response2.status_code == 200 assert response2.json()["state"] == self.NEW_STATE + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch( + self.ENDPOINT_URL, + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch( + self.ENDPOINT_URL, + json={ + "new_state": self.NEW_STATE, + }, + ) + assert response.status_code == 403 + @pytest.mark.parametrize( "error, code, payload", [ @@ -3529,6 +3660,20 @@ def test_should_not_update(self, test_client, session, payload): assert task_before == task_after + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json={}, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch( + f"{self.ENDPOINT_URL}/dry_run", + json={}, + ) + assert response.status_code == 403 + def test_should_not_update_mapped_task_instance(self, test_client, session): map_index = 1 tis = self.create_task_instances(session) diff --git a/tests/api_fastapi/core_api/routes/public/test_xcom.py b/tests/api_fastapi/core_api/routes/public/test_xcom.py index dd5a073c1baae..c665a5299ced3 100644 --- a/tests/api_fastapi/core_api/routes/public/test_xcom.py +++ b/tests/api_fastapi/core_api/routes/public/test_xcom.py @@ -149,6 +149,18 @@ def test_should_respond_200_native(self, test_client): "value": TEST_XCOM_VALUE, } + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}" + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY}" + ) + assert response.status_code == 403 + def test_should_raise_404_for_non_existent_xcom(self, test_client): response = test_client.get( f"/public/dags/{TEST_DAG_ID}/dagRuns/{run_id}/taskInstances/{TEST_TASK_ID}/xcomEntries/{TEST_XCOM_KEY_2}" @@ -437,6 +449,20 @@ def _create_xcom_entries(self, dag_id, run_id, logical_date, task_id, mapped_ti= map_index=map_index, ) + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.get( + "/public/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + params={}, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.get( + "/public/dags/~/dagRuns/~/taskInstances/~/xcomEntries", + params={}, + ) + assert response.status_code == 403 + class TestPaginationGetXComEntries(TestXComEndpoint): @pytest.mark.parametrize( @@ -583,6 +609,20 @@ def test_create_xcom_entry( assert current_data["run_id"] == dag_run_id assert current_data["map_index"] == request_body.map_index + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.post( + "/public/dags/dag_id/dagRuns/dag_run_id/taskInstances/task_id/xcomEntries", + json={}, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.post( + "/public/dags/dag_id/dagRuns/dag_run_id/taskInstances/task_id/xcomEntries", + json={}, + ) + assert response.status_code == 403 + class TestPatchXComEntry(TestXComEndpoint): @pytest.mark.parametrize( @@ -623,3 +663,17 @@ def test_patch_xcom_entry(self, key, patch_body, expected_status, expected_detai assert response.json()["value"] == XCom.serialize_value(new_value) else: assert response.json()["detail"] == expected_detail + + def test_should_respond_401(self, unauthenticated_test_client): + response = unauthenticated_test_client.patch( + f"/public/dags/{TEST_DAG_ID}/dagRuns/run_id/taskInstances/TEST_TASK_ID/xcomEntries/key", + json={}, + ) + assert response.status_code == 401 + + def test_should_respond_403(self, unauthorized_test_client): + response = unauthorized_test_client.patch( + f"/public/dags/{TEST_DAG_ID}/dagRuns/run_id/taskInstances/TEST_TASK_ID/xcomEntries/key", + json={}, + ) + assert response.status_code == 403