From 1dd84419cefeef6ae6c75afd9d653b8864156cd1 Mon Sep 17 00:00:00 2001 From: Nishieee Date: Thu, 14 May 2026 15:10:07 -0400 Subject: [PATCH 1/2] Preserve Databricks deferrable trigger caller across triggerer restarts --- .../databricks/triggers/databricks.py | 4 +++ .../databricks/triggers/test_databricks.py | 25 +++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 67f8a392a0cce..916adc3dad338 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -61,6 +61,7 @@ def __init__( self.retry_args = retry_args self.run_page_url = run_page_url self.repair_run = repair_run + self.caller = caller self.hook = DatabricksHook( databricks_conn_id, retry_limit=self.retry_limit, @@ -81,6 +82,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "retry_args": self.retry_args, "run_page_url": self.run_page_url, "repair_run": self.repair_run, + "caller": self.caller, }, ) @@ -153,6 +155,7 @@ def __init__( self.retry_limit = retry_limit self.retry_delay = retry_delay self.retry_args = retry_args + self.caller = caller self.hook = DatabricksHook( databricks_conn_id, retry_limit=self.retry_limit, @@ -172,6 +175,7 @@ def serialize(self) -> tuple[str, dict[str, Any]]: "retry_limit": self.retry_limit, "retry_delay": self.retry_delay, "retry_args": self.retry_args, + "caller": self.caller, }, ) diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index d2465077534bb..d8ee461b2dbb9 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -152,9 +152,21 @@ def test_serialize(self): "retry_args": None, "run_page_url": RUN_PAGE_URL, "repair_run": False, + "caller": "DatabricksExecutionTrigger", }, ) + def test_serialize_round_trip_caller(self): + caller = "DatabricksSubmitRunOperator" + trigger = DatabricksExecutionTrigger( + run_id=RUN_ID, + databricks_conn_id=DEFAULT_CONN_ID, + caller=caller, + ) + _, kwargs = trigger.serialize() + restored = DatabricksExecutionTrigger(**kwargs) + assert restored.caller == caller + @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run") @@ -299,9 +311,22 @@ def test_serialize(self): "retry_delay": 10, "retry_limit": 3, "retry_args": None, + "caller": "DatabricksSQLStatementExecutionTrigger", }, ) + def test_serialize_round_trip_caller(self): + caller = "DatabricksSqlOperator" + trigger = DatabricksSQLStatementExecutionTrigger( + statement_id=STATEMENT_ID, + databricks_conn_id=DEFAULT_CONN_ID, + end_time=self.end_time, + caller=caller, + ) + _, kwargs = trigger.serialize() + restored = DatabricksSQLStatementExecutionTrigger(**kwargs) + assert restored.caller == caller + @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_sql_statement_state") async def test_run_return_success(self, mock_a_get_sql_statement_state): From 2d1f4e30b20b8d5aaf93a7a6cc612c509a876c46 Mon Sep 17 00:00:00 2001 From: Nishieee Date: Fri, 15 May 2026 10:49:13 -0400 Subject: [PATCH 2/2] Address review: add trigger docstrings and shared CALLER test constant --- .../providers/databricks/triggers/databricks.py | 3 +++ .../tests/unit/databricks/triggers/test_databricks.py | 11 +++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py index 916adc3dad338..25cade7fc80dc 100644 --- a/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py +++ b/providers/databricks/src/airflow/providers/databricks/triggers/databricks.py @@ -38,6 +38,8 @@ class DatabricksExecutionTrigger(BaseTrigger): :param retry_delay: The number of seconds to wait between retries. :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. :param run_page_url: The run page url. + :param repair_run: Repair the databricks run in case of failure. + :param caller: The name of the operator that is calling the hook. """ def __init__( @@ -134,6 +136,7 @@ class DatabricksSQLStatementExecutionTrigger(BaseTrigger): :param retry_limit: The number of times to retry the connection in case of service outages. :param retry_delay: The number of seconds to wait between retries. :param retry_args: An optional dictionary with arguments passed to ``tenacity.Retrying`` class. + :param caller: The name of the operator that is calling the hook. """ def __init__( diff --git a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py index d8ee461b2dbb9..903173774b73a 100644 --- a/providers/databricks/tests/unit/databricks/triggers/test_databricks.py +++ b/providers/databricks/tests/unit/databricks/triggers/test_databricks.py @@ -50,6 +50,7 @@ TASK_RUN_ID3_KEY = "third_task" JOB_ID = 42 RUN_PAGE_URL = "https://XX.cloud.databricks.com/#jobs/1/runs/1" +CALLER = "DatabricksSubmitRunOperator" ERROR_MESSAGE = "error message from databricks API" GET_RUN_OUTPUT_RESPONSE = {"metadata": {}, "error": ERROR_MESSAGE, "notebook_output": {}} @@ -157,15 +158,14 @@ def test_serialize(self): ) def test_serialize_round_trip_caller(self): - caller = "DatabricksSubmitRunOperator" trigger = DatabricksExecutionTrigger( run_id=RUN_ID, databricks_conn_id=DEFAULT_CONN_ID, - caller=caller, + caller=CALLER, ) _, kwargs = trigger.serialize() restored = DatabricksExecutionTrigger(**kwargs) - assert restored.caller == caller + assert restored.caller == CALLER @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_run_output") @@ -316,16 +316,15 @@ def test_serialize(self): ) def test_serialize_round_trip_caller(self): - caller = "DatabricksSqlOperator" trigger = DatabricksSQLStatementExecutionTrigger( statement_id=STATEMENT_ID, databricks_conn_id=DEFAULT_CONN_ID, end_time=self.end_time, - caller=caller, + caller=CALLER, ) _, kwargs = trigger.serialize() restored = DatabricksSQLStatementExecutionTrigger(**kwargs) - assert restored.caller == caller + assert restored.caller == CALLER @pytest.mark.asyncio @mock.patch("airflow.providers.databricks.hooks.databricks.DatabricksHook.a_get_sql_statement_state")