diff --git a/tests/deprecations_ignore.yml b/tests/deprecations_ignore.yml index cf3ad35905505..4d451038ca02b 100644 --- a/tests/deprecations_ignore.yml +++ b/tests/deprecations_ignore.yml @@ -235,8 +235,4 @@ - tests/providers/mysql/operators/test_mysql.py::TestMySql::test_mysql_operator_test_multi - tests/providers/mysql/operators/test_mysql.py::TestMySql::test_overwrite_schema - tests/providers/mysql/operators/test_mysql.py::test_execute_openlineage_events -- tests/providers/snowflake/operators/test_snowflake.py::TestSnowflakeOperator::test_snowflake_operator -- tests/providers/snowflake/operators/test_snowflake.py::TestSnowflakeOperatorForParams::test_overwrite_params -- tests/providers/snowflake/operators/test_snowflake_sql.py::test_exec_success -- tests/providers/snowflake/operators/test_snowflake_sql.py::test_execute_openlineage_events - tests/providers/trino/operators/test_trino.py::test_execute_openlineage_events diff --git a/tests/providers/snowflake/operators/test_snowflake.py b/tests/providers/snowflake/operators/test_snowflake.py index 7f429277b9268..e24e8ca9db6f7 100644 --- a/tests/providers/snowflake/operators/test_snowflake.py +++ b/tests/providers/snowflake/operators/test_snowflake.py @@ -26,10 +26,10 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator from airflow.providers.snowflake.operators.snowflake import ( SnowflakeCheckOperator, SnowflakeIntervalCheckOperator, - SnowflakeOperator, SnowflakeSqlApiOperator, SnowflakeValueCheckOperator, ) @@ -68,7 +68,9 @@ def test_snowflake_operator(self, mock_get_db_hook): dummy VARCHAR(50) ); """ - operator = SnowflakeOperator(task_id="basic_snowflake", sql=sql, dag=self.dag, do_xcom_push=False) + operator = SQLExecuteQueryOperator( + task_id="basic_snowflake", sql=sql, dag=self.dag, do_xcom_push=False, conn_id="snowflake_default" + ) # do_xcom_push=False because otherwise the XCom test will fail due to the mocking (it actually works) operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True) @@ -77,16 +79,18 @@ class TestSnowflakeOperatorForParams: @mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__") def test_overwrite_params(self, mock_base_op): sql = "Select * from test_table" - SnowflakeOperator( + SQLExecuteQueryOperator( sql=sql, task_id="snowflake_params_check", - snowflake_conn_id="snowflake_default", - warehouse="test_warehouse", - database="test_database", - role="test_role", - schema="test_schema", - authenticator="oath", - session_parameters={"QUERY_TAG": "test_tag"}, + conn_id="snowflake_default", + hook_params={ + "warehouse": "test_warehouse", + "database": "test_database", + "role": "test_role", + "schema": "test_schema", + "authenticator": "oath", + "session_parameters": {"QUERY_TAG": "test_tag"}, + }, ) mock_base_op.assert_called_once_with( conn_id="snowflake_default", diff --git a/tests/providers/snowflake/operators/test_snowflake_sql.py b/tests/providers/snowflake/operators/test_snowflake_sql.py index 87d77ca813df7..e3955d4540470 100644 --- a/tests/providers/snowflake/operators/test_snowflake_sql.py +++ b/tests/providers/snowflake/operators/test_snowflake_sql.py @@ -23,6 +23,8 @@ import pytest from _pytest.outcomes import importorskip +from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator + databricks = importorskip("databricks") try: @@ -45,7 +47,6 @@ def Row(*args, **kwargs): from airflow.models.connection import Connection from airflow.providers.common.sql.hooks.sql import fetch_all_handler from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook -from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator DATE = "2017-04-20" TASK_ID = "databricks-sql-operator" @@ -61,7 +62,7 @@ def Row(*args, **kwargs): True, [Row(id=1, value="value1"), Row(id=2, value="value2")], [[("id",), ("value",)]], - ([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]), + ([Row(id=1, value="value1"), Row(id=2, value="value2")]), id="Scalar: Single SQL statement, return_last, split statement", ), pytest.param( @@ -70,7 +71,7 @@ def Row(*args, **kwargs): True, [Row(id=1, value="value1"), Row(id=2, value="value2")], [[("id",), ("value",)]], - ([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]), + ([Row(id=1, value="value1"), Row(id=2, value="value2")]), id="Scalar: Multiple SQL statements, return_last, split statement", ), pytest.param( @@ -79,7 +80,7 @@ def Row(*args, **kwargs): False, [Row(id=1, value="value1"), Row(id=2, value="value2")], [[("id",), ("value",)]], - ([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]), + ([Row(id=1, value="value1"), Row(id=2, value="value2")]), id="Scalar: Single SQL statements, no return_last (doesn't matter), no split statement", ), pytest.param( @@ -88,7 +89,7 @@ def Row(*args, **kwargs): False, [Row(id=1, value="value1"), Row(id=2, value="value2")], [[("id",), ("value",)]], - ([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]), + ([Row(id=1, value="value1"), Row(id=2, value="value2")]), id="Scalar: Single SQL statements, return_last (doesn't matter), no split statement", ), pytest.param( @@ -97,7 +98,7 @@ def Row(*args, **kwargs): False, [[Row(id=1, value="value1"), Row(id=2, value="value2")]], [[("id",), ("value",)]], - [([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}])], + [([Row(id=1, value="value1"), Row(id=2, value="value2")])], id="Non-Scalar: Single SQL statements in list, no return_last, no split statement", ), pytest.param( @@ -110,8 +111,8 @@ def Row(*args, **kwargs): ], [[("id",), ("value",)], [("id2",), ("value2",)]], [ - ([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]), - ([{"id2": 1, "value2": "value1"}, {"id2": 2, "value2": "value2"}]), + ([Row(id=1, value="value1"), Row(id=2, value="value2")]), + ([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]), ], id="Non-Scalar: Multiple SQL statements in list, no return_last (no matter), no split statement", ), @@ -125,8 +126,8 @@ def Row(*args, **kwargs): ], [[("id",), ("value",)], [("id2",), ("value2",)]], [ - ([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]), - ([{"id2": 1, "value2": "value1"}, {"id2": 2, "value2": "value2"}]), + ([Row(id=1, value="value1"), Row(id=2, value="value2")]), + ([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]), ], id="Non-Scalar: Multiple SQL statements in list, return_last (no matter), no split statement", ), @@ -137,12 +138,13 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc Test the execute function in case where SQL query was successful. """ with patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") as get_db_hook_mock: - op = SnowflakeOperator( + op = SQLExecuteQueryOperator( task_id=TASK_ID, sql=sql, do_xcom_push=True, return_last=return_last, split_statements=split_statement, + conn_id="snowflake_default", ) dbapi_hook = MagicMock() get_db_hook_mock.return_value = dbapi_hook @@ -177,7 +179,7 @@ class SnowflakeHookForTests(SnowflakeHook): dbapi_hook = SnowflakeHookForTests() - class SnowflakeOperatorForTest(SnowflakeOperator): + class SnowflakeOperatorForTest(SQLExecuteQueryOperator): def get_db_hook(self): return dbapi_hook