diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py index 2221e3ea1322b..2f7b8d7b3bf33 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/ssm.py @@ -106,6 +106,19 @@ def __init__( def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> str: event = validate_execute_complete_event(event) + if event["status"] == "failed": + # Command failed - raise an exception with detailed information + command_status = event.get("command_status", "Unknown") + exit_code = event.get("exit_code", -1) + instance_id = event.get("instance_id", "Unknown") + message = event.get("message", "Command failed") + + error_msg = ( + f"SSM run command {event['command_id']} failed on instance {instance_id}. " + f"Status: {command_status}, Exit code: {exit_code}. {message}" + ) + raise RuntimeError(error_msg) + if event["status"] != "success": raise RuntimeError(f"Error while running run command: {event}") diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py index 2c66c21c12a18..360a8f3121d6f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/ssm.py @@ -64,7 +64,7 @@ def __init__( waiter_args={"CommandId": command_id}, failure_message="SSM run command failed.", status_message="Status of SSM run command is", - status_queries=["status"], + status_queries=["Status"], return_key="command_id", return_value=command_id, waiter_delay=waiter_delay, @@ -105,19 +105,26 @@ async def run(self) -> AsyncIterator[TriggerEvent]: self.status_queries, ) except Exception: - if not self.fail_on_nonzero_exit: - # Enhanced mode: check if it's an AWS-level failure - invocation = await client.get_command_invocation( - CommandId=self.command_id, InstanceId=instance_id - ) - status = invocation.get("Status", "") + # Get detailed invocation information to determine failure type + invocation = await client.get_command_invocation( + CommandId=self.command_id, InstanceId=instance_id + ) + status = invocation.get("Status", "") + response_code = invocation.get("ResponseCode", -1) - # AWS-level failures should always raise - if SsmHook.is_aws_level_failure(status): - raise + # AWS-level failures should always raise + if SsmHook.is_aws_level_failure(status): + self.log.error( + "AWS-level failure for command %s on instance %s: status=%s", + self.command_id, + instance_id, + status, + ) + raise - # Command-level failure - tolerate it in enhanced mode - response_code = invocation.get("ResponseCode", "unknown") + # Command-level failure (non-zero exit code) + if not self.fail_on_nonzero_exit: + # Enhanced mode: tolerate command-level failures self.log.info( "Command %s completed with status %s (exit code: %s) for instance %s. " "Continuing due to fail_on_nonzero_exit=False", @@ -128,7 +135,25 @@ async def run(self) -> AsyncIterator[TriggerEvent]: ) continue else: - # Traditional mode: all failures raise - raise + # Traditional mode: yield failure event instead of raising + # This allows the operator to handle the failure gracefully + self.log.warning( + "Command %s failed with status %s (exit code: %s) for instance %s", + self.command_id, + status, + response_code, + instance_id, + ) + yield TriggerEvent( + { + "status": "failed", + "message": f"Command failed with status {status} (exit code: {response_code})", + "command_status": status, + "exit_code": response_code, + "instance_id": instance_id, + self.return_key: self.return_value, + } + ) + return yield TriggerEvent({"status": "success", self.return_key: self.return_value}) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py index 0dacf4de95da8..500b3993e73e5 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_ssm.py @@ -237,6 +237,61 @@ def test_operator_passes_parameter_to_trigger(self, mock_trigger_class, mock_con assert call_kwargs["command_id"] == COMMAND_ID assert call_kwargs["fail_on_nonzero_exit"] is False + def test_execute_complete_success(self): + """Test execute_complete with successful event.""" + event = {"status": "success", "command_id": COMMAND_ID} + + result = self.operator.execute_complete({}, event) + + assert result == COMMAND_ID + + def test_execute_complete_failure_event(self): + """Test execute_complete with failure event from trigger.""" + event = { + "status": "failed", + "command_id": COMMAND_ID, + "command_status": "Failed", + "exit_code": 1, + "instance_id": "i-123456", + "message": "Command failed with status Failed (exit code: 1)", + } + + with pytest.raises(RuntimeError) as exc_info: + self.operator.execute_complete({}, event) + + error_msg = str(exc_info.value) + assert COMMAND_ID in error_msg + assert "Failed" in error_msg + assert "exit code: 1" in error_msg + assert "i-123456" in error_msg + + def test_execute_complete_failure_event_with_different_exit_codes(self): + """Test execute_complete properly reports different exit codes in error messages.""" + event = { + "status": "failed", + "command_id": COMMAND_ID, + "command_status": "Failed", + "exit_code": 42, + "instance_id": "i-789012", + "message": "Command failed with status Failed (exit code: 42)", + } + + with pytest.raises(RuntimeError) as exc_info: + self.operator.execute_complete({}, event) + + error_msg = str(exc_info.value) + assert "exit code: 42" in error_msg + assert "i-789012" in error_msg + + def test_execute_complete_unknown_status(self): + """Test execute_complete with unknown status.""" + event = {"status": "unknown", "command_id": COMMAND_ID} + + with pytest.raises(RuntimeError) as exc_info: + self.operator.execute_complete({}, event) + + assert "Error while running run command" in str(exc_info.value) + class TestSsmGetCommandInvocationOperator: @pytest.fixture diff --git a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py index ad328e41e374e..f2aa4c05efe0c 100644 --- a/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py +++ b/providers/amazon/tests/unit/amazon/aws/triggers/test_ssm.py @@ -105,10 +105,14 @@ async def test_run_success(self, mock_get_waiter, mock_get_async_conn, mock_ssm_ @mock.patch.object(SsmHook, "get_async_conn") @mock.patch.object(SsmHook, "get_waiter") async def test_run_fails(self, mock_get_waiter, mock_get_async_conn, mock_ssm_list_invocations): - mock_ssm_list_invocations(mock_get_async_conn) + mock_client = mock_ssm_list_invocations(mock_get_async_conn) mock_get_waiter().wait.side_effect = WaiterError( "name", "terminal failure", {"CommandInvocations": [{"CommandId": COMMAND_ID}]} ) + # Mock get_command_invocation to return AWS-level failure + mock_client.get_command_invocation = mock.AsyncMock( + return_value={"Status": "TimedOut", "ResponseCode": -1} + ) trigger = SsmRunCommandTrigger(command_id=COMMAND_ID) generator = trigger.run() @@ -124,8 +128,12 @@ async def test_trigger_default_fails_on_waiter_error( self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations ): """Test traditional mode (fail_on_nonzero_exit=True) raises exception on waiter error.""" - mock_ssm_list_invocations(mock_get_async_conn) + mock_client = mock_ssm_list_invocations(mock_get_async_conn) mock_async_wait.side_effect = AirflowException("SSM run command failed.") + # Mock get_command_invocation to return AWS-level failure + mock_client.get_command_invocation = mock.AsyncMock( + return_value={"Status": "Cancelled", "ResponseCode": -1} + ) trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True) generator = trigger.run() @@ -204,3 +212,80 @@ def test_trigger_serialization_includes_parameter(self): classpath, kwargs = trigger_default.serialize() assert kwargs.get("fail_on_nonzero_exit") is True + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait") + @mock.patch.object(SsmHook, "get_async_conn") + @mock.patch.object(SsmHook, "get_waiter") + async def test_trigger_yields_failure_event_instead_of_raising( + self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations + ): + """Test that trigger yields failure event instead of raising exception for command failures.""" + mock_client = mock_ssm_list_invocations(mock_get_async_conn) + # Mock async_wait to raise exception (simulating waiter failure) + mock_async_wait.side_effect = AirflowException("SSM run command failed.") + # Mock get_command_invocation to return Failed status with exit code 1 + mock_client.get_command_invocation = mock.AsyncMock( + return_value={"Status": "Failed", "ResponseCode": 1} + ) + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True) + generator = trigger.run() + response = await generator.asend(None) + + # Should yield a failure event, not raise an exception + assert response.payload["status"] == "failed" + assert response.payload["command_id"] == COMMAND_ID + assert response.payload["exit_code"] == 1 + assert response.payload["command_status"] == "Failed" + assert response.payload["instance_id"] == INSTANCE_ID_1 + assert "Command failed with status Failed (exit code: 1)" in response.payload["message"] + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait") + @mock.patch.object(SsmHook, "get_async_conn") + @mock.patch.object(SsmHook, "get_waiter") + async def test_trigger_yields_failure_event_for_different_exit_codes( + self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations + ): + """Test that trigger properly captures different exit codes in failure events.""" + mock_client = mock_ssm_list_invocations(mock_get_async_conn) + mock_async_wait.side_effect = AirflowException("SSM run command failed.") + + # Test with exit code 2 + mock_client.get_command_invocation = mock.AsyncMock( + return_value={"Status": "Failed", "ResponseCode": 2} + ) + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True) + generator = trigger.run() + response = await generator.asend(None) + + assert response.payload["status"] == "failed" + assert response.payload["exit_code"] == 2 + assert response.payload["command_status"] == "Failed" + + @pytest.mark.asyncio + @mock.patch("airflow.providers.amazon.aws.triggers.ssm.async_wait") + @mock.patch.object(SsmHook, "get_async_conn") + @mock.patch.object(SsmHook, "get_waiter") + async def test_trigger_continues_on_second_instance_after_first_fails( + self, mock_get_waiter, mock_get_async_conn, mock_async_wait, mock_ssm_list_invocations + ): + """Test that trigger stops after first failure and yields failure event.""" + mock_client = mock_ssm_list_invocations(mock_get_async_conn) + # First instance fails + mock_async_wait.side_effect = AirflowException("SSM run command failed.") + mock_client.get_command_invocation = mock.AsyncMock( + return_value={"Status": "Failed", "ResponseCode": 1} + ) + + trigger = SsmRunCommandTrigger(command_id=COMMAND_ID, fail_on_nonzero_exit=True) + generator = trigger.run() + response = await generator.asend(None) + + # Should yield failure event for first instance + assert response.payload["status"] == "failed" + assert response.payload["instance_id"] == INSTANCE_ID_1 + # Should only call get_command_invocation once (for first instance) + assert mock_client.get_command_invocation.call_count == 1