Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions airflow/providers/common/sql/hooks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,9 +244,7 @@ def split_sql_string(sql: str) -> List[str]:
:return: list of individual expressions
"""
splits = sqlparse.split(sqlparse.format(sql, strip_comments=True))
statements: List[str] = list(
filter(None, [s.rstrip(';').strip() if s.endswith(';') else s.strip() for s in splits])
)

@kazanzhy kazanzhy Aug 21, 2022

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sqlparse.split already has .strip:
return [str(stmt).strip() for stmt in stack.run(sql, encoding)]

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But this filtration is necessary I think:

splits = ["SELECT ...", None, ""]
list(filter(None, splits))
Out[3]: ['SELECT ...']

statements: List[str] = list(filter(None, splits))
return statements

def run(
Expand Down Expand Up @@ -278,7 +276,7 @@ def run(
if split_statements:
sql = self.split_sql_string(sql)
else:
sql = [self.strip_sql_string(sql)]
sql = [sql]

if sql:
self.log.debug("Executing following statements against DB: %s", list(sql))
Expand Down
13 changes: 5 additions & 8 deletions tests/providers/common/sql/hooks/test_sqlparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,13 @@
"line,parsed_statements",
[
('SELECT * FROM table', ['SELECT * FROM table']),
('SELECT * FROM table;', ['SELECT * FROM table']),
('SELECT * FROM table; # comment', ['SELECT * FROM table']),
('SELECT * FROM table; # comment;', ['SELECT * FROM table']),
(' SELECT * FROM table ; # comment;', ['SELECT * FROM table']),
('SELECT * FROM table;;;;;', ['SELECT * FROM table']),
('SELECT * FROM table;;# comment;;;', ['SELECT * FROM table']),
('SELECT * FROM table;;# comment;; ;', ['SELECT * FROM table']),

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did I remove these test cases?

In the previous PR, some functionality was added that tried to make SQL query clearer.
But also it caused a few bugs. I think I tried to do some jobs for a developer and it was a bad idea.
I think we could add some functionality like split_statements but the developer should decide either to use it or not. But with semicolon stripping it wasn't that.

('SELECT * FROM table;', ['SELECT * FROM table;']),
('SELECT * FROM table; # comment', ['SELECT * FROM table;']),
('SELECT * FROM table; # comment;', ['SELECT * FROM table;']),
(' SELECT * FROM table ; # comment;', ['SELECT * FROM table ;']),
(
'SELECT * FROM table; SELECT * FROM table2 # comment',
['SELECT * FROM table', 'SELECT * FROM table2'],
['SELECT * FROM table;', 'SELECT * FROM table2'],
),
],
)
Expand Down
2 changes: 1 addition & 1 deletion tests/providers/databricks/hooks/test_databricks_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_query(self, mock_requests, mock_conn):
assert schema == test_schema
assert results == []

cur.execute.assert_has_calls([mock.call(q) for q in [query.rstrip(';')]])
cur.execute.assert_has_calls([mock.call(q) for q in [query]])
cur.close.assert_called()

def test_no_query(self):
Expand Down
8 changes: 4 additions & 4 deletions tests/providers/oracle/hooks/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def getvalue(self):

self.cur.bindvars = None
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END')]
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(); END;')]
assert result == parameters

def test_callproc_dict(self):
Expand All @@ -280,7 +280,7 @@ def getvalue(self):

self.cur.bindvars = {k: bindvar(v) for k, v in parameters.items()}
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END', parameters)]
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:a,:b,:c); END;', parameters)]
assert result == parameters

def test_callproc_list(self):
Expand All @@ -292,7 +292,7 @@ def getvalue(self):

self.cur.bindvars = list(map(bindvar, parameters))
result = self.db_hook.callproc('proc', True, parameters)
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END', parameters)]
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3); END;', parameters)]
assert result == parameters

def test_callproc_out_param(self):
Expand All @@ -306,7 +306,7 @@ def bindvar(value):
self.cur.bindvars = [bindvar(p() if type(p) is type else p) for p in parameters]
result = self.db_hook.callproc('proc', True, parameters)
expected = [1, 0, 0.0, False, '']
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END', expected)]
assert self.cur.execute.mock_calls == [mock.call('BEGIN proc(:1,:2,:3,:4,:5); END;', expected)]
assert result == expected

def test_test_connection_use_dual_table(self):
Expand Down