From 069cd090cee8029e51004d507cafceaa59c8469b Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Sun, 21 Aug 2022 19:00:26 +0300 Subject: [PATCH 1/2] Discard semicolon stripping in SQL hook --- airflow/providers/common/sql/hooks/sql.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index c94e16e0f25d6..0cac27188ca64 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -231,10 +231,6 @@ def get_first(self, sql: Union[str, List[str]], parameters=None): cur.execute(sql) return cur.fetchone() - @staticmethod - def strip_sql_string(sql: str) -> str: - return sql.strip().rstrip(';') - @staticmethod def split_sql_string(sql: str) -> List[str]: """ @@ -244,9 +240,7 @@ def split_sql_string(sql: str) -> List[str]: :return: list of individual expressions """ splits = sqlparse.split(sqlparse.format(sql, strip_comments=True)) - statements: List[str] = list( - filter(None, [s.rstrip(';').strip() if s.endswith(';') else s.strip() for s in splits]) - ) + statements: List[str] = list(filter(None, splits)) return statements def run( @@ -278,7 +272,7 @@ def run( if split_statements: sql = self.split_sql_string(sql) else: - sql = [self.strip_sql_string(sql)] + sql = [sql] if sql: self.log.debug("Executing following statements against DB: %s", list(sql)) From 58fe7cd7e7e301c3510e9ec337059c7db9af43e0 Mon Sep 17 00:00:00 2001 From: Dmytro Kazanzhy Date: Mon, 22 Aug 2022 00:58:29 +0300 Subject: [PATCH 2/2] Fixed tests --- airflow/providers/common/sql/hooks/sql.py | 4 ++++ tests/providers/common/sql/hooks/test_sqlparse.py | 13 +++++-------- .../databricks/hooks/test_databricks_sql.py | 2 +- tests/providers/oracle/hooks/test_oracle.py | 8 ++++---- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/airflow/providers/common/sql/hooks/sql.py b/airflow/providers/common/sql/hooks/sql.py index 0cac27188ca64..56a43f92bb258 100644 --- a/airflow/providers/common/sql/hooks/sql.py +++ b/airflow/providers/common/sql/hooks/sql.py @@ -231,6 +231,10 @@ def get_first(self, sql: Union[str, List[str]], parameters=None): cur.execute(sql) return cur.fetchone() + @staticmethod + def strip_sql_string(sql: str) -> str: + return sql.strip().rstrip(';') + @staticmethod def split_sql_string(sql: str) -> List[str]: """ diff --git a/tests/providers/common/sql/hooks/test_sqlparse.py b/tests/providers/common/sql/hooks/test_sqlparse.py index 3137f3cd21942..142eed9d25b77 100644 --- a/tests/providers/common/sql/hooks/test_sqlparse.py +++ b/tests/providers/common/sql/hooks/test_sqlparse.py @@ -24,16 +24,13 @@ "line,parsed_statements", [ ('SELECT * FROM table', ['SELECT * FROM table']), - ('SELECT * FROM table;', ['SELECT * FROM table']), - ('SELECT * FROM table; # comment', ['SELECT * FROM table']), - ('SELECT * FROM table; # comment;', ['SELECT * FROM table']), - (' SELECT * FROM table ; # comment;', ['SELECT * FROM table']), - ('SELECT * FROM table;;;;;', ['SELECT * FROM table']), - ('SELECT * FROM table;;# comment;;;', ['SELECT * FROM table']), - ('SELECT * FROM table;;# comment;; ;', ['SELECT * FROM table']), + ('SELECT * FROM table;', ['SELECT * FROM table;']), + ('SELECT * FROM table; # comment', ['SELECT * FROM table;']), + ('SELECT * FROM table; # comment;', ['SELECT * FROM table;']), + (' SELECT * FROM table ; # comment;', ['SELECT * FROM table ;']), ( 'SELECT * FROM table; SELECT * FROM table2 # comment', - ['SELECT * FROM table', 'SELECT * FROM table2'], + ['SELECT * FROM table;', 'SELECT * FROM table2'], ), ], ) diff --git a/tests/providers/databricks/hooks/test_databricks_sql.py b/tests/providers/databricks/hooks/test_databricks_sql.py index d70d203e33f66..8d1b6cf6fb0eb 100644 --- a/tests/providers/databricks/hooks/test_databricks_sql.py +++ b/tests/providers/databricks/hooks/test_databricks_sql.py @@ -83,7 +83,7 @@ def test_query(self, mock_requests, mock_conn): assert schema == test_schema assert results == [] - cur.execute.assert_has_calls([mock.call(q) for q in [query.rstrip(';')]]) + cur.execute.assert_has_calls([mock.call(q) for q in [query]]) cur.close.assert_called() def test_no_query(self): diff --git a/tests/providers/oracle/hooks/test_oracle.py b/tests/providers/oracle/hooks/test_oracle.py index 254514bc9fe86..d33dbf79f721a 100644 --- a/tests/providers/oracle/hooks/test_oracle.py +++ b/tests/providers/oracle/hooks/test_oracle.py @@ -268,7 +268,7 @@ def getvalue(self): self.cur.bindvars = None result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END')] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END;')] assert result == parameters def test_callproc_dict(self): @@ -280,7 +280,7 @@ def getvalue(self): self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()} result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END', parameters)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)] assert result == parameters def test_callproc_list(self): @@ -292,7 +292,7 @@ def getvalue(self): self.cur.bindvars = list(map(bindvar, parameters)) result = self.db_hook.callproc('proc', True, parameters) - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END', parameters)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)] assert result == parameters def test_callproc_out_param(self): @@ -306,7 +306,7 @@ def bindvar(value): self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters] result = self.db_hook.callproc('proc', True, parameters) expected = [1, 0, 0.0, False, ''] - assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END', expected)] + assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)] assert result == expected def test_test_connection_use_dual_table(self):