From 347dbe92211db64cf862e314f1f69df47a373f8a Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 27 Mar 2024 10:36:55 +0100 Subject: [PATCH 01/10] refactor: Moved placeholder property from OdbcHook class to parent DbApiHook class so that the same logic can also be used with the JdbcHook --- airflow/providers/common/sql/hooks/sql.py | 16 +++- airflow/providers/odbc/hooks/odbc.py | 16 ---- tests/providers/common/sql/hooks/test_sql.py | 35 ++++---- tests/providers/conftest.py | 30 +++++++ tests/providers/odbc/hooks/test_odbc.py | 92 +++++++------------- 5 files changed, 95 insertions(+), 94 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 3b9e6fba1e63a..5c389dd9a45d6 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -50,6 +50,7 @@ from airflow.providers.openlineage.sqlparser import DatabaseInfo T = TypeVar("T") +DEFAULT_SQL_PLACEHOLDERS = frozenset({"%s", "?"}) def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool): @@ -173,8 +174,19 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa ) @property - def placeholder(self) -> str: - return self._placeholder + def placeholder(self): + conn = self.get_connection(getattr(self, self.conn_name_attr)) + placeholder = conn.extra_dejson.get("placeholder") + if placeholder in DEFAULT_SQL_PLACEHOLDERS: + return placeholder + else: + self.log.warning( + "Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' " + "and got ignored. Falling back to the default placeholder '%s'.", + placeholder, + self._placeholder, + ) + return self._placeholder def get_conn(self): """Return a connection object.""" diff --git a/airflow/providers/odbc/hooks/odbc.py b/airflow/providers/odbc/hooks/odbc.py index a14e64d6df658..8cf95bf095f31 100644 --- a/airflow/providers/odbc/hooks/odbc.py +++ b/airflow/providers/odbc/hooks/odbc.py @@ -27,8 +27,6 @@ from airflow.providers.common.sql.hooks.sql import DbApiHook from airflow.utils.helpers import merge_dicts -DEFAULT_ODBC_PLACEHOLDERS = frozenset({"%s", "?"}) - class OdbcHook(DbApiHook): """ @@ -202,20 +200,6 @@ def get_conn(self) -> Connection: conn = connect(self.odbc_connection_string, **self.connect_kwargs) return conn - @property - def placeholder(self): - placeholder = self.connection.extra_dejson.get("placeholder") - if placeholder in DEFAULT_ODBC_PLACEHOLDERS: - return placeholder - else: - self.log.warning( - "Placeholder defined in Connection '%s' is not listed in 'DEFAULT_ODBC_PLACEHOLDERS' " - "and got ignored. Falling back to the default placeholder '%s'.", - placeholder, - self._placeholder, - ) - return self._placeholder - def get_uri(self) -> str: """URI invoked in :meth:`~airflow.providers.common.sql.hooks.sql.DbApiHook.get_sqlalchemy_engine`.""" quoted_conn_str = quote_plus(self.odbc_connection_string) diff --git a/tests/providers/common/sql/hooks/test_sql.py b/tests/providers/common/sql/hooks/test_sql.py index eea912340241a..0277d1703d62c 100644 --- a/tests/providers/common/sql/hooks/test_sql.py +++ b/tests/providers/common/sql/hooks/test_sql.py @@ -25,6 +25,7 @@ from airflow.models import Connection from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler from airflow.utils.session import provide_session +from tests.providers.conftest import mock_hook TASK_ID = "sql-operator" HOST = "host" @@ -211,18 +212,22 @@ def mock_execute(*args, **kwargs): dbapi_hook.get_conn.return_value.cursor.return_value.close.assert_called() -@pytest.mark.db_test -@pytest.mark.parametrize( - "empty_statement", - [ - pytest.param([], id="Empty list"), - pytest.param("", id="Empty string"), - pytest.param("\n", id="Only EOL"), - ], -) -def test_no_query(empty_statement): - dbapi_hook = DBApiHookForTests() - dbapi_hook.get_conn.return_value.cursor.rowcount = 0 - with pytest.raises(ValueError) as err: - dbapi_hook.run(sql=empty_statement) - assert err.value.args[0] == "List of SQL statements is empty" +class TestDbApiHook: + @pytest.mark.db_test + @pytest.mark.parametrize( + "empty_statement", + [ + pytest.param([], id="Empty list"), + pytest.param("", id="Empty string"), + pytest.param("\n", id="Only EOL"), + ], + ) + def test_no_query(self, empty_statement): + dbapi_hook = mock_hook(DbApiHook) + with pytest.raises(ValueError) as err: + dbapi_hook.run(sql=empty_statement) + assert err.value.args[0] == "List of SQL statements is empty" + + def test_placeholder_config_from_extra(self): + dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "?"}}) + assert dbapi_hook.placeholder == "?" diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py index 7dd0079ae6c7a..6b65175aee5fb 100644 --- a/tests/providers/conftest.py +++ b/tests/providers/conftest.py @@ -21,6 +21,7 @@ import pytest +from airflow.hooks.base import BaseHook from airflow.models import Connection @@ -55,3 +56,32 @@ def hook_conn(request): ) yield m + + +def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None): + hook_params = hook_params or {} + conn_params = conn_params or {} + connection = Connection( + **{ + **dict(login="login", password="password", host="host", schema="schema", port=1234), + **conn_params, + } + ) + + cursor = mock.MagicMock( + rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"] + ) + conn = mock.MagicMock() + conn.cursor.return_value = cursor + + class MockedHook(hook_class): + conn_name_attr = "test_conn_id" + + @classmethod + def get_connection(cls, conn_id: str): + return connection + + def get_conn(self): + return conn + + return MockedHook(**hook_params) diff --git a/tests/providers/odbc/hooks/test_odbc.py b/tests/providers/odbc/hooks/test_odbc.py index 64683d21e1971..20c44bd899903 100644 --- a/tests/providers/odbc/hooks/test_odbc.py +++ b/tests/providers/odbc/hooks/test_odbc.py @@ -27,8 +27,8 @@ import pyodbc import pytest -from airflow.models import Connection from airflow.providers.odbc.hooks.odbc import OdbcHook +from tests.providers.conftest import mock_hook @pytest.fixture @@ -77,38 +77,10 @@ class PyodbcRow(metaclass=PyodbcRowMeta): class TestOdbcHook: - def get_hook(self=None, hook_params=None, conn_params=None): - hook_params = hook_params or {} - conn_params = conn_params or {} - connection = Connection( - **{ - **dict(login="login", password="password", host="host", schema="schema", port=1234), - **conn_params, - } - ) - - cursor = mock.MagicMock( - rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"] - ) - conn = mock.MagicMock() - conn.cursor.return_value = cursor - - class UnitTestOdbcHook(OdbcHook): - conn_name_attr = "test_conn_id" - - @classmethod - def get_connection(cls, conn_id: str): - return connection - - def get_conn(self): - return conn - - return UnitTestOdbcHook(**hook_params) - def test_driver_in_extra_not_used(self): conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver", Fake_Param="Fake Param"))) hook_params = {"driver": "ParamDriver"} - hook = self.get_hook(conn_params=conn_params, hook_params=hook_params) + hook = mock_hook(OdbcHook, conn_params=conn_params, hook_params=hook_params) expected = ( "DRIVER={ParamDriver};" "SERVER=host;" @@ -123,7 +95,7 @@ def test_driver_in_extra_not_used(self): def test_driver_in_both(self): conn_params = dict(extra=json.dumps(dict(Driver="Fake Driver", Fake_Param="Fake Param"))) hook_params = dict(driver="ParamDriver") - hook = self.get_hook(hook_params=hook_params, conn_params=conn_params) + hook = mock_hook(OdbcHook, conn_params=conn_params, hook_params=hook_params) expected = ( "DRIVER={ParamDriver};" "SERVER=host;" @@ -137,7 +109,7 @@ def test_driver_in_both(self): def test_dsn_in_extra(self): conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake Param"))) - hook = self.get_hook(conn_params=conn_params) + hook = mock_hook(OdbcHook, conn_params=conn_params) expected = ( "DSN=MyDSN;SERVER=host;DATABASE=schema;UID=login;PWD=password;PORT=1234;Fake_Param=Fake Param;" ) @@ -146,7 +118,7 @@ def test_dsn_in_extra(self): def test_dsn_in_both(self): conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake Param"))) hook_params = dict(driver="ParamDriver", dsn="ParamDSN") - hook = self.get_hook(hook_params=hook_params, conn_params=conn_params) + hook = mock_hook(OdbcHook, conn_params=conn_params, hook_params=hook_params) expected = ( "DRIVER={ParamDriver};" "DSN=ParamDSN;" @@ -162,7 +134,7 @@ def test_dsn_in_both(self): def test_get_uri(self): conn_params = dict(extra=json.dumps(dict(DSN="MyDSN", Fake_Param="Fake Param"))) hook_params = dict(dsn="ParamDSN") - hook = self.get_hook(hook_params=hook_params, conn_params=conn_params) + hook = mock_hook(OdbcHook, conn_params=conn_params, hook_params=hook_params) uri_param = quote_plus( "DSN=ParamDSN;SERVER=host;DATABASE=schema;UID=login;PWD=password;PORT=1234;Fake_Param=Fake Param;" ) @@ -170,7 +142,8 @@ def test_get_uri(self): assert hook.get_uri() == expected def test_connect_kwargs_from_hook(self): - hook = self.get_hook( + hook = mock_hook( + OdbcHook, hook_params=dict( connect_kwargs={ "attrs_before": { @@ -202,7 +175,7 @@ def test_connect_kwargs_from_conn(self): ) ) - hook = self.get_hook(conn_params=dict(extra=extra)) + hook = mock_hook(OdbcHook, conn_params=dict(extra=extra)) assert hook.connect_kwargs == { "attrs_before": {1: 2, pyodbc.SQL_TXN_ISOLATION: pyodbc.SQL_TXN_READ_UNCOMMITTED}, "readonly": True, @@ -219,7 +192,7 @@ def test_connect_kwargs_from_conn_and_hook(self): connect_kwargs={"attrs_before": {3: 5, pyodbc.SQL_TXN_ISOLATION: 0}, "readonly": True} ) - hook = self.get_hook(conn_params=dict(extra=conn_extra), hook_params=hook_params) + hook = mock_hook(OdbcHook, conn_params=dict(extra=conn_extra), hook_params=hook_params) assert hook.connect_kwargs == { "attrs_before": {1: 2, 3: 5, pyodbc.SQL_TXN_ISOLATION: 0}, "readonly": True, @@ -230,74 +203,71 @@ def test_connect_kwargs_bool_from_uri(self): Bools will be parsed from uri as strings """ conn_extra = json.dumps(dict(connect_kwargs={"ansi": True})) - hook = self.get_hook(conn_params=dict(extra=conn_extra)) + hook = mock_hook(OdbcHook, conn_params=dict(extra=conn_extra)) assert hook.connect_kwargs == { "ansi": True, } def test_driver(self): - hook = self.get_hook(hook_params=dict(driver="Blah driver")) + hook = mock_hook(OdbcHook, hook_params=dict(driver="Blah driver")) assert hook.driver == "Blah driver" - hook = self.get_hook(hook_params=dict(driver="{Blah driver}")) + hook = mock_hook(OdbcHook, hook_params=dict(driver="{Blah driver}")) assert hook.driver == "Blah driver" def test_driver_extra_raises_warning_by_default(self, caplog): with caplog.at_level(logging.WARNING, logger="airflow.providers.odbc.hooks.test_odbc"): - driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver"}')).driver + driver = mock_hook(OdbcHook, conn_params=dict(extra='{"driver": "Blah driver"}')).driver assert "You have supplied 'driver' via connection extra but it will not be used" in caplog.text assert driver is None @mock.patch.dict("os.environ", {"AIRFLOW__PROVIDERS_ODBC__ALLOW_DRIVER_IN_EXTRA": "TRUE"}) def test_driver_extra_works_when_allow_driver_extra(self): - hook = self.get_hook( - conn_params=dict(extra='{"driver": "Blah driver"}'), hook_params=dict(allow_driver_extra=True) + hook = mock_hook( + OdbcHook, + conn_params=dict(extra='{"driver": "Blah driver"}'), + hook_params=dict(allow_driver_extra=True), ) assert hook.driver == "Blah driver" def test_default_driver_set(self): with patch.object(OdbcHook, "default_driver", "Blah driver"): - hook = self.get_hook() + hook = mock_hook(OdbcHook) assert hook.driver == "Blah driver" def test_driver_extra_works_when_default_driver_set(self): with patch.object(OdbcHook, "default_driver", "Blah driver"): - hook = self.get_hook() + hook = mock_hook(OdbcHook) assert hook.driver == "Blah driver" def test_driver_none_by_default(self): - hook = self.get_hook() + hook = mock_hook(OdbcHook) assert hook.driver is None def test_driver_extra_raises_warning_and_returns_default_driver_by_default(self, caplog): with patch.object(OdbcHook, "default_driver", "Blah driver"): with caplog.at_level(logging.WARNING, logger="airflow.providers.odbc.hooks.test_odbc"): - driver = self.get_hook(conn_params=dict(extra='{"driver": "Blah driver2"}')).driver + driver = mock_hook(OdbcHook, conn_params=dict(extra='{"driver": "Blah driver2"}')).driver assert "have supplied 'driver' via connection extra but it will not be used" in caplog.text assert driver == "Blah driver" - def test_placeholder_config_from_extra(self): - conn_params = dict(extra=json.dumps(dict(placeholder="?"))) - hook = self.get_hook(conn_params=conn_params) - assert hook.placeholder == "?" - def test_database(self): - hook = self.get_hook(hook_params=dict(database="abc")) + hook = mock_hook(OdbcHook, hook_params=dict(database="abc")) assert hook.database == "abc" - hook = self.get_hook() + hook = mock_hook(OdbcHook) assert hook.database == "schema" def test_sqlalchemy_scheme_default(self): - hook = self.get_hook() + hook = mock_hook(OdbcHook) uri = hook.get_uri() assert urlsplit(uri).scheme == "mssql+pyodbc" def test_sqlalchemy_scheme_param(self): - hook = self.get_hook(hook_params=dict(sqlalchemy_scheme="my-scheme")) + hook = mock_hook(OdbcHook, hook_params=dict(sqlalchemy_scheme="my-scheme")) uri = hook.get_uri() assert urlsplit(uri).scheme == "my-scheme" def test_sqlalchemy_scheme_extra(self): - hook = self.get_hook(conn_params=dict(extra=json.dumps(dict(sqlalchemy_scheme="my-scheme")))) + hook = mock_hook(OdbcHook, conn_params=dict(extra=json.dumps(dict(sqlalchemy_scheme="my-scheme")))) uri = hook.get_uri() assert urlsplit(uri).scheme == "my-scheme" @@ -323,7 +293,7 @@ def test_query_return_serializable_result_with_fetchall( def mock_handler(*_): return pyodbc_result - hook = self.get_hook() + hook = mock_hook(OdbcHook) with monkeypatch.context() as patcher: patcher.setattr("pyodbc.Row", pyodbc_instancecheck) result = hook.run("SQL", handler=mock_handler) @@ -340,7 +310,7 @@ def test_query_return_serializable_result_empty(self, pyodbc_row_mock, monkeypat def mock_handler(*_): return pyodbc_result - hook = self.get_hook() + hook = mock_hook(OdbcHook) with monkeypatch.context() as patcher: patcher.setattr("pyodbc.Row", pyodbc_instancecheck) result = hook.run("SQL", handler=mock_handler) @@ -359,13 +329,13 @@ def test_query_return_serializable_result_with_fetchone( def mock_handler(*_): return pyodbc_result - hook = self.get_hook() + hook = mock_hook(OdbcHook) with monkeypatch.context() as patcher: patcher.setattr("pyodbc.Row", pyodbc_instancecheck) result = hook.run("SQL", handler=mock_handler) assert hook_result == result def test_query_no_handler_return_none(self): - hook = self.get_hook() + hook = mock_hook(OdbcHook) result = hook.run("SQL") assert result is None From 22b8efee0122f0e8c1a48ea9ce744844b996a80d Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 27 Mar 2024 11:15:43 +0100 Subject: [PATCH 02/10] refactor: Import BaseHook under type checking block --- tests/providers/conftest.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py index 6b65175aee5fb..6a2b6e7f4be28 100644 --- a/tests/providers/conftest.py +++ b/tests/providers/conftest.py @@ -17,13 +17,16 @@ from __future__ import annotations +from typing import TYPE_CHECKING from unittest import mock import pytest -from airflow.hooks.base import BaseHook from airflow.models import Connection +if TYPE_CHECKING: + from airflow.hooks.base import BaseHook + @pytest.fixture def hook_conn(request): From faf832b8812d73968cf99656b0eed6ba983e2d4d Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 27 Mar 2024 13:42:51 +0100 Subject: [PATCH 03/10] refactor: Marked test_placeholder_config_from_extra as a db test --- tests/providers/common/sql/hooks/test_sql.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/providers/common/sql/hooks/test_sql.py b/tests/providers/common/sql/hooks/test_sql.py index 0277d1703d62c..4d9e4d12e5e39 100644 --- a/tests/providers/common/sql/hooks/test_sql.py +++ b/tests/providers/common/sql/hooks/test_sql.py @@ -228,6 +228,7 @@ def test_no_query(self, empty_statement): dbapi_hook.run(sql=empty_statement) assert err.value.args[0] == "List of SQL statements is empty" + @pytest.mark.db_test def test_placeholder_config_from_extra(self): dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "?"}}) assert dbapi_hook.placeholder == "?" From 1e6daa9afd60be9f7d504998485b247656c4e3ee Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 27 Mar 2024 13:47:29 +0100 Subject: [PATCH 04/10] refactor: Moved mock_conn from conftest to test_utils module under common sql --- tests/providers/common/sql/hooks/test_sql.py | 2 +- tests/providers/common/sql/test_utils.py | 55 ++++++++++++++++++++ tests/providers/conftest.py | 33 ------------ tests/providers/odbc/hooks/test_odbc.py | 2 +- 4 files changed, 57 insertions(+), 35 deletions(-) create mode 100644 tests/providers/common/sql/test_utils.py diff --git a/tests/providers/common/sql/hooks/test_sql.py b/tests/providers/common/sql/hooks/test_sql.py index 4d9e4d12e5e39..b244f7ae1896a 100644 --- a/tests/providers/common/sql/hooks/test_sql.py +++ b/tests/providers/common/sql/hooks/test_sql.py @@ -25,7 +25,7 @@ from airflow.models import Connection from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler from airflow.utils.session import provide_session -from tests.providers.conftest import mock_hook +from tests.providers.common.sql.test_utils import mock_hook TASK_ID = "sql-operator" HOST = "host" diff --git a/tests/providers/common/sql/test_utils.py b/tests/providers/common/sql/test_utils.py new file mode 100644 index 0000000000000..c3bc4c356556f --- /dev/null +++ b/tests/providers/common/sql/test_utils.py @@ -0,0 +1,55 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest import mock + +from airflow.models import Connection + +if TYPE_CHECKING: + from airflow.hooks.base import BaseHook + + +def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None): + hook_params = hook_params or {} + conn_params = conn_params or {} + connection = Connection( + **{ + **dict(login="login", password="password", host="host", schema="schema", port=1234), + **conn_params, + } + ) + + cursor = mock.MagicMock( + rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"] + ) + conn = mock.MagicMock() + conn.cursor.return_value = cursor + + class MockedHook(hook_class): + conn_name_attr = "test_conn_id" + + @classmethod + def get_connection(cls, conn_id: str): + return connection + + def get_conn(self): + return conn + + return MockedHook(**hook_params) diff --git a/tests/providers/conftest.py b/tests/providers/conftest.py index 6a2b6e7f4be28..7dd0079ae6c7a 100644 --- a/tests/providers/conftest.py +++ b/tests/providers/conftest.py @@ -17,16 +17,12 @@ from __future__ import annotations -from typing import TYPE_CHECKING from unittest import mock import pytest from airflow.models import Connection -if TYPE_CHECKING: - from airflow.hooks.base import BaseHook - @pytest.fixture def hook_conn(request): @@ -59,32 +55,3 @@ def hook_conn(request): ) yield m - - -def mock_hook(hook_class: type[BaseHook], hook_params=None, conn_params=None): - hook_params = hook_params or {} - conn_params = conn_params or {} - connection = Connection( - **{ - **dict(login="login", password="password", host="host", schema="schema", port=1234), - **conn_params, - } - ) - - cursor = mock.MagicMock( - rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"] - ) - conn = mock.MagicMock() - conn.cursor.return_value = cursor - - class MockedHook(hook_class): - conn_name_attr = "test_conn_id" - - @classmethod - def get_connection(cls, conn_id: str): - return connection - - def get_conn(self): - return conn - - return MockedHook(**hook_params) diff --git a/tests/providers/odbc/hooks/test_odbc.py b/tests/providers/odbc/hooks/test_odbc.py index 20c44bd899903..bddd2ffd996b6 100644 --- a/tests/providers/odbc/hooks/test_odbc.py +++ b/tests/providers/odbc/hooks/test_odbc.py @@ -28,7 +28,7 @@ import pytest from airflow.providers.odbc.hooks.odbc import OdbcHook -from tests.providers.conftest import mock_hook +from tests.providers.common.sql.test_utils import mock_hook @pytest.fixture From 7b3e53557e886333c43fff9e7575d6898baef8c1 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 27 Mar 2024 13:52:41 +0100 Subject: [PATCH 05/10] refactor: Removed unnecessary else statement in placeholder property --- airflow/providers/common/sql/hooks/sql.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 5c389dd9a45d6..d7b7f7f4ddea3 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -179,14 +179,13 @@ def placeholder(self): placeholder = conn.extra_dejson.get("placeholder") if placeholder in DEFAULT_SQL_PLACEHOLDERS: return placeholder - else: - self.log.warning( - "Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' " - "and got ignored. Falling back to the default placeholder '%s'.", - placeholder, - self._placeholder, - ) - return self._placeholder + self.log.warning( + "Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' " + "and got ignored. Falling back to the default placeholder '%s'.", + placeholder, + self._placeholder, + ) + return self._placeholder def get_conn(self): """Return a connection object.""" From e9ad88a2a5be563ae71304f9ac88f1e8cbc76986 Mon Sep 17 00:00:00 2001 From: David Blain Date: Wed, 27 Mar 2024 13:55:57 +0100 Subject: [PATCH 06/10] refactor: Default placeholder can be a class/static variable as it's only purpose is to define a default SQL placeholder, the actual placeholder will always be retrieved through the property --- airflow/providers/common/sql/hooks/sql.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index d7b7f7f4ddea3..01129c1f3430c 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -147,6 +147,8 @@ class DbApiHook(BaseHook): connector: ConnectorProtocol | None = None # Override with db-specific query to check connection _test_connection_sql = "select 1" + # Default SQL placeholder + _placeholder: str = "%s" def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwargs): super().__init__() @@ -165,7 +167,6 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa self.__schema = schema self.log_sql = log_sql self.descriptions: list[Sequence[Sequence] | None] = [] - self._placeholder: str = "%s" self._insert_statement_format: str = kwargs.get( "insert_statement_format", "INSERT INTO {} {} VALUES ({})" ) From a79de0b4d66a17f36a5c9a0987f4717c00789ac1 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 29 Mar 2024 09:20:15 +0100 Subject: [PATCH 07/10] refactor: Updated sql test with changes from main --- tests/providers/common/sql/hooks/test_sql.py | 23 ++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/providers/common/sql/hooks/test_sql.py b/tests/providers/common/sql/hooks/test_sql.py index b244f7ae1896a..b8c9c403704de 100644 --- a/tests/providers/common/sql/hooks/test_sql.py +++ b/tests/providers/common/sql/hooks/test_sql.py @@ -18,9 +18,12 @@ # from __future__ import annotations +import warnings +from typing import Any from unittest.mock import MagicMock import pytest +from airflow.exceptions import AirflowProviderDeprecationWarning from airflow.models import Connection from airflow.providers.common.sql.hooks.sql import DbApiHook, fetch_all_handler @@ -228,6 +231,26 @@ def test_no_query(self, empty_statement): dbapi_hook.run(sql=empty_statement) assert err.value.args[0] == "List of SQL statements is empty" + @pytest.mark.db_test + def test_make_common_data_structure_hook_has_deprecated_method(self): + """If hook implements ``_make_serializable`` warning should be raised on call.""" + + class DBApiHookForMakeSerializableTests(DBApiHookForTests): + def _make_serializable(self, result: Any): + return result + + hook = DBApiHookForMakeSerializableTests() + with pytest.warns(AirflowProviderDeprecationWarning, + match="`_make_serializable` method is deprecated"): + hook._make_common_data_structure(["foo", "bar", "baz"]) + + @pytest.mark.db_test + def test_make_common_data_structure_no_deprecated_method(self): + """If hook not implements ``_make_serializable`` there is no warning should be raised on call.""" + with warnings.catch_warnings(): + warnings.simplefilter("error", AirflowProviderDeprecationWarning) + DBApiHookForTests()._make_common_data_structure(["foo", "bar", "baz"]) + @pytest.mark.db_test def test_placeholder_config_from_extra(self): dbapi_hook = mock_hook(DbApiHook, conn_params={"extra": {"placeholder": "?"}}) From 20cae1353c8327e4e63d265c05b9cf5ca520cbd8 Mon Sep 17 00:00:00 2001 From: David Blain Date: Sat, 30 Mar 2024 09:37:32 +0100 Subject: [PATCH 08/10] refactor: Reformatted test --- tests/providers/common/sql/hooks/test_sql.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/providers/common/sql/hooks/test_sql.py b/tests/providers/common/sql/hooks/test_sql.py index 9decf4edfbee0..4bd5bdcc549f3 100644 --- a/tests/providers/common/sql/hooks/test_sql.py +++ b/tests/providers/common/sql/hooks/test_sql.py @@ -235,8 +235,9 @@ def test_make_common_data_structure_hook_has_deprecated_method(self): """If hook implements ``_make_serializable`` warning should be raised on call.""" hook = mock_hook(DbApiHook) hook._make_serializable = lambda result: result - with pytest.warns(AirflowProviderDeprecationWarning, - match="`_make_serializable` method is deprecated"): + with pytest.warns( + AirflowProviderDeprecationWarning, match="`_make_serializable` method is deprecated" + ): hook._make_common_data_structure(["foo", "bar", "baz"]) @pytest.mark.db_test From aee9483ada525c65f172af8cdf841be8c5366c11 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 4 Apr 2024 11:16:09 +0200 Subject: [PATCH 09/10] Update airflow/providers/common/sql/hooks/sql.py Co-authored-by: Elad Kalif <45845474+eladkal@users.noreply.github.com> --- airflow/providers/common/sql/hooks/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 71b399a0b1e05..cdc85fe1bc793 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -50,7 +50,7 @@ from airflow.providers.openlineage.sqlparser import DatabaseInfo T = TypeVar("T") -DEFAULT_SQL_PLACEHOLDERS = frozenset({"%s", "?"}) +SQL_PLACEHOLDERS = frozenset({"%s", "?"}) def return_single_query_results(sql: str | Iterable[str], return_last: bool, split_statements: bool): From ad3d8983861cdcdaa0fbe3d4c7b5ef1f32f7b2d5 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 4 Apr 2024 15:19:47 +0200 Subject: [PATCH 10/10] fix: Fixed name of constant SQL_PLACEHOLDERS being checked in placeholder property --- airflow/providers/common/sql/hooks/sql.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index cdc85fe1bc793..3f324e4f697a3 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -178,7 +178,7 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa def placeholder(self): conn = self.get_connection(getattr(self, self.conn_name_attr)) placeholder = conn.extra_dejson.get("placeholder") - if placeholder in DEFAULT_SQL_PLACEHOLDERS: + if placeholder in SQL_PLACEHOLDERS: return placeholder self.log.warning( "Placeholder defined in Connection '%s' is not listed in 'DEFAULT_SQL_PLACEHOLDERS' "