From 513f4960e410ff821b81cedaa82097ba719dd902 Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Feb 2024 13:55:50 +0100 Subject: [PATCH 1/4] refactor: Added executemany parameter to insert_rows method so the fast executemany method on the cursor can be used to achieve better performance when inserting in bulk. Also check if dialect is SAP Hana in _generate_insert_sql method so the UPSERT statement can be used as REPLACE INTO doesn't exist on SAP Hana. --- airflow/providers/common/sql/hooks/sql.py | 65 ++++++++++------ .../providers/common/sql/hooks/test_dbapi.py | 76 +++++++++++++++---- 2 files changed, 102 insertions(+), 39 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 5ffc8f89b6eaa..a8cb5dfc44bdd 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -37,6 +37,7 @@ import sqlparse from deprecated import deprecated +from more_itertools import chunked from sqlalchemy import create_engine from airflow.exceptions import AirflowException, AirflowProviderDeprecationWarning @@ -48,7 +49,6 @@ from airflow.providers.openlineage.extractors import OperatorLineage from airflow.providers.openlineage.sqlparser import DatabaseInfo - T = TypeVar("T") @@ -486,17 +486,17 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) """ Generate the INSERT SQL statement. - The REPLACE variant is specific to MySQL syntax. + The REPLACE variant is specific to MySQL syntax, the UPSERT variant is specific to SAP Hana syntax :param table: Name of the target table :param values: The row to insert into the table :param target_fields: The names of the columns to fill in the table - :param replace: Whether to replace instead of insert - :return: The generated INSERT or REPLACE SQL statement + :param replace: Whether to replace/upsert instead of insert + :return: The generated INSERT or REPLACE/UPSERT SQL statement """ placeholders = [ - self.placeholder, - ] * len(values) + self.placeholder, + ] * len(values) if target_fields: target_fields = ", ".join(target_fields) @@ -504,14 +504,18 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) else: target_fields = "" + sql = f"{table} {target_fields} VALUES ({','.join(placeholders)})" + if not replace: - sql = "INSERT INTO " - else: - sql = "REPLACE INTO " - sql += f"{table} {target_fields} VALUES ({','.join(placeholders)})" - return sql + return f"INSERT INTO {sql}" + + if self.get_sqlalchemy_engine().dialect.name == "hana": + return f"UPSERT {sql} WITH PRIMARY KEY" - def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, **kwargs): + return f"REPLACE INTO {sql}" + + def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, + executemany=False, **kwargs): """Insert a collection of tuples into a table. Rows are inserted in chunks, each chunk (of size ``commit_every``) is @@ -523,6 +527,8 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replac :param commit_every: The maximum number of rows to insert in one transaction. Set to 0 to insert all rows in one transaction. :param replace: Whether to replace instead of insert + :param executemany: Insert all rows at once in chunks defined by the commit_every parameter, only + works if all rows have same number of column names but leads to better performance """ i = 0 with closing(self.get_conn()) as conn: @@ -532,19 +538,30 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replac conn.commit() with closing(conn.cursor()) as cur: - for i, row in enumerate(rows, 1): - lst = [] - for cell in row: - lst.append(self._serialize_cell(cell, conn)) - values = tuple(lst) - sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs) - self.log.debug("Generated sql: %s", sql) - cur.execute(sql, values) - if commit_every and i % commit_every == 0: + if executemany: + for chunked_rows in chunked(rows, commit_every): + values = list(map(lambda row: tuple(map(lambda cell: self._serialize_cell(cell, conn), row)), chunked_rows)) + sql = self._generate_insert_sql(table, values[0], target_fields, replace, **kwargs) + self.log.debug("Generated sql: %s", sql) + cur.fast_executemany = True + cur.executemany(sql, values) conn.commit() - self.log.info("Loaded %s rows into %s so far", i, table) - - conn.commit() + self.log.info("Loaded %s rows into %s so far", len(chunked_rows), table) + else: + for i, row in enumerate(rows, 1): + lst = [] + for cell in row: + lst.append(self._serialize_cell(cell, conn)) + values = tuple(lst) + sql = self._generate_insert_sql(table, values, target_fields, replace, **kwargs) + self.log.debug("Generated sql: %s", sql) + cur.execute(sql, values) + if commit_every and i % commit_every == 0: + conn.commit() + self.log.info("Loaded %s rows into %s so far", i, table) + + if not executemany: + conn.commit() self.log.info("Done loading. Loaded a total of %s rows into %s", i, table) @staticmethod diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index 45b40d8e48fdb..0e822be3b5f74 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -19,8 +19,10 @@ import json from unittest import mock +from unittest.mock import patch import pytest +from sqlalchemy.engine import Engine, Dialect from airflow.hooks.base import BaseHook from airflow.models import Connection @@ -38,22 +40,36 @@ class NonDbApiHook(BaseHook): class TestDbApiHook: def setup_method(self): self.cur = mock.MagicMock( - rowcount=0, spec=["description", "rowcount", "execute", "fetchall", "fetchone", "close"] + rowcount=0, spec=["description", "rowcount", "execute", "executemany", "fetchall", "fetchone", "close"] ) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur + self.conn.schema.return_value = "test_schema" conn = self.conn - class UnitTestDbApiHook(DbApiHook): + class DbApiHookMock(DbApiHook): conn_name_attr = "test_conn_id" log = mock.MagicMock() + @classmethod + def get_connection(cls, conn_id: str) -> Connection: + return conn + def get_conn(self): return conn - self.db_hook = UnitTestDbApiHook() - self.db_hook_no_log_sql = UnitTestDbApiHook(log_sql=False) - self.db_hook_schema_override = UnitTestDbApiHook(schema="schema-override") + self.db_hook = DbApiHookMock() + self.db_hook_no_log_sql = DbApiHookMock(log_sql=False) + self.db_hook_schema_override = DbApiHookMock(schema="schema-override") + + @staticmethod + def create_engine(dialect_name: str = "sqlite") -> mock.MagicMock: + # Mocking create_engine to return a mock engine + mock_dialect = mock.MagicMock(spec=Dialect) + mock_dialect.name = dialect_name + mock_engine = mock.MagicMock(spec=Engine) + mock_engine.dialect = mock_dialect + return mock.MagicMock(return_value=mock_engine) def test_get_records(self): statement = "SQL" @@ -108,20 +124,22 @@ def test_insert_rows(self): self.cur.execute.assert_any_call(sql, row) def test_insert_rows_replace(self): - table = "table" - rows = [("hello",), ("world",)] + # Patching create_engine in the module where it is used + with patch(f"{DbApiHook.__module__}.create_engine", self.create_engine()): + table = "table" + rows = [("hello",), ("world",)] - self.db_hook.insert_rows(table, rows, replace=True) + self.db_hook.insert_rows(table, rows, replace=True) - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 - commit_count = 2 # The first and last commit - assert commit_count == self.conn.commit.call_count + commit_count = 2 # The first and last commit + assert commit_count == self.conn.commit.call_count - sql = f"REPLACE INTO {table} VALUES (%s)" - for row in rows: - self.cur.execute.assert_any_call(sql, row) + sql = f"REPLACE INTO {table} VALUES (%s)" + for row in rows: + self.cur.execute.assert_any_call(sql, row) def test_insert_rows_target_fields(self): table = "table" @@ -157,6 +175,34 @@ def test_insert_rows_commit_every(self): for row in rows: self.cur.execute.assert_any_call(sql, row) + def test_insert_rows_executemany(self): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows, executemany=True) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + sql = f"INSERT INTO {table} VALUES (%s)" + self.cur.executemany.assert_any_call(sql, rows) + + def test_insert_rows_replace_executemany_hana_dialect(self): + # Patching create_engine in the module where it is used + with patch(f"{DbApiHook.__module__}.create_engine", self.create_engine(dialect_name="hana")): + table = "table" + rows = [("hello",), ("world",)] + + self.db_hook.insert_rows(table, rows, replace=True, executemany=True) + + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 + + sql = f"UPSERT {table} VALUES (%s) WITH PRIMARY KEY" + self.cur.executemany.assert_any_call(sql, rows) + def test_get_uri_schema_not_none(self): self.db_hook.get_connection = mock.MagicMock( return_value=Connection( From 8f8e6d11ba5a1e12e915f220bb6e5f647958c7aa Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Feb 2024 16:13:29 +0100 Subject: [PATCH 2/4] refactor: Reformatted code to be static check compliant --- airflow/providers/common/sql/hooks/sql.py | 16 +++++++++++----- tests/providers/common/sql/hooks/test_dbapi.py | 5 +++-- 2 files changed, 14 insertions(+), 7 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index a8cb5dfc44bdd..917a659f4c96c 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -495,8 +495,8 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) :return: The generated INSERT or REPLACE/UPSERT SQL statement """ placeholders = [ - self.placeholder, - ] * len(values) + self.placeholder, + ] * len(values) if target_fields: target_fields = ", ".join(target_fields) @@ -514,8 +514,9 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) return f"REPLACE INTO {sql}" - def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replace=False, - executemany=False, **kwargs): + def insert_rows( + self, table, rows, target_fields=None, commit_every=1000, replace=False,executemany=False, **kwargs + ): """Insert a collection of tuples into a table. Rows are inserted in chunks, each chunk (of size ``commit_every``) is @@ -540,7 +541,12 @@ def insert_rows(self, table, rows, target_fields=None, commit_every=1000, replac with closing(conn.cursor()) as cur: if executemany: for chunked_rows in chunked(rows, commit_every): - values = list(map(lambda row: tuple(map(lambda cell: self._serialize_cell(cell, conn), row)), chunked_rows)) + values = list( + map( + lambda row: tuple(map(lambda cell: self._serialize_cell(cell, conn), row)), + chunked_rows, + ) + ) sql = self._generate_insert_sql(table, values[0], target_fields, replace, **kwargs) self.log.debug("Generated sql: %s", sql) cur.fast_executemany = True diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index 0e822be3b5f74..2607652823482 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -22,7 +22,7 @@ from unittest.mock import patch import pytest -from sqlalchemy.engine import Engine, Dialect +from sqlalchemy.engine import Dialect, Engine from airflow.hooks.base import BaseHook from airflow.models import Connection @@ -40,7 +40,8 @@ class NonDbApiHook(BaseHook): class TestDbApiHook: def setup_method(self): self.cur = mock.MagicMock( - rowcount=0, spec=["description", "rowcount", "execute", "executemany", "fetchall", "fetchone", "close"] + rowcount=0, + spec=["description", "rowcount", "execute", "executemany", "fetchall", "fetchone", "close"], ) self.conn = mock.MagicMock() self.conn.cursor.return_value = self.cur From eae37b073579ad9d520914d6cee65c480b896ecb Mon Sep 17 00:00:00 2001 From: David Blain Date: Thu, 8 Feb 2024 20:12:25 +0100 Subject: [PATCH 3/4] refactor: Refactored DbApiHook so that insert and replace statements are parametrized # Conflicts: # tests/providers/common/sql/hooks/test_dbapi.py --- airflow/providers/common/sql/hooks/sql.py | 13 ++--- .../providers/common/sql/hooks/test_dbapi.py | 56 +++++++------------ 2 files changed, 27 insertions(+), 42 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 917a659f4c96c..1001780cc22ab 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -165,6 +165,10 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa self.log_sql = log_sql self.descriptions: list[Sequence[Sequence] | None] = [] self._placeholder: str = "%s" + self._insert_statement_format: str = kwargs.get("insert_statement_format", + "INSERT INTO {} {} VALUES ({})") + self._replace_statement_format: str = kwargs.get("replace_statement_format", + "REPLACE INTO {} {} VALUES ({})") @property def placeholder(self) -> str: @@ -504,15 +508,10 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) else: target_fields = "" - sql = f"{table} {target_fields} VALUES ({','.join(placeholders)})" - if not replace: - return f"INSERT INTO {sql}" - - if self.get_sqlalchemy_engine().dialect.name == "hana": - return f"UPSERT {sql} WITH PRIMARY KEY" + return self._insert_statement_format.format(table, target_fields, ",".join(placeholders)) - return f"REPLACE INTO {sql}" + return self._replace_statement_format.format(table, target_fields, ",".join(placeholders)) def insert_rows( self, table, rows, target_fields=None, commit_every=1000, replace=False,executemany=False, **kwargs diff --git a/tests/providers/common/sql/hooks/test_dbapi.py b/tests/providers/common/sql/hooks/test_dbapi.py index 2607652823482..fd9886345fd6b 100644 --- a/tests/providers/common/sql/hooks/test_dbapi.py +++ b/tests/providers/common/sql/hooks/test_dbapi.py @@ -19,10 +19,8 @@ import json from unittest import mock -from unittest.mock import patch import pytest -from sqlalchemy.engine import Dialect, Engine from airflow.hooks.base import BaseHook from airflow.models import Connection @@ -38,7 +36,7 @@ class NonDbApiHook(BaseHook): class TestDbApiHook: - def setup_method(self): + def setup_method(self, **kwargs): self.cur = mock.MagicMock( rowcount=0, spec=["description", "rowcount", "execute", "executemany", "fetchall", "fetchone", "close"], @@ -59,19 +57,10 @@ def get_connection(cls, conn_id: str) -> Connection: def get_conn(self): return conn - self.db_hook = DbApiHookMock() + self.db_hook = DbApiHookMock(**kwargs) self.db_hook_no_log_sql = DbApiHookMock(log_sql=False) self.db_hook_schema_override = DbApiHookMock(schema="schema-override") - @staticmethod - def create_engine(dialect_name: str = "sqlite") -> mock.MagicMock: - # Mocking create_engine to return a mock engine - mock_dialect = mock.MagicMock(spec=Dialect) - mock_dialect.name = dialect_name - mock_engine = mock.MagicMock(spec=Engine) - mock_engine.dialect = mock_dialect - return mock.MagicMock(return_value=mock_engine) - def test_get_records(self): statement = "SQL" rows = [("hello",), ("world",)] @@ -125,22 +114,20 @@ def test_insert_rows(self): self.cur.execute.assert_any_call(sql, row) def test_insert_rows_replace(self): - # Patching create_engine in the module where it is used - with patch(f"{DbApiHook.__module__}.create_engine", self.create_engine()): - table = "table" - rows = [("hello",), ("world",)] + table = "table" + rows = [("hello",), ("world",)] - self.db_hook.insert_rows(table, rows, replace=True) + self.db_hook.insert_rows(table, rows, replace=True) - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 - commit_count = 2 # The first and last commit - assert commit_count == self.conn.commit.call_count + commit_count = 2 # The first and last commit + assert commit_count == self.conn.commit.call_count - sql = f"REPLACE INTO {table} VALUES (%s)" - for row in rows: - self.cur.execute.assert_any_call(sql, row) + sql = f"REPLACE INTO {table} VALUES (%s)" + for row in rows: + self.cur.execute.assert_any_call(sql, row) def test_insert_rows_target_fields(self): table = "table" @@ -190,19 +177,18 @@ def test_insert_rows_executemany(self): self.cur.executemany.assert_any_call(sql, rows) def test_insert_rows_replace_executemany_hana_dialect(self): - # Patching create_engine in the module where it is used - with patch(f"{DbApiHook.__module__}.create_engine", self.create_engine(dialect_name="hana")): - table = "table" - rows = [("hello",), ("world",)] + self.setup_method(replace_statement_format="UPSERT {} {} VALUES ({}) WITH PRIMARY KEY") + table = "table" + rows = [("hello",), ("world",)] - self.db_hook.insert_rows(table, rows, replace=True, executemany=True) + self.db_hook.insert_rows(table, rows, replace=True, executemany=True) - assert self.conn.close.call_count == 1 - assert self.cur.close.call_count == 1 - assert self.conn.commit.call_count == 2 + assert self.conn.close.call_count == 1 + assert self.cur.close.call_count == 1 + assert self.conn.commit.call_count == 2 - sql = f"UPSERT {table} VALUES (%s) WITH PRIMARY KEY" - self.cur.executemany.assert_any_call(sql, rows) + sql = f"UPSERT {table} VALUES (%s) WITH PRIMARY KEY" + self.cur.executemany.assert_any_call(sql, rows) def test_get_uri_schema_not_none(self): self.db_hook.get_connection = mock.MagicMock( From 1ae15fd65a6ca7e908911b19bba69f50feae81d3 Mon Sep 17 00:00:00 2001 From: David Blain Date: Fri, 9 Feb 2024 16:52:07 +0100 Subject: [PATCH 4/4] refactor: Fixed some formatting in DbApiHook --- airflow/providers/common/sql/hooks/sql.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 1001780cc22ab..1da7d2c739a84 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -165,10 +165,12 @@ def __init__(self, *args, schema: str | None = None, log_sql: bool = True, **kwa self.log_sql = log_sql self.descriptions: list[Sequence[Sequence] | None] = [] self._placeholder: str = "%s" - self._insert_statement_format: str = kwargs.get("insert_statement_format", - "INSERT INTO {} {} VALUES ({})") - self._replace_statement_format: str = kwargs.get("replace_statement_format", - "REPLACE INTO {} {} VALUES ({})") + self._insert_statement_format: str = kwargs.get( + "insert_statement_format", "INSERT INTO {} {} VALUES ({})" + ) + self._replace_statement_format: str = kwargs.get( + "replace_statement_format", "REPLACE INTO {} {} VALUES ({})" + ) @property def placeholder(self) -> str: @@ -514,7 +516,7 @@ def _generate_insert_sql(self, table, values, target_fields, replace, **kwargs) return self._replace_statement_format.format(table, target_fields, ",".join(placeholders)) def insert_rows( - self, table, rows, target_fields=None, commit_every=1000, replace=False,executemany=False, **kwargs + self, table, rows, target_fields=None, commit_every=1000, replace=False, executemany=False, **kwargs ): """Insert a collection of tuples into a table.