diff --git a/generated/provider_dependencies.json.sha256sum b/generated/provider_dependencies.json.sha256sum new file mode 100644 index 0000000000000..943fd0fc93e4c --- /dev/null +++ b/generated/provider_dependencies.json.sha256sum @@ -0,0 +1 @@ +93831555f2a141e481c81c147142aeb860c34ea860163ca130d045e5ecd0a83b diff --git a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py index f0b36e7b41f9b..b3af126763f8e 100644 --- a/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/decorators/test_snowpark.py @@ -23,7 +23,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.common.compat.sdk import task @@ -72,12 +72,12 @@ def func1(session: Session): def func2(): return number - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: _ = [func1(), func2()] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task_obj in dag.tasks: + ti = dag_maker.run_ti(task_obj.task_id, dr) assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 2 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2 @@ -124,12 +124,12 @@ def func2(number: int, session: Session): def func3(number: int): return number - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: _ = [func1(number=number), func2(number=number), func3(number=number)] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task_obj in dag.tasks: + ti = dag_maker.run_ti(task_obj.task_id, dr) assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 3 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3 @@ -148,12 +148,12 @@ def test_snowpark_decorator_no_return(self, mock_snowflake_hook, dag_maker): def func(session: Session): assert session == mock_snowflake_hook.return_value.get_snowpark_session.return_value - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: func() dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task_obj in dag.tasks: + ti = dag_maker.run_ti(task_obj.task_id, dr) assert ti.xcom_pull() is None mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @@ -182,12 +182,12 @@ def func(session: Session): assert run_task.xcom.get(key="b") == "2" assert run_task.xcom.get(key="return_value") == {"a": 1, "b": "2"} else: - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: func() dr = dag_maker.create_dagrun() - ti = dr.get_task_instances()[0] - ti.run() + task_obj = dag.tasks[0] + ti = dag_maker.run_ti(task_obj.task_id, dr) assert ti.xcom_pull(key="a") == 1 assert ti.xcom_pull(key="b") == "2" assert ti.xcom_pull() == {"a": 1, "b": "2"} @@ -217,12 +217,12 @@ def update_query_tag(new_tags): def func(session: Session): return session.query_tag - with dag_maker(dag_id=TEST_DAG_ID): + with dag_maker(dag_id=TEST_DAG_ID) as dag: func() dr = dag_maker.create_dagrun() - ti = dr.get_task_instances()[0] - ti.run() + task_obj = dag.tasks[0] + ti = dag_maker.run_ti(task_obj.task_id, dr) query_tag = ti.xcom_pull() assert query_tag == { "dag_id": TEST_DAG_ID, diff --git a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py index 062b3112b39de..38fcdffe13fc1 100644 --- a/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/operators/test_snowpark.py @@ -23,7 +23,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.snowflake.operators.snowpark import SnowparkOperator from airflow.utils import timezone @@ -69,9 +69,10 @@ def func2(): ] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() == number + assert mock_snowflake_hook.call_count == 2 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 2 @@ -109,8 +110,8 @@ def func3(number: int): ] dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() == number assert mock_snowflake_hook.call_count == 3 assert mock_snowflake_hook.return_value.get_snowpark_session.call_count == 3 @@ -135,8 +136,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - for ti in dr.get_task_instances(): - ti.run() + for task in dag.tasks: + ti = dag_maker.run_ti(task.task_id, dr) assert ti.xcom_pull() is None mock_snowflake_hook.assert_called_once() mock_snowflake_hook.return_value.get_snowpark_session.assert_called_once() @@ -170,8 +171,8 @@ def func(session: Session): ) dr = dag_maker.create_dagrun() - ti = dr.get_task_instances()[0] - ti.run() + task = dag.tasks[0] + ti = dag_maker.run_ti(task.task_id, dr) query_tag = ti.xcom_pull() assert query_tag == { "dag_id": TEST_DAG_ID, @@ -179,3 +180,36 @@ def func(session: Session): "task_id": TASK_ID, "operator": "SnowparkOperator", } + + @mock.patch("airflow.providers.snowflake.operators.snowpark.SnowflakeHook") + def test_snowpark_operator_closes_session_on_exception( + self, + mock_snowflake_hook, + dag_maker, + ): + mock_session = mock_snowflake_hook.return_value.get_snowpark_session.return_value + + with dag_maker(dag_id=TEST_DAG_ID) as dag: + + def func(session: Session): + raise ValueError("boom") + + SnowparkOperator( + task_id=TASK_ID, + snowflake_conn_id=CONN_ID, + python_callable=func, + warehouse="test_warehouse", + database="test_database", + schema="test_schema", + role="test_role", + authenticator="externalbrowser", + dag=dag, + ) + + dr = dag_maker.create_dagrun() + task = dag.tasks[0] + + with pytest.raises(ValueError, match="boom"): + dag_maker.run_ti(task.task_id, dr) + + mock_session.close.assert_called_once() diff --git a/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py b/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py index 181d8ed8f3aa2..c8c9c7945ea39 100644 --- a/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py +++ b/providers/snowflake/tests/unit/snowflake/utils/test_snowpark.py @@ -18,7 +18,7 @@ import pytest -pytest.importorskip("snowflake-snowpark-python") +pytest.importorskip("snowflake.snowpark") from airflow.providers.snowflake.utils.snowpark import inject_session_into_op_kwargs