From 5aaa7b766119f1c77cb35401634ecdea1e9ad091 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Wed, 12 Mar 2025 16:52:04 +0800 Subject: [PATCH] feat(AIP-84): add auth to /execution/task-instances --- .../execution_api/routes/task_instances.py | 10 +- .../routes/test_task_instances.py | 274 ++++++++++++++---- 2 files changed, 228 insertions(+), 56 deletions(-) diff --git a/airflow/api_fastapi/execution_api/routes/task_instances.py b/airflow/api_fastapi/execution_api/routes/task_instances.py index e6fcc78e46171..92c3feca085bc 100644 --- a/airflow/api_fastapi/execution_api/routes/task_instances.py +++ b/airflow/api_fastapi/execution_api/routes/task_instances.py @@ -22,14 +22,16 @@ from typing import Annotated from uuid import UUID -from fastapi import Body, HTTPException, status +from fastapi import Body, Depends, HTTPException, status from pydantic import JsonValue from sqlalchemy import func, update from sqlalchemy.exc import NoResultFound, SQLAlchemyError from sqlalchemy.sql import select +from airflow.api_fastapi.auth.managers.models.resource_details import DagAccessEntity from airflow.api_fastapi.common.db.common import SessionDep from airflow.api_fastapi.common.router import AirflowRouter +from airflow.api_fastapi.core_api.security import requires_access_dag from airflow.api_fastapi.execution_api.datamodels.taskinstance import ( PrevSuccessfulDagRunResponse, TIDeferredStatePayload, @@ -67,6 +69,7 @@ status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, }, response_model_exclude_unset=True, + dependencies=[Depends(requires_access_dag(method="POST", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def ti_run( task_instance_id: UUID, ti_run_payload: Annotated[TIEnterRunningPayload, Body()], session: SessionDep @@ -247,6 +250,7 @@ def ti_run( status.HTTP_409_CONFLICT: {"description": "The TI is already in the requested state"}, status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, }, + dependencies=[Depends(requires_access_dag(method="POST", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def ti_update_state( task_instance_id: UUID, @@ -401,6 +405,7 @@ def ti_update_state( }, status.HTTP_422_UNPROCESSABLE_ENTITY: {"description": "Invalid payload for the state transition"}, }, + dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def ti_heartbeat( task_instance_id: UUID, @@ -465,6 +470,7 @@ def ti_heartbeat( "description": "Invalid payload for the setting rendered task instance fields" }, }, + dependencies=[Depends(requires_access_dag(method="PUT", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def ti_put_rtif( task_instance_id: UUID, @@ -489,6 +495,7 @@ def ti_put_rtif( responses={ status.HTTP_404_NOT_FOUND: {"description": "Task Instance or Dag Run not found"}, }, + dependencies=[Depends(requires_access_dag(method="GET", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def get_previous_successful_dagrun( task_instance_id: UUID, session: SessionDep @@ -534,6 +541,7 @@ def get_previous_successful_dagrun( "description": "Invalid payload for requested runtime checks on the Task Instance." }, }, + dependencies=[Depends(requires_access_dag(method="POST", access_entity=DagAccessEntity.TASK_INSTANCE))], ) def ti_runtime_checks( task_instance_id: UUID, diff --git a/tests/api_fastapi/execution_api/routes/test_task_instances.py b/tests/api_fastapi/execution_api/routes/test_task_instances.py index 19701aa6a4b03..1260a34c47496 100644 --- a/tests/api_fastapi/execution_api/routes/test_task_instances.py +++ b/tests/api_fastapi/execution_api/routes/test_task_instances.py @@ -63,7 +63,7 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - def test_ti_run_state_to_running(self, client, session, create_task_instance, time_machine): + def test_ti_run_state_to_running(self, test_client, session, create_task_instance, time_machine): """ Test that the Task Instance state is updated to running when the Task Instance is in a state where it can be marked as running. @@ -80,7 +80,7 @@ def test_ti_run_state_to_running(self, client, session, create_task_instance, ti ) session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/run", json={ "state": "running", @@ -124,7 +124,7 @@ def test_ti_run_state_to_running(self, client, session, create_task_instance, ti # Test that if we make a second request (simulating a network glitch so the client issues a retry) # that it is accepted and we get the same info back - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/run", json={ "state": "running", @@ -138,7 +138,7 @@ def test_ti_run_state_to_running(self, client, session, create_task_instance, ti assert response.json() == response1 # But that for a different pid on the same host (etc) it fails - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/run", json={ "state": "running", @@ -150,7 +150,7 @@ def test_ti_run_state_to_running(self, client, session, create_task_instance, ti ) assert response.status_code == 409 - def test_next_kwargs_still_encoded(self, client, session, create_task_instance, time_machine): + def test_next_kwargs_still_encoded(self, test_client, session, create_task_instance, time_machine): instant_str = "2024-09-30T12:00:00Z" instant = timezone.parse(instant_str) time_machine.move_to(instant, tick=False) @@ -168,7 +168,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance, session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/run", json={ "state": "running", @@ -195,7 +195,7 @@ def test_next_kwargs_still_encoded(self, client, session, create_task_instance, @pytest.mark.parametrize("initial_ti_state", [s for s in TaskInstanceState if s != State.QUEUED]) def test_ti_run_state_conflict_if_not_queued( - self, client, session, create_task_instance, initial_ti_state + self, test_client, session, create_task_instance, initial_ti_state ): """ Test that a 409 error is returned when the Task Instance is not in a state where it can be marked as @@ -207,7 +207,7 @@ def test_ti_run_state_conflict_if_not_queued( ) session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/run", json={ "state": "running", @@ -229,7 +229,7 @@ def test_ti_run_state_conflict_if_not_queued( assert session.scalar(select(TaskInstance.state).where(TaskInstance.id == ti.id)) == initial_ti_state - def test_xcom_cleared_when_ti_runs(self, client, session, create_task_instance, time_machine): + def test_xcom_cleared_when_ti_runs(self, test_client, session, create_task_instance, time_machine): """ Test that the xcoms are cleared when the Task Instance state is updated to running. """ @@ -248,7 +248,7 @@ def test_xcom_cleared_when_ti_runs(self, client, session, create_task_instance, # Lets stage a xcom push ti.xcom_push(key="key", value="value") - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/run", json={ "state": "running", @@ -263,7 +263,7 @@ def test_xcom_cleared_when_ti_runs(self, client, session, create_task_instance, # Once the task is running, we can check if xcom is cleared assert ti.xcom_pull(task_ids="test_xcom_cleared_when_ti_runs", key="key") is None - def test_xcom_not_cleared_for_deferral(self, client, session, create_task_instance, time_machine): + def test_xcom_not_cleared_for_deferral(self, test_client, session, create_task_instance, time_machine): """ Test that the xcoms are not cleared when the Task Instance state is re-running after deferral. """ @@ -288,7 +288,7 @@ def test_xcom_not_cleared_for_deferral(self, client, session, create_task_instan "next_method": "execute_callback", } - response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + response = test_client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) assert response.status_code == 204 assert response.text == "" session.expire_all() @@ -301,7 +301,7 @@ def test_xcom_not_cleared_for_deferral(self, client, session, create_task_instan # Lets stage a xcom push ti.xcom_push(key="key", value="value") - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/run", json={ "state": "running", @@ -315,6 +315,44 @@ def test_xcom_not_cleared_for_deferral(self, client, session, create_task_instan assert response.status_code == 200 assert ti.xcom_pull(task_ids="test_xcom_not_cleared_for_deferral", key="key") == "value" + def test_get_config_should_response_401(self, unauthenticated_test_client, session, create_task_instance): + ti = create_task_instance( + task_id="test_401", + state=State.QUEUED, + ) + session.commit() + + response = unauthenticated_test_client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": "2024-10-31T12:00:00Z", + }, + ) + assert response.status_code == 401 + + def test_get_config_should_response_403(self, unauthorized_test_client, session, create_task_instance): + ti = create_task_instance( + task_id="test_403", + state=State.QUEUED, + ) + session.commit() + + response = unauthorized_test_client.patch( + f"/execution/task-instances/{ti.id}/run", + json={ + "state": "running", + "hostname": "random-hostname", + "unixname": "random-unixname", + "pid": 100, + "start_date": "2024-10-31T12:00:00Z", + }, + ) + assert response.status_code == 403 + class TestTIUpdateState: def setup_method(self): @@ -332,7 +370,7 @@ def teardown_method(self): ], ) def test_ti_update_state_to_terminal( - self, client, session, create_task_instance, state, end_date, expected_state + self, test_client, session, create_task_instance, state, end_date, expected_state ): ti = create_task_instance( task_id="test_ti_update_state_to_terminal", @@ -341,7 +379,7 @@ def test_ti_update_state_to_terminal( ) session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/state", json={ "state": state, @@ -384,7 +422,7 @@ def test_ti_update_state_to_terminal( ], ) def test_ti_update_state_to_success_with_asset_events( - self, client, session, create_task_instance, task_outlets, outlet_events + self, test_client, session, create_task_instance, task_outlets, outlet_events ): clear_db_assets() clear_db_runs() @@ -412,7 +450,7 @@ def test_ti_update_state_to_success_with_asset_events( ) session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/state", json={ "state": "success", @@ -439,7 +477,7 @@ def test_ti_update_state_to_success_with_asset_events( if asset_type == "AssetAlias": assert event[0].source_aliases == [AssetAliasModel(name="example-alias")] - def test_ti_update_state_not_found(self, client, session): + def test_ti_update_state_not_found(self, test_client, session): """ Test that a 404 error is returned when the Task Instance does not exist. """ @@ -450,14 +488,14 @@ def test_ti_update_state_not_found(self, client, session): payload = {"state": "success", "end_date": "2024-10-31T12:30:00Z"} - response = client.patch(f"/execution/task-instances/{task_instance_id}/state", json=payload) + response = test_client.patch(f"/execution/task-instances/{task_instance_id}/state", json=payload) assert response.status_code == 404 assert response.json()["detail"] == { "reason": "not_found", "message": "Task Instance not found", } - def test_ti_update_state_running_errors(self, client, session, create_task_instance, time_machine): + def test_ti_update_state_running_errors(self, test_client, session, create_task_instance, time_machine): """ Test that a 422 error is returned when the Task Instance state is RUNNING in the payload. @@ -473,11 +511,11 @@ def test_ti_update_state_running_errors(self, client, session, create_task_insta session.commit() - response = client.patch(f"/execution/task-instances/{ti.id}/state", json={"state": "running"}) + response = test_client.patch(f"/execution/task-instances/{ti.id}/state", json={"state": "running"}) assert response.status_code == 422 - def test_ti_update_state_database_error(self, client, session, create_task_instance): + def test_ti_update_state_database_error(self, test_client, session, create_task_instance): """ Test that a database error is handled correctly when updating the Task Instance state. """ @@ -505,11 +543,11 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta ) as mock_register_asset_changes_in_db, ): mock_register_asset_changes_in_db.return_value = None - response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + response = test_client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) assert response.status_code == 500 assert response.json()["detail"] == "Database error occurred" - def test_ti_update_state_to_deferred(self, client, session, create_task_instance, time_machine): + def test_ti_update_state_to_deferred(self, test_client, session, create_task_instance, time_machine): """ Test that tests if the transition to deferred state is handled correctly. """ @@ -540,7 +578,7 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance }, } - response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + response = test_client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) assert response.status_code == 204 assert response.text == "" @@ -567,7 +605,7 @@ def test_ti_update_state_to_deferred(self, client, session, create_task_instance "moment": datetime(2024, 12, 18, 00, 00, 1, tzinfo=timezone.utc), } - def test_ti_update_state_to_reschedule(self, client, session, create_task_instance, time_machine): + def test_ti_update_state_to_reschedule(self, test_client, session, create_task_instance, time_machine): """ Test that tests if the transition to reschedule state is handled correctly. """ @@ -589,7 +627,7 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan "end_date": DEFAULT_END_DATE.isoformat(), } - response = client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) + response = test_client.patch(f"/execution/task-instances/{ti.id}/state", json=payload) assert response.status_code == 204 assert response.text == "" @@ -624,7 +662,7 @@ def test_ti_update_state_to_reschedule(self, client, session, create_task_instan ], ) def test_ti_update_state_to_failed_with_retries( - self, client, session, create_task_instance, retries, expected_state + self, test_client, session, create_task_instance, retries, expected_state ): ti = create_task_instance( task_id="test_ti_update_state_to_retry", @@ -635,7 +673,7 @@ def test_ti_update_state_to_failed_with_retries( ti.max_tries = retries session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/state", json={ "state": TerminalTIState.FAILED, @@ -663,7 +701,7 @@ def test_ti_update_state_to_failed_with_retries( assert tih.try_id assert tih.try_id != ti.try_id - def test_ti_update_state_when_ti_is_restarting(self, client, session, create_task_instance): + def test_ti_update_state_when_ti_is_restarting(self, test_client, session, create_task_instance): ti = create_task_instance( task_id="test_ti_update_state_when_ti_is_restarting", state=State.RUNNING, @@ -672,7 +710,7 @@ def test_ti_update_state_when_ti_is_restarting(self, client, session, create_tas ti.state = State.RESTARTING session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/state", json={ "state": TerminalTIState.FAILED, @@ -692,7 +730,7 @@ def test_ti_update_state_when_ti_is_restarting(self, client, session, create_tas assert ti.next_kwargs is None def test_ti_update_state_when_ti_has_higher_tries_than_retries( - self, client, session, create_task_instance + self, test_client, session, create_task_instance ): ti = create_task_instance( task_id="test_ti_update_state_when_ti_has_higher_tries_than_retries", @@ -703,7 +741,7 @@ def test_ti_update_state_when_ti_has_higher_tries_than_retries( ti.try_number = 3 session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/state", json={ "state": TerminalTIState.FAILED, @@ -722,7 +760,9 @@ def test_ti_update_state_when_ti_has_higher_tries_than_retries( assert ti.next_method is None assert ti.next_kwargs is None - def test_ti_update_state_to_failed_without_retry_table_check(self, client, session, create_task_instance): + def test_ti_update_state_to_failed_without_retry_table_check( + self, test_client, session, create_task_instance + ): # we just want to fail in this test, no need to retry ti = create_task_instance( task_id="test_ti_update_state_to_failed_table_check", @@ -731,7 +771,7 @@ def test_ti_update_state_to_failed_without_retry_table_check(self, client, sessi ti.start_date = DEFAULT_START_DATE session.commit() - response = client.patch( + response = test_client.patch( f"/execution/task-instances/{ti.id}/state", json={ "state": TerminalTIState.FAIL_WITHOUT_RETRY, @@ -760,7 +800,7 @@ def test_ti_update_state_to_failed_without_retry_table_check(self, client, sessi ], ) def test_ti_runtime_checks_success( - self, client, session, create_task_instance, state, expected_status_code + self, test_client, session, create_task_instance, state, expected_status_code ): ti = create_task_instance( task_id="test_ti_runtime_checks", @@ -772,7 +812,7 @@ def test_ti_runtime_checks_success( "airflow.models.taskinstance.TaskInstance.validate_inlet_outlet_assets_activeness" ) as mock_validate_inlet_outlet_assets_activeness: mock_validate_inlet_outlet_assets_activeness.return_value = None - response = client.post( + response = test_client.post( f"/execution/task-instances/{ti.id}/runtime-checks", json={ "inlets": [], @@ -784,7 +824,7 @@ def test_ti_runtime_checks_success( session.expire_all() - def test_ti_runtime_checks_failure(self, client, session, create_task_instance): + def test_ti_runtime_checks_failure(self, test_client, session, create_task_instance): ti = create_task_instance( task_id="test_ti_runtime_checks_failure", state=State.RUNNING, @@ -797,7 +837,7 @@ def test_ti_runtime_checks_failure(self, client, session, create_task_instance): mock_validate_inlet_outlet_assets_activeness.side_effect = ( AirflowInactiveAssetInInletOrOutletException([AssetUniqueKey(name="abc", uri="something")]) ) - response = client.post( + response = test_client.post( f"/execution/task-instances/{ti.id}/runtime-checks", json={ "inlets": [], @@ -809,6 +849,38 @@ def test_ti_runtime_checks_failure(self, client, session, create_task_instance): session.expire_all() + def test_get_config_should_response_401(self, unauthenticated_test_client, session, create_task_instance): + ti = create_task_instance( + task_id="test_ti_runtime_checks_failure", + state=State.RUNNING, + ) + session.commit() + + response = unauthenticated_test_client.post( + f"/execution/task-instances/{ti.id}/runtime-checks", + json={ + "inlets": [], + "outlets": [], + }, + ) + assert response.status_code == 401 + + def test_get_config_should_response_403(self, unauthorized_test_client, session, create_task_instance): + ti = create_task_instance( + task_id="test_ti_runtime_checks_failure", + state=State.RUNNING, + ) + session.commit() + + response = unauthorized_test_client.post( + f"/execution/task-instances/{ti.id}/runtime-checks", + json={ + "inlets": [], + "outlets": [], + }, + ) + assert response.status_code == 403 + class TestTIHealthEndpoint: def setup_method(self): @@ -850,7 +922,7 @@ def teardown_method(self): ) def test_ti_heartbeat( self, - client, + test_client, session, create_task_instance, hostname, @@ -878,7 +950,7 @@ def test_ti_heartbeat( # Pre-condition: TI heartbeat is NONE assert ti.last_heartbeat_at is None - response = client.put( + response = test_client.put( f"/execution/task-instances/{task_instance_id}/heartbeat", json={"hostname": hostname, "pid": pid}, ) @@ -894,7 +966,7 @@ def test_ti_heartbeat( # If there's an error, check the error detail assert response.json()["detail"] == expected_detail - def test_ti_heartbeat_non_existent_task(self, client, session, create_task_instance): + def test_ti_heartbeat_non_existent_task(self, test_client, session, create_task_instance): """Test that a 404 error is returned when the Task Instance does not exist.""" task_instance_id = "0182e924-0f1e-77e6-ab50-e977118bc139" @@ -902,7 +974,7 @@ def test_ti_heartbeat_non_existent_task(self, client, session, create_task_insta # Pre-condition: the Task Instance does not exist assert session.get(TaskInstance, task_instance_id) is None - response = client.put( + response = test_client.put( f"/execution/task-instances/{task_instance_id}/heartbeat", json={"hostname": "random-hostname", "pid": 1547}, ) @@ -917,7 +989,7 @@ def test_ti_heartbeat_non_existent_task(self, client, session, create_task_insta "ti_state", [State.SUCCESS, State.FAILED], ) - def test_ti_heartbeat_when_task_not_running(self, client, session, create_task_instance, ti_state): + def test_ti_heartbeat_when_task_not_running(self, test_client, session, create_task_instance, ti_state): """Test that a 409 error is returned when the Task Instance is not in RUNNING state.""" ti = create_task_instance( @@ -930,7 +1002,7 @@ def test_ti_heartbeat_when_task_not_running(self, client, session, create_task_i session.commit() task_instance_id = ti.id - response = client.put( + response = test_client.put( f"/execution/task-instances/{task_instance_id}/heartbeat", json={"hostname": "random-hostname", "pid": 1547}, ) @@ -942,7 +1014,7 @@ def test_ti_heartbeat_when_task_not_running(self, client, session, create_task_i "current_state": ti_state, } - def test_ti_heartbeat_update(self, client, session, create_task_instance, time_machine): + def test_ti_heartbeat_update(self, test_client, session, create_task_instance, time_machine): """Test that the Task Instance heartbeat is updated when the Task Instance is running.""" # Set initial time for the test @@ -967,7 +1039,7 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m new_time = time_now.add(minutes=10) time_machine.move_to(new_time, tick=False) - response = client.put( + response = test_client.put( f"/execution/task-instances/{task_instance_id}/heartbeat", json={"hostname": "random-hostname", "pid": 1547}, ) @@ -978,6 +1050,39 @@ def test_ti_heartbeat_update(self, client, session, create_task_instance, time_m session.refresh(ti) assert ti.last_heartbeat_at == time_now.add(minutes=10) + def test_get_config_should_response_401(self, unauthenticated_test_client, session, create_task_instance): + ti = create_task_instance( + task_id="test_ti_heartbeat_update", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + last_heartbeat_at=timezone.parse("2024-10-31T12:00:00Z"), + session=session, + ) + session.commit() + + response = unauthenticated_test_client.put( + f"/execution/task-instances/{ti.id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547}, + ) + assert response.status_code == 401 + + def test_get_config_should_response_403(self, unauthorized_test_client, session, create_task_instance): + ti = create_task_instance( + task_id="test_ti_heartbeat_update", + state=State.RUNNING, + hostname="random-hostname", + pid=1547, + last_heartbeat_at=timezone.parse("2024-10-31T12:00:00Z"), + session=session, + ) + session.commit() + response = unauthorized_test_client.put( + f"/execution/task-instances/{ti.id}/heartbeat", + json={"hostname": "random-hostname", "pid": 1547}, + ) + assert response.status_code == 403 + class TestTIPutRTIF: def setup_method(self): @@ -1006,14 +1111,14 @@ def teardown_method(self): }, ], ) - def test_ti_put_rtif_success(self, client, session, create_task_instance, payload): + def test_ti_put_rtif_success(self, test_client, session, create_task_instance, payload): ti = create_task_instance( task_id="test_ti_put_rtif_success", state=State.RUNNING, session=session, ) session.commit() - response = client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload) + response = test_client.put(f"/execution/task-instances/{ti.id}/rtif", json=payload) assert response.status_code == 201 assert response.json() == {"message": "Rendered task instance fields successfully set"} @@ -1028,7 +1133,7 @@ def test_ti_put_rtif_success(self, client, session, create_task_instance, payloa assert rtifs[0].map_index == -1 assert rtifs[0].rendered_fields == payload - def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance): + def test_ti_put_rtif_missing_ti(self, test_client, session, create_task_instance): create_task_instance( task_id="test_ti_put_rtif_missing_ti", state=State.RUNNING, @@ -1039,10 +1144,38 @@ def test_ti_put_rtif_missing_ti(self, client, session, create_task_instance): payload = {"field1": "rendered_value1", "field2": "rendered_value2"} random_id = uuid6.uuid7() - response = client.put(f"/execution/task-instances/{random_id}/rtif", json=payload) + response = test_client.put(f"/execution/task-instances/{random_id}/rtif", json=payload) assert response.status_code == 404 assert response.json()["detail"] == "Not Found" + def test_get_config_should_response_401(self, unauthenticated_test_client, create_task_instance, session): + create_task_instance( + task_id="test_ti_put_rtif_missing_ti", + state=State.RUNNING, + session=session, + ) + session.commit() + + response = unauthenticated_test_client.put( + f"/execution/task-instances/{uuid6.uuid7()}/rtif", + json={"field1": "rendered_value1", "field2": "rendered_value2"}, + ) + assert response.status_code == 401 + + def test_get_config_should_response_403(self, unauthorized_test_client, create_task_instance, session): + create_task_instance( + task_id="test_ti_put_rtif_missing_ti", + state=State.RUNNING, + session=session, + ) + session.commit() + + response = unauthorized_test_client.put( + f"/execution/task-instances/{uuid6.uuid7()}/rtif", + json={"field1": "rendered_value1", "field2": "rendered_value2"}, + ) + assert response.status_code == 403 + class TestPreviousDagRun: def setup_method(self): @@ -1051,7 +1184,7 @@ def setup_method(self): def teardown_method(self): clear_db_runs() - def test_ti_previous_dag_run(self, client, session, create_task_instance, dag_maker): + def test_ti_previous_dag_run(self, test_client, session, create_task_instance, dag_maker): """Test that the previous dag run is returned correctly for a task instance.""" ti = create_task_instance( task_id="test_ti_previous_dag_run", @@ -1085,7 +1218,7 @@ def test_ti_previous_dag_run(self, client, session, create_task_instance, dag_ma session.commit() - response = client.get(f"/execution/task-instances/{ti.id}/previous-successful-dagrun") + response = test_client.get(f"/execution/task-instances/{ti.id}/previous-successful-dagrun") assert response.status_code == 200 assert response.json() == { "data_interval_start": "2025-01-18T00:00:00Z", @@ -1094,12 +1227,12 @@ def test_ti_previous_dag_run(self, client, session, create_task_instance, dag_ma "end_date": "2025-01-18T01:00:00Z", } - def test_ti_previous_dag_run_not_found(self, client, session): + def test_ti_previous_dag_run_not_found(self, test_client, session): ti_id = "0182e924-0f1e-77e6-ab50-e977118bc139" assert session.get(TaskInstance, ti_id) is None - response = client.get(f"/execution/task-instances/{ti_id}/previous-successful-dagrun") + response = test_client.get(f"/execution/task-instances/{ti_id}/previous-successful-dagrun") assert response.status_code == 200 assert response.json() == { "data_interval_start": None, @@ -1107,3 +1240,34 @@ def test_ti_previous_dag_run_not_found(self, client, session): "start_date": None, "end_date": None, } + + def test_get_config_should_response_401(self, unauthenticated_test_client, create_task_instance, session): + ti = create_task_instance( + task_id="test_ti_previous_dag_run", + dag_id="test_dag", + logical_date=timezone.datetime(2025, 1, 19), + state=State.RUNNING, + start_date=timezone.datetime(2024, 1, 17), + session=session, + ) + session.commit() + + response = unauthenticated_test_client.get( + f"/execution/task-instances/{ti.id}/previous-successful-dagrun" + ) + assert response.status_code == 401 + + def test_get_config_should_response_403(self, unauthorized_test_client, create_task_instance, session): + ti = create_task_instance( + task_id="test_ti_previous_dag_run", + dag_id="test_dag", + logical_date=timezone.datetime(2025, 1, 19), + state=State.RUNNING, + start_date=timezone.datetime(2024, 1, 17), + session=session, + ) + session.commit() + response = unauthorized_test_client.get( + f"/execution/task-instances/{ti.id}/previous-successful-dagrun" + ) + assert response.status_code == 403