From 378076845d5167270ebebed620b62530e14773fc Mon Sep 17 00:00:00 2001 From: Sameer Mesiah Date: Fri, 19 Jun 2026 21:04:24 +0100 Subject: [PATCH] Add a test verifying that SnowparkOperator closes the Snowpark session when the user callable raises an exception. As part of enabling the new coverage, migrate the Snowpark tests away from the removed TaskInstance.run() API and update Snowpark test dependency detection to use the importable Snowpark module. --- .../provider_dependencies.json.sha256sum | 1 + .../snowflake/decorators/test_snowpark.py | 32 ++++++------ .../unit/snowflake/operators/test_snowpark.py | 52 +++++++++++++++---- .../unit/snowflake/utils/test_snowpark.py | 2 +- 4 files changed, 61 insertions(+), 26 deletions(-) create mode 100644 generated/provider_dependencies.json.sha256sum diff --git a/generated/provider_dependencies.json.sha256sum b/generated/provider_dependencies.json.sha256sum new file mode 100644 index 0000000000000..943fd0fc93e4c --- /dev/null +++ b/generated/provider_dependencies.json.sha256sum @@ -0,0 +1 @@ +93831555f2a141e481c81c147142aeb860c34ea860163ca130d045e5ecd0a83b diff --git a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py index f0b36e7b41f9b..b3af126763f8e 100644 --- a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py @@ -23,7 +23,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.common.compat.sdk import task @@ -72,12 +72,12 @@ def func1(session: Session): def func2(): return number - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: _ = [func1(), func2()] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task_obj in dag.tasks: + ti = dag_maker.run_ti(task_obj.task_id, dr) assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 2 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2 @@ -124,12 +124,12 @@ def func2(number: int, session: Session): def func3(number: int): return number - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: _ = [func1(number=number), func2(number=number), func3(number=number)] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task_obj in dag.tasks: + ti = dag_maker.run_ti(task_obj.task_id, dr) assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 3 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3 @@ -148,12 +148,12 @@ def test_snowpark_decorator_no_return(self, mock_snowflake_hook, dag_maker): def func(session: Session): assert session == mock_snowflake_hook.return_value.get_snowpark_session.return_value - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: func() dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task_obj in dag.tasks: + ti = dag_maker.run_ti(task_obj.task_id, dr) assert ti.xcom_pull() is None mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @@ -182,12 +182,12 @@ def func(session: Session): assert run_task.xcom.get(key="b") == "2" assert run_task.xcom.get(key="return_value") == {"a": 1, "b": "2"} else: - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: func() dr = dag_maker.create_dagrun() - ti = dr.get_task_instances()[0] - ti.run() + task_obj = dag.tasks[0] + ti = dag_maker.run_ti(task_obj.task_id, dr) assert ti.xcom_pull(key="a") == 1 assert ti.xcom_pull(key="b") == "2" assert ti.xcom_pull() == {"a": 1, "b": "2"} @@ -217,12 +217,12 @@ def update_query_tag(new_tags): def func(session: Session): return session.query_tag - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: func() dr = dag_maker.create_dagrun() - ti = dr.get_task_instances()[0] - ti.run() + task_obj = dag.tasks[0] + ti = dag_maker.run_ti(task_obj.task_id, dr) query_tag = ti.xcom_pull() assert query_tag == { "dag_id": TEST_DAG_ID, diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py index 062b3112b39de..38fcdffe13fc1 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py @@ -23,7 +23,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.snowflake.operators.snowpark import SnowparkOperator from airflow.utils import timezone @@ -69,9 +69,10 @@ def func2(): ] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() == number + assert mock_snowflake_hook.call_count == 2 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2 @@ -109,8 +110,8 @@ def func3(number: int): ] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 3 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3 @@ -135,8 +136,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() is None mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @@ -170,8 +171,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - ti = dr.get_task_instances()[0] - ti.run() + task = dag.tasks[0] + ti = dag_maker.run_ti(task.task_id, dr) query_tag = ti.xcom_pull() assert query_tag == { "dag_id": TEST_DAG_ID, @@ -179,3 +180,36 @@ def func(session: Session): "task_id": TASK_ID, "operator": "SnowparkOperator", } + + @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook") + def test_snowpark_operator_closes_session_on_exception( + self, + mock_snowflake_hook, + dag_maker, + ): + mock_session = mock_snowflake_hook.return_value.get_snowpark_session.return_value + + with dag_maker(dag_id=TEST_DAG_ID) as dag: + + def func(session: Session): + raise ValueError("boom") + + SnowparkOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + python_callable=func, + warehouse="test_warehouse", + database="test_database", + schema="test_schema", + role="test_role", + authenticator="externalbrowser", + dag=dag, + ) + + dr = dag_maker.create_dagrun() + task = dag.tasks[0] + + with pytest.raises(ValueError, match="boom"): + dag_maker.run_ti(task.task_id, dr) + + mock_session.close.assert_called_once() diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py index 181d8ed8f3aa2..c8c9c7945ea39 100644 --- a/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py @@ -18,7 +18,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.snowflake.utils.snowpark import inject_session_into_op_kwargs