Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions tests/deprecations_ignore.yml
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,4 @@
- tests/providers/mysql/operators/test_mysql.py::TestMySql::test_mysql_operator_test_multi
- tests/providers/mysql/operators/test_mysql.py::TestMySql::test_overwrite_schema
- tests/providers/mysql/operators/test_mysql.py::test_execute_openlineage_events
- tests/providers/snowflake/operators/test_snowflake.py::TestSnowflakeOperator::test_snowflake_operator
- tests/providers/snowflake/operators/test_snowflake.py::TestSnowflakeOperatorForParams::test_overwrite_params
- tests/providers/snowflake/operators/test_snowflake_sql.py::test_exec_success
- tests/providers/snowflake/operators/test_snowflake_sql.py::test_execute_openlineage_events
- tests/providers/trino/operators/test_trino.py::test_execute_openlineage_events
24 changes: 14 additions & 10 deletions tests/providers/snowflake/operators/test_snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@
from airflow.models.dag import DAG
from airflow.models.dagrun import DagRun
from airflow.models.taskinstance import TaskInstance
from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator
from airflow.providers.snowflake.operators.snowflake import (
SnowflakeCheckOperator,
SnowflakeIntervalCheckOperator,
SnowflakeOperator,
SnowflakeSqlApiOperator,
SnowflakeValueCheckOperator,
)
Expand Down Expand Up @@ -68,7 +68,9 @@ def test_snowflake_operator(self, mock_get_db_hook):
dummy VARCHAR(50)
);
"""
operator = SnowflakeOperator(task_id="basic_snowflake", sql=sql, dag=self.dag, do_xcom_push=False)
operator = SQLExecuteQueryOperator(
task_id="basic_snowflake", sql=sql, dag=self.dag, do_xcom_push=False, conn_id="snowflake_default"
)
# do_xcom_push=False because otherwise the XCom test will fail due to the mocking (it actually works)
operator.run(start_date=DEFAULT_DATE, end_date=DEFAULT_DATE, ignore_ti_state=True)

Expand All @@ -77,16 +79,18 @@ class TestSnowflakeOperatorForParams:
@mock.patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.__init__")
def test_overwrite_params(self, mock_base_op):
sql = "Select * from test_table"
SnowflakeOperator(
SQLExecuteQueryOperator(
sql=sql,
task_id="snowflake_params_check",
snowflake_conn_id="snowflake_default",
warehouse="test_warehouse",
database="test_database",
role="test_role",
schema="test_schema",
authenticator="oath",
session_parameters={"QUERY_TAG": "test_tag"},
conn_id="snowflake_default",
hook_params={
"warehouse": "test_warehouse",
"database": "test_database",
"role": "test_role",
"schema": "test_schema",
"authenticator": "oath",
"session_parameters": {"QUERY_TAG": "test_tag"},
},
)
mock_base_op.assert_called_once_with(
conn_id="snowflake_default",
Expand Down
26 changes: 14 additions & 12 deletions tests/providers/snowflake/operators/test_snowflake_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
import pytest
from _pytest.outcomes import importorskip

from airflow.providers.common.sql.operators.sql import SQLExecuteQueryOperator

databricks = importorskip("databricks")

try:
Expand All @@ -45,7 +47,6 @@ def Row(*args, **kwargs):
from airflow.models.connection import Connection
from airflow.providers.common.sql.hooks.sql import fetch_all_handler
from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook
from airflow.providers.snowflake.operators.snowflake import SnowflakeOperator

DATE = "2017-04-20"
TASK_ID = "databricks-sql-operator"
Expand All @@ -61,7 +62,7 @@ def Row(*args, **kwargs):
True,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[[("id",), ("value",)]],
([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]),
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
id="Scalar: Single SQL statement, return_last, split statement",
),
pytest.param(
Expand All @@ -70,7 +71,7 @@ def Row(*args, **kwargs):
True,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[[("id",), ("value",)]],
([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]),
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
id="Scalar: Multiple SQL statements, return_last, split statement",
),
pytest.param(
Expand All @@ -79,7 +80,7 @@ def Row(*args, **kwargs):
False,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[[("id",), ("value",)]],
([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]),
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
id="Scalar: Single SQL statements, no return_last (doesn't matter), no split statement",
),
pytest.param(
Expand All @@ -88,7 +89,7 @@ def Row(*args, **kwargs):
False,
[Row(id=1, value="value1"), Row(id=2, value="value2")],
[[("id",), ("value",)]],
([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]),
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
id="Scalar: Single SQL statements, return_last (doesn't matter), no split statement",
),
pytest.param(
Expand All @@ -97,7 +98,7 @@ def Row(*args, **kwargs):
False,
[[Row(id=1, value="value1"), Row(id=2, value="value2")]],
[[("id",), ("value",)]],
[([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}])],
[([Row(id=1, value="value1"), Row(id=2, value="value2")])],
id="Non-Scalar: Single SQL statements in list, no return_last, no split statement",
),
pytest.param(
Expand All @@ -110,8 +111,8 @@ def Row(*args, **kwargs):
],
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]),
([{"id2": 1, "value2": "value1"}, {"id2": 2, "value2": "value2"}]),
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]),
],
id="Non-Scalar: Multiple SQL statements in list, no return_last (no matter), no split statement",
),
Expand All @@ -125,8 +126,8 @@ def Row(*args, **kwargs):
],
[[("id",), ("value",)], [("id2",), ("value2",)]],
[
([{"id": 1, "value": "value1"}, {"id": 2, "value": "value2"}]),
([{"id2": 1, "value2": "value1"}, {"id2": 2, "value2": "value2"}]),
([Row(id=1, value="value1"), Row(id=2, value="value2")]),
([Row(id2=1, value2="value1"), Row(id2=2, value2="value2")]),
],
id="Non-Scalar: Multiple SQL statements in list, return_last (no matter), no split statement",
),
Expand All @@ -137,12 +138,13 @@ def test_exec_success(sql, return_last, split_statement, hook_results, hook_desc
Test the execute function in case where SQL query was successful.
"""
with patch("airflow.providers.common.sql.operators.sql.BaseSQLOperator.get_db_hook") as get_db_hook_mock:
op = SnowflakeOperator(
op = SQLExecuteQueryOperator(
task_id=TASK_ID,
sql=sql,
do_xcom_push=True,
return_last=return_last,
split_statements=split_statement,
conn_id="snowflake_default",
)
dbapi_hook = MagicMock()
get_db_hook_mock.return_value = dbapi_hook
Expand Down Expand Up @@ -177,7 +179,7 @@ class SnowflakeHookForTests(SnowflakeHook):

dbapi_hook = SnowflakeHookForTests()

class SnowflakeOperatorForTest(SnowflakeOperator):
class SnowflakeOperatorForTest(SQLExecuteQueryOperator):
def get_db_hook(self):
return dbapi_hook

Expand Down