diff --git a/sql_metadata/parser.py b/sql_metadata/parser.py index 00938bb8..887c5053 100644 --- a/sql_metadata/parser.py +++ b/sql_metadata/parser.py @@ -18,7 +18,7 @@ TABLE_ADJUSTMENT_KEYWORDS, WITH_ENDING_KEYWORDS, ) -from sql_metadata.token import SQLToken +from sql_metadata.token import SQLToken, EmptyToken from sql_metadata.utils import UniqueList @@ -123,6 +123,23 @@ def columns(self) -> List[str]: subqueries_names = self.subqueries_names for token in self.tokens: + # handle CREATE TABLE queries (#35) + if token.is_name and self._is_create_table_query: + # previous token is either ( or , -> indicates the column name + if token.is_in_parenthesis and token.previous_token.is_punctuation: + columns.append(str(token)) + continue + + # we're in CREATE TABLE query with the columns + # ignore any annotations outside the parenthesis with the list of columns + # e.g. ) CHARACTER SET utf8; + if ( + not token.is_in_parenthesis + and token.find_nearest_token("SELECT", value_attribute="normalized") + is EmptyToken + ): + continue + if token.is_name and not token.next_token.is_dot: # analyze the name tokens, column names and where condition values if ( @@ -335,6 +352,11 @@ def tables(self) -> List[str]: and token.previous_token.normalized not in ["AS", "WITH"] and token.normalized not in ["AS", "SELECT"] ): + # handle CREATE TABLE queries (#35) + # skip keyword that are withing parenthesis-wrapped list of column + if self._is_create_table_query and token.is_in_parenthesis: + continue + if token.next_token.is_dot: pass # part of the qualified name elif token.is_in_parenthesis and ( @@ -728,3 +750,16 @@ def _preprocess_query(self) -> str: query = re.sub(r"`([^`]+)`\.`([^`]+)`", r"\1.\2", query) return query + + @property + def _is_create_table_query(self) -> bool: + """ + Return True if the query begins with "CREATE TABLE" statement + """ + if ( + self.tokens[0].normalized == "CREATE" + and self.tokens[1].normalized == "TABLE" + ): + return True + + return False diff --git a/test/test_create_table.py b/test/test_create_table.py new file mode 100644 index 00000000..ac0a544b --- /dev/null +++ b/test/test_create_table.py @@ -0,0 +1,43 @@ +from sql_metadata import Parser + + +def test_is_create_table_query(): + assert Parser("BEGIN")._is_create_table_query is False + assert Parser("SELECT * FROM `foo` ()")._is_create_table_query is False + + assert Parser("CREATE TABLE `foo` ()")._is_create_table_query is True + assert ( + Parser( + "create table abc.foo as SELECT pqr.foo1 , ab.foo2 FROM foo pqr, bar ab" + )._is_create_table_query + is True + ) + + +def test_create_table(): + parser = Parser( + """ +CREATE TABLE `new_table` ( + `item_id` int(9) NOT NULL AUTO_INCREMENT, + `foo` varchar(16) NOT NULL DEFAULT '', + PRIMARY KEY (`item_id`,`foo`), + KEY `idx_foo` (`foo`) +) CHARACTER SET utf8; + """ + ) + + assert parser.tables == ["new_table"] + assert parser.columns == ["item_id", "foo"] + + +def test_create_table_as_select(): + parser = Parser( + """ +create table abc.foo + as SELECT pqr.foo1 , ab.foo2 + FROM foo pqr, bar ab; + """ + ) + + assert parser.tables == ["abc.foo", "foo", "bar"] + assert parser.columns == ["foo.foo1", "bar.foo2"]