diff --git a/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java b/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java index c0c48aa1692..29815d08857 100644 --- a/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java +++ b/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one or more contributor license * agreements. See the NOTICE file distributed with this work for additional information regarding * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the @@ -861,7 +861,7 @@ public List completion(String buf, int cursor, sqlCompleter = createOrUpdateSqlCompleter(sqlCompleter, connection, propertyKey, buf, cursor); sqlCompletersMap.put(sqlCompleterKey, sqlCompleter); - sqlCompleter.complete(buf, cursor, candidates); + sqlCompleter.fillCandidates(buf, cursor, candidates); return candidates; } diff --git a/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlCompleter.java b/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlCompleter.java index 9f52ecba491..77b559054a4 100644 --- a/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlCompleter.java +++ b/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlCompleter.java @@ -1,3 +1,18 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + package org.apache.zeppelin.jdbc; /* @@ -5,7 +20,6 @@ */ import org.apache.commons.lang.StringUtils; -import org.apache.commons.lang.math.NumberUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -18,6 +32,7 @@ import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.HashSet; import java.util.List; @@ -25,10 +40,6 @@ import java.util.Set; import java.util.StringTokenizer; import java.util.TreeSet; -import java.util.regex.Pattern; - -import jline.console.completer.ArgumentCompleter.ArgumentList; -import jline.console.completer.ArgumentCompleter.WhitespaceArgumentDelimiter; import org.apache.zeppelin.completer.CachedCompleter; import org.apache.zeppelin.completer.CompletionType; @@ -39,88 +50,73 @@ * SQL auto complete functionality for the JdbcInterpreter. */ public class SqlCompleter { + private static Logger logger = LoggerFactory.getLogger(SqlCompleter.class); /** - * Delimiter that can split SQL statement in keyword list. + * Completer for sql keywords. */ - private WhitespaceArgumentDelimiter sqlDelimiter = new WhitespaceArgumentDelimiter() { - - private Pattern pattern = Pattern.compile(","); - - @Override - public boolean isDelimiterChar(CharSequence buffer, int pos) { - return pattern.matcher("" + buffer.charAt(pos)).matches() - || super.isDelimiterChar(buffer, pos); - } - }; + private CachedCompleter keywordCompleter; /** * Schema completer. */ - private CachedCompleter schemasCompleter; + private CachedCompleter schemasCompleter; /** * Contain different completer with table list for every schema name. */ - private Map tablesCompleters = new HashMap<>(); + private Map> tablesCompleters = new HashMap<>(); /** - * Contains different completer with column list for every table name - * Table names store as schema_name.table_name. + * Contains different completer with column list for every table name. + * Table names store as schema_name.table_name */ - private Map columnsCompleters = new HashMap<>(); - - /** - * Completer for sql keywords. - */ - private CachedCompleter keywordCompleter; + private Map> columnsCompleters = new HashMap<>(); private int ttlInSeconds; - public SqlCompleter(int ttlInSeconds) { + private String defaultSchema; + + SqlCompleter(int ttlInSeconds) { this.ttlInSeconds = ttlInSeconds; } - public int complete(String buffer, int cursor, List candidates) { - logger.debug("Complete with buffer = " + buffer + ", cursor = " + cursor); - - // The delimiter breaks the buffer into separate words (arguments), separated by the - // white spaces. - ArgumentList argumentList = sqlDelimiter.delimit(buffer, cursor); - - Pattern whitespaceEndPatter = Pattern.compile("\\s$"); - String cursorArgument = null; - int argumentPosition; - if (buffer.length() == 0 || whitespaceEndPatter.matcher(buffer).find()) { - argumentPosition = buffer.length() - 1; - } else { - cursorArgument = argumentList.getCursorArgument(); - argumentPosition = argumentList.getArgumentPosition(); + /** + * Return schema for tables that can be written without schema in query. + * Typically it is enough to use getSchema() on connection, + * but for Oracle - getUserName() from DatabaseMetaData. + */ + private String getDefaultSchema(Connection conn, DatabaseMetaData meta) { + String defaultSchema = null; + try { + if ((defaultSchema = conn.getSchema()) == null) { + defaultSchema = conn.getCatalog(); + } + } catch (SQLException | AbstractMethodError e) { + logger.debug("Default schema is not defined" + e.getMessage()); + try { + defaultSchema = meta.getUserName(); + } catch (Exception ee) { + logger.debug("User name is not defined" + ee.getMessage()); + } } - - int complete = completeName(cursorArgument, argumentPosition, candidates, - findAliasesInSQL(argumentList.getArguments())); - - logger.debug("complete:" + complete + ", size:" + candidates.size()); - return complete; + return defaultSchema; } /** * Return list of schema names within the database. * - * @param meta metadata from connection to database + * @param meta metadata from connection to database * @param schemaFilters a schema name patterns; must match the schema name - * as it is stored in the database; "" retrieves those without a schema; - * null means that the schema name should not be used to narrow - * the search; supports '%'; for example "prod_v_%" + * as it is stored in the database; "" retrieves those without a schema; + * null means that the schema name should not be used to narrow + * the search; supports '%'; for example "prod_v_%" * @return set of all schema names in the database */ private static Set getSchemaNames(DatabaseMetaData meta, List schemaFilters) { Set res = new HashSet<>(); - try { - ResultSet schemas = meta.getSchemas(); - + try (ResultSet schemas = meta.getSchemas()) { try { while (schemas.next()) { String schemaName = schemas.getString("TABLE_SCHEM"); @@ -145,18 +141,17 @@ private static Set getSchemaNames(DatabaseMetaData meta, List sc /** * Return list of catalog names within the database. * - * @param meta metadata from connection to database + * @param meta metadata from connection to database * @param schemaFilters a catalog name patterns; must match the catalog name - * as it is stored in the database; "" retrieves those without a catalog; - * null means that the schema name should not be used to narrow - * the search; supports '%'; for example "prod_v_%" + * as it is stored in the database; "" retrieves those without a catalog; + * null means that the schema name should not be used to narrow + * the search; supports '%'; for example "prod_v_%" * @return set of all catalog names in the database */ private static Set getCatalogNames(DatabaseMetaData meta, List schemaFilters) { Set res = new HashSet<>(); try { - ResultSet schemas = meta.getCatalogs(); - try { + try (ResultSet schemas = meta.getCatalogs()) { while (schemas.next()) { String schemaName = schemas.getString("TABLE_CAT"); for (String schemaFilter : schemaFilters) { @@ -165,8 +160,6 @@ private static Set getCatalogNames(DatabaseMetaData meta, List s } } } - } finally { - schemas.close(); } } catch (SQLException t) { logger.error("Failed to retrieve the schema names", t); @@ -176,10 +169,15 @@ private static Set getCatalogNames(DatabaseMetaData meta, List s private static void fillTableNames(String schema, DatabaseMetaData meta, Set tables) { try (ResultSet tbls = meta.getTables(schema, schema, "%", - new String[]{"TABLE", "VIEW", "ALIAS", "SYNONYM", "GLOBAL TEMPORARY", "LOCAL TEMPORARY"})) { - while (tbls.next()) { - String table = tbls.getString("TABLE_NAME"); - tables.add(table); + new String[]{"TABLE", "VIEW", "ALIAS", "SYNONYM", "GLOBAL TEMPORARY", + "LOCAL TEMPORARY"})) { + if (!tbls.isBeforeFirst()) { + logger.debug("There is no tables for schema " + schema); + } else { + while (tbls.next()) { + String table = tbls.getString("TABLE_NAME"); + tables.add(table); + } } } catch (Throwable t) { logger.error("Failed to retrieve the table name", t); @@ -189,14 +187,14 @@ private static void fillTableNames(String schema, DatabaseMetaData meta, Set columns) { + Set columns) { try (ResultSet cols = meta.getColumns(schema, schema, table, "%")) { while (cols.next()) { String column = cols.getString("COLUMN_NAME"); @@ -207,31 +205,33 @@ private static void fillColumnNames(String schema, String table, DatabaseMetaDat } } - public static Set getSqlKeywordsCompletions(DatabaseMetaData meta) throws IOException, - SQLException { + private static Set getSqlKeywordsCompletions(DatabaseMetaData meta) throws IOException, + SQLException { + // Add the default SQL completions String keywords = - new BufferedReader(new InputStreamReader( - SqlCompleter.class.getResourceAsStream("/ansi.sql.keywords"))).readLine(); + new BufferedReader(new InputStreamReader( + SqlCompleter.class.getResourceAsStream("/ansi.sql.keywords"))).readLine(); Set completions = new TreeSet<>(); if (null != meta) { + // Add the driver specific SQL completions String driverSpecificKeywords = - "/" + meta.getDriverName().replace(" ", "-").toLowerCase() + "-sql.keywords"; + "/" + meta.getDriverName().replace(" ", "-").toLowerCase() + "-sql.keywords"; logger.info("JDBC DriverName:" + driverSpecificKeywords); try { if (SqlCompleter.class.getResource(driverSpecificKeywords) != null) { String driverKeywords = - new BufferedReader(new InputStreamReader( - SqlCompleter.class.getResourceAsStream(driverSpecificKeywords))) - .readLine(); + new BufferedReader(new InputStreamReader( + SqlCompleter.class.getResourceAsStream(driverSpecificKeywords))) + .readLine(); keywords += "," + driverKeywords.toUpperCase(); } } catch (Exception e) { - logger.debug("fail to get driver specific SQL completions for " + - driverSpecificKeywords + " : " + e, e); + logger.debug("fail to get driver specific SQL completions for " + + driverSpecificKeywords + " : " + e, e); } // Add the keywords from the current JDBC connection @@ -263,6 +263,7 @@ public static Set getSqlKeywordsCompletions(DatabaseMetaData meta) throw // Set all keywords to lower-case versions keywords = keywords.toLowerCase(); + } StringTokenizer tok = new StringTokenizer(keywords, ", "); @@ -273,119 +274,134 @@ public static Set getSqlKeywordsCompletions(DatabaseMetaData meta) throw return completions; } + private SqlStatement getStatementParameters(String buffer, int cursor) { + Collection schemas = schemasCompleter.getCompleter().getStrings(); + Collection keywords = keywordCompleter.getCompleter().getStrings(); + + Collection tablesInDefaultSchema = new TreeSet<>(); + if (tablesCompleters.containsKey(defaultSchema)) { + tablesInDefaultSchema = tablesCompleters.get(defaultSchema) + .getCompleter().getStrings(); + } + + + + return new SqlStatement(buffer, cursor, defaultSchema, schemas, + tablesInDefaultSchema, keywords); + } + /** * Initializes all local completers from database connection. * - * @param connection database connection + * @param connection database connection * @param schemaFiltersString a comma separated schema name patterns, supports '%' symbol; - * for example "prod_v_%,prod_t_%" + * for example "prod_v_%,prod_t_%" */ - public void createOrUpdateFromConnection(Connection connection, String schemaFiltersString, - String buffer, int cursor) { + + void createOrUpdateFromConnection(Connection connection, String schemaFiltersString, + String buffer, int cursor) { try (Connection c = connection) { if (schemaFiltersString == null) { schemaFiltersString = StringUtils.EMPTY; } List schemaFilters = Arrays.asList(schemaFiltersString.split(",")); - CursorArgument cursorArgument = parseCursorArgument(buffer, cursor); - Set tables = new HashSet<>(); - Set columns = new HashSet<>(); - Set schemas = new HashSet<>(); - Set catalogs = new HashSet<>(); - Set keywords = new HashSet<>(); if (c != null) { DatabaseMetaData databaseMetaData = c.getMetaData(); - if (keywordCompleter == null || keywordCompleter.getCompleter() == null - || keywordCompleter.isExpired()) { - keywords = getSqlKeywordsCompletions(databaseMetaData); + + //TODO(mebelousov): put defaultSchema in cache + if (defaultSchema == null) { + defaultSchema = getDefaultSchema(connection, databaseMetaData); + } + + if (keywordCompleter == null || keywordCompleter.getCompleter() == null) { + Set keywords = getSqlKeywordsCompletions(databaseMetaData); initKeywords(keywords); + logger.info("Keyword completer initialized with " + keywords.size() + " keywords"); } - if (cursorArgument.needLoadSchemas() && - (schemasCompleter == null || schemasCompleter.getCompleter() == null - || schemasCompleter.isExpired())) { - schemas = getSchemaNames(databaseMetaData, schemaFilters); - catalogs = getCatalogNames(databaseMetaData, schemaFilters); + + if (schemasCompleter == null || schemasCompleter.getCompleter() == null + || schemasCompleter.isExpired()) { + Set schemas = getSchemaNames(databaseMetaData, schemaFilters); + Set catalogs = getCatalogNames(databaseMetaData, schemaFilters); if (schemas.size() == 0) { schemas.addAll(catalogs); } initSchemas(schemas); + logger.info("Schema completer initialized with " + schemas.size() + " schemas"); } - CachedCompleter tablesCompleter = tablesCompleters.get(cursorArgument.getSchema()); - if (cursorArgument.needLoadTables() && - (tablesCompleter == null || tablesCompleter.isExpired())) { - fillTableNames(cursorArgument.getSchema(), databaseMetaData, tables); - initTables(cursorArgument.getSchema(), tables); - } + CachedCompleter tablesCompleterInDefaultSchema = tablesCompleters + .get(defaultSchema); - String schemaTable = - String.format("%s.%s", cursorArgument.getSchema(), cursorArgument.getTable()); - CachedCompleter columnsCompleter = columnsCompleters.get(schemaTable); + if (tablesCompleterInDefaultSchema == null || tablesCompleterInDefaultSchema.isExpired()) { + Set tables = new HashSet<>(); + fillTableNames(defaultSchema, databaseMetaData, tables); + initTables(defaultSchema, tables); + } - if (cursorArgument.needLoadColumns() && - (columnsCompleter == null || columnsCompleter.isExpired())) { - fillColumnNames(cursorArgument.getSchema(), cursorArgument.getTable(), databaseMetaData, - columns); - initColumns(schemaTable, columns); + SqlStatement sqlStatement = getStatementParameters(buffer, cursor); + + if (sqlStatement.needLoadTables()) { + String schema = sqlStatement.getSchema(); + CachedCompleter tablesCompleter = tablesCompleters.get(schema); + if (tablesCompleter == null || tablesCompleter.isExpired()) { + Set tables = new HashSet<>(); + fillTableNames(schema, databaseMetaData, tables); + initTables(schema, tables); + logger.info("Tables completer for schema " + schema + " initialized with " + + tables.size() + " tables"); + } } - logger.info("Completer initialized with " + schemas.size() + " schemas, " + - columns.size() + " tables and " + keywords.size() + " keywords"); + for (String schemaTable : sqlStatement.getActiveSchemaTables()) { + CachedCompleter columnsCompleter = columnsCompleters.get(schemaTable); + if (columnsCompleter == null || columnsCompleter.isExpired()) { + int pointPos = schemaTable.indexOf('.'); + Set columns = new HashSet<>(); + fillColumnNames(schemaTable.substring(0, pointPos), schemaTable.substring(pointPos + 1), + databaseMetaData, columns); + initColumns(schemaTable, columns); + logger.info("Completer for schemaTable " + schemaTable + " initialized with " + + columns.size() + " columns."); + } + } } - } catch (SQLException | IOException e) { - logger.error("Failed to update the metadata completions", e); + logger.error("Failed to update the metadata completions" + e.getMessage()); } } - public void initKeywords(Set keywords) { + void initKeywords(Set keywords) { if (keywords != null && !keywords.isEmpty()) { keywordCompleter = new CachedCompleter(new StringsCompleter(keywords), 0); } } - public void initSchemas(Set schemas) { + void initSchemas(Set schemas) { if (schemas != null && !schemas.isEmpty()) { schemasCompleter = new CachedCompleter( new StringsCompleter(new TreeSet<>(schemas)), ttlInSeconds); } } - public void initTables(String schema, Set tables) { + void initTables(String schema, Set tables) { if (tables != null && !tables.isEmpty()) { tablesCompleters.put(schema, new CachedCompleter( new StringsCompleter(new TreeSet<>(tables)), ttlInSeconds)); } } - public void initColumns(String schemaTable, Set columns) { + void initColumns(String schemaTable, Set columns) { if (columns != null && !columns.isEmpty()) { columnsCompleters.put(schemaTable, new CachedCompleter(new StringsCompleter(columns), ttlInSeconds)); } } - /** - * Find aliases in sql command. - * - * @param sqlArguments sql command divided on arguments - * @return for every alias contains table name in format schema_name.table_name - */ - public Map findAliasesInSQL(String[] sqlArguments) { - Map res = new HashMap<>(); - for (int i = 0; i < sqlArguments.length - 1; i++) { - if (columnsCompleters.keySet().contains(sqlArguments[i]) && - sqlArguments[i + 1].matches("[a-zA-Z]+")) { - res.put(sqlArguments[i + 1], sqlArguments[i]); - } - } - return res; - } - /** * Complete buffer in case it is a keyword. * @@ -436,160 +452,72 @@ private int completeColumn(String schema, String table, String buffer, int curso } /** - * Complete buffer with a single name. Function will decide what it is: - * a schema, a table of a column or a keyword - * - * @param aliases for every alias contains table name in format schema_name.table_name - * @return -1 in case of no candidates found, 0 otherwise + * Fill candidates for statement. */ - public int completeName(String buffer, int cursor, List candidates, - Map aliases) { - CursorArgument cursorArgument = parseCursorArgument(buffer, cursor); + void fillCandidates(String statement, int cursor, List candidates) { + SqlStatement sqlStatement = getStatementParameters(statement, cursor); + + logger.debug("Complete with buffer = " + statement + ", cursor = " + cursor); - // find schema and table name if they are - String schema; - String table; - String column; - if (cursorArgument.getSchema() == null) { // process all - List keywordsCandidates = new ArrayList(); + String schema = sqlStatement.getSchema(); + int cursorPosition = sqlStatement.getCursorPosition(); + + if (schema == null) { // process all + final String buffer; + if (cursorPosition > 0) { + buffer = sqlStatement.getCursorString(); + } else { + buffer = ""; + } + + int allColumnsRes = 0; + List columnCandidates = new ArrayList<>(); + for (String schemaTable : sqlStatement.getActiveSchemaTables()) { + int pointPos = schemaTable.indexOf('.'); + int columnRes = completeColumn(schemaTable.substring(0, pointPos), + schemaTable.substring(pointPos + 1), buffer, cursorPosition, columnCandidates); + addCompletions(candidates, columnCandidates, CompletionType.column.name()); + allColumnsRes = allColumnsRes + columnRes; + } + + List tableInDefaultSchemaCandidates = new ArrayList<>(); + int tableRes = completeTable(defaultSchema, buffer, cursorPosition, + tableInDefaultSchemaCandidates); + addCompletions(candidates, tableInDefaultSchemaCandidates, CompletionType.table.name()); + List schemaCandidates = new ArrayList<>(); - int keywordsRes = completeKeyword(buffer, cursor, keywordsCandidates); - int schemaRes = completeSchema(buffer, cursor, schemaCandidates); - addCompletions(candidates, keywordsCandidates, CompletionType.keyword.name()); + int schemaRes = completeSchema(buffer, cursorPosition, schemaCandidates); addCompletions(candidates, schemaCandidates, CompletionType.schema.name()); - return NumberUtils.max(new int[]{keywordsRes, schemaRes}); + + List keywordsCandidates = new ArrayList<>(); + int keywordsRes = completeKeyword(buffer, cursorPosition, keywordsCandidates); + addCompletions(candidates, keywordsCandidates, CompletionType.keyword.name()); + + logger.debug("Complete for buffer with " + keywordsRes + schemaRes + + tableRes + allColumnsRes + "candidates"); } else { - schema = cursorArgument.getSchema(); - if (aliases.containsKey(schema)) { // process alias case - String alias = aliases.get(schema); - int pointPos = alias.indexOf('.'); - schema = alias.substring(0, pointPos); - table = alias.substring(pointPos + 1); - column = cursorArgument.getColumn(); - List columnCandidates = new ArrayList(); - int columnRes = completeColumn(schema, table, column, cursorArgument.getCursorPosition(), - columnCandidates); - addCompletions(candidates, columnCandidates, CompletionType.column.name()); - // process schema.table case - } else if (cursorArgument.getTable() != null && cursorArgument.getColumn() == null) { - List tableCandidates = new ArrayList(); - table = cursorArgument.getTable(); - int tableRes = completeTable(schema, table, cursorArgument.getCursorPosition(), - tableCandidates); + String table = sqlStatement.getTable(); + String column = sqlStatement.getColumn(); + if (column == null) { + List tableCandidates = new ArrayList<>(); + int tableRes = completeTable(schema, table, cursorPosition, tableCandidates); addCompletions(candidates, tableCandidates, CompletionType.table.name()); - return tableRes; - } else { - List columnCandidates = new ArrayList(); - table = cursorArgument.getTable(); - column = cursorArgument.getColumn(); - int columnRes = completeColumn(schema, table, column, cursorArgument.getCursorPosition(), - columnCandidates); + logger.debug("Complete for tables with " + tableRes + "candidates"); + } else { // process schema.table and alias case + List columnCandidates = new ArrayList<>(); + int columnRes = completeColumn(schema, table, column, cursorPosition, columnCandidates); addCompletions(candidates, columnCandidates, CompletionType.column.name()); + logger.debug("Complete for tables with " + columnRes + "candidates"); } } - - return -1; - } - - // test purpose only - WhitespaceArgumentDelimiter getSqlDelimiter() { - return this.sqlDelimiter; } private void addCompletions(List interpreterCompletions, - List candidates, String meta) { + List candidates, String meta) { for (CharSequence candidate : candidates) { interpreterCompletions.add(new InterpreterCompletion(candidate.toString(), candidate.toString(), meta)); } } - - private CursorArgument parseCursorArgument(String buffer, int cursor) { - CursorArgument result = new CursorArgument(); - if (buffer != null && buffer.length() >= cursor) { - String buf = buffer.substring(0, cursor); - if (StringUtils.isNotBlank(buf)) { - ArgumentList argumentList = sqlDelimiter.delimit(buf, cursor); - String cursorArgument = argumentList.getCursorArgument(); - if (cursorArgument != null) { - int pointPos1 = cursorArgument.indexOf('.'); - int pointPos2 = cursorArgument.indexOf('.', pointPos1 + 1); - if (pointPos1 > -1) { - result.setSchema(cursorArgument.substring(0, pointPos1).trim()); - if (pointPos2 > -1) { - result.setTable(cursorArgument.substring(pointPos1 + 1, pointPos2)); - result.setColumn(cursorArgument.substring(pointPos2 + 1)); - result.setCursorPosition(cursor - pointPos2 - 1); - } else { - result.setTable(cursorArgument.substring(pointPos1 + 1)); - result.setCursorPosition(cursor - pointPos1 - 1); - } - } - } - } - } - - return result; - } - - private class CursorArgument { - private String schema; - private String table; - private String column; - private int cursorPosition; - - public String getSchema() { - return schema; - } - - public void setSchema(String schema) { - this.schema = schema; - } - - public String getTable() { - return table; - } - - public void setTable(String table) { - this.table = table; - } - - public String getColumn() { - return column; - } - - public void setColumn(String column) { - this.column = column; - } - - public int getCursorPosition() { - return cursorPosition; - } - - public void setCursorPosition(int cursorPosition) { - this.cursorPosition = cursorPosition; - } - - public boolean needLoadSchemas() { - if (table == null && column == null) { - return true; - } - return false; - } - - public boolean needLoadTables() { - if (schema != null && table != null && column == null) { - return true; - } - return false; - } - - public boolean needLoadColumns() { - if (schema != null && table != null && column != null) { - return true; - } - return false; - } - } } diff --git a/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlStatement.java b/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlStatement.java new file mode 100644 index 00000000000..c550aa979e5 --- /dev/null +++ b/jdbc/src/main/java/org/apache/zeppelin/jdbc/SqlStatement.java @@ -0,0 +1,188 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.zeppelin.jdbc; + +import org.apache.commons.lang.StringUtils; +import jline.console.completer.ArgumentCompleter; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.regex.Pattern; + +public class SqlStatement { + private int cursorPosition; + private String cursorString; + private String schema; + private String table; + private String column; + private HashMap aliases = new HashMap<>(); + private HashSet activeSchemaTables = new HashSet<>(); + + private static final String KEYWORD_AS = "as"; + + /** + * Delimiter that can split SQL statement in keyword list. + */ + private ArgumentCompleter.WhitespaceArgumentDelimiter sqlDelimiter = new ArgumentCompleter. + WhitespaceArgumentDelimiter() { + + private Pattern pattern = Pattern.compile(","); + + @Override + public boolean isDelimiterChar(CharSequence buffer, int pos) { + return pattern.matcher("" + buffer.charAt(pos)).matches() + || super.isDelimiterChar(buffer, pos); + } + }; + + // test purpose only + ArgumentCompleter.WhitespaceArgumentDelimiter getSqlDelimiter() { + return this.sqlDelimiter; + } + + SqlStatement(String statement, int cursor, String defaultSchema, + Collection schemas, Collection tablesInDefaultSchema, + Collection keywords) { + if (StringUtils.isNotBlank(statement) && statement.length() >= cursor) { + ArgumentCompleter.ArgumentList argumentList = sqlDelimiter.delimit(statement, cursor); + String cursorArgument = argumentList.getCursorArgument(); + int argumentPosition = argumentList.getArgumentPosition(); + + setStatementArguments(argumentList.getArguments(), cursorArgument, defaultSchema, schemas, + tablesInDefaultSchema, keywords); + setCursorArgument(cursorArgument, argumentPosition, defaultSchema, schemas, + tablesInDefaultSchema); + } + } + + private void setStatementArguments(String[] sqlArguments, String currentArgument, + String defaultSchema, Collection schemas, + Collection tablesInDefaultSchema, + Collection keywords) { + String schemaTable = null; + boolean isTable = false; + + for (int i = 0; i < sqlArguments.length; i++) { + if (!currentArgument.equals(sqlArguments[i])) { + int pointPos1 = sqlArguments[i].indexOf('.'); + if (pointPos1 > -1 + && sqlArguments[i].substring(pointPos1 + 1).indexOf('.') == -1 + && schemas.contains(sqlArguments[i].substring(0, pointPos1))) { + schemaTable = sqlArguments[i]; + isTable = true; + } else if (tablesInDefaultSchema.contains(sqlArguments[i])) { + schemaTable = defaultSchema + "." + sqlArguments[i]; + isTable = true; + } + if (isTable) { + isTable = false; + this.activeSchemaTables.add(schemaTable); + if (i + 2 < sqlArguments.length + && sqlArguments[i + 1].toLowerCase().equals(KEYWORD_AS)) { + this.aliases.put(sqlArguments[i + 2], schemaTable); + i += 2; + } else if (i + 1 < sqlArguments.length + && sqlArguments[i + 1].matches("[a-zA-Z0-9]+") + && !keywords.contains(sqlArguments[i + 1])) { + this.aliases.put(sqlArguments[i + 1], schemaTable); + i++; + } + } + } + } + } + + private void setCursorArgument(String string, int cursor, + String defaultSchema, Collection schemas, + Collection tablesInDefaultSchema) { + boolean defineColumns = false; + + if (cursor == 0) { + this.cursorPosition = 0; + return; + } + if (string != null) { + int pointPos1 = string.indexOf('.'); + int pointPos2 = string.indexOf('.', pointPos1 + 1); + if (pointPos1 > -1) { + String string1 = string.substring(0, pointPos1).trim(); + if (schemas.contains(string1)) { + this.schema = string1; + if (pointPos2 > -1) { + this.table = string.substring(pointPos1 + 1, pointPos2); + this.column = string.substring(pointPos2 + 1); + this.cursorPosition = cursor - pointPos2 - 1; + } else { + this.table = string.substring(pointPos1 + 1); + this.cursorPosition = cursor - pointPos1 - 1; + } + } else { + if (this.aliases.containsKey(string1)) { + String schemaTable = this.aliases.get(string1); + int pointSchemaTable = schemaTable.indexOf('.'); + this.schema = schemaTable.substring(0, pointSchemaTable); + this.table = schemaTable.substring(pointSchemaTable + 1); + defineColumns = true; + } else if (tablesInDefaultSchema.contains(string1)) { + this.schema = defaultSchema; + this.table = string1; + defineColumns = true; + } + if (defineColumns) { + this.column = string.substring(pointPos1 + 1); + this.cursorPosition = cursor - pointPos1 - 1; + } + } + } else { + this.cursorString = string; + this.cursorPosition = cursor; + } + } + } + + String getCursorString() { + return cursorString; + } + + String getSchema() { + return schema; + } + + String getTable() { + return table; + } + + String getColumn() { + return column; + } + + int getCursorPosition() { + return cursorPosition; + } + + boolean needLoadTables() { + return (schema != null) && (table != null) && (column == null); + } + + Collection getActiveSchemaTables() { + if ((schema != null) && (table != null) && (column != null)) { + activeSchemaTables.add(schema + "." + table); + } + return Collections.unmodifiableSet(activeSchemaTables); + } +} diff --git a/jdbc/src/main/resources/postgresql-native-driver-sql.keywords b/jdbc/src/main/resources/postgresql-native-driver-sql.keywords index bcd00c8ac8a..7254b45fbbb 100644 --- a/jdbc/src/main/resources/postgresql-native-driver-sql.keywords +++ b/jdbc/src/main/resources/postgresql-native-driver-sql.keywords @@ -1 +1 @@ -a,abort,abs,absent,absolute,access,according,action,ada,add,admin,after,aggregate,all,allocate,also,alter,always,analyse,analyze,and,any,are,array,array_agg,array_max_cardinality,as,asc,asensitive,assertion,assignment,asymmetric,at,atomic,attribute,attributes,authorization,avg,backward,base64,before,begin,begin_frame,begin_partition,bernoulli,between,bigint,binary,bit,bit_length,blob,blocked,bom,boolean,both,breadth,by,c,cache,call,called,cardinality,cascade,cascaded,case,cast,catalog,catalog_name,ceil,ceiling,chain,char,character,characteristics,characters,character_length,character_set_catalog,character_set_name,character_set_schema,char_length,check,checkpoint,class,class_origin,clob,close,cluster,coalesce,cobol,collate,collation,collation_catalog,collation_name,collation_schema,collect,column,columns,column_name,command_function,command_function_code,comment,comments,commit,committed,concurrently,condition,condition_number,configuration,connect,connection,connection_name,constraint,constraints,constraint_catalog,constraint_name,constraint_schema,constructor,contains,content,continue,control,conversion,convert,copy,corr,corresponding,cost,count,covar_pop,covar_samp,create,cross,csv,cube,cume_dist,current,current_catalog,current_date,current_default_transform_group,current_path,current_role,current_row,current_schema,current_time,current_timestamp,current_transform_group_for_type,current_user,cursor,cursor_name,cycle,data,database,datalink,date,datetime_interval_code,datetime_interval_precision,day,db,deallocate,dec,decimal,declare,default,defaults,deferrable,deferred,defined,definer,degree,delete,delimiter,delimiters,dense_rank,depth,deref,derived,desc,describe,descriptor,deterministic,diagnostics,dictionary,disable,discard,disconnect,dispatch,distinct,dlnewcopy,dlpreviouscopy,dlurlcomplete,dlurlcompleteonly,dlurlcompletewrite,dlurlpath,dlurlpathonly,dlurlpathwrite,dlurlscheme,dlurlserver,dlvalue,do,document,domain,double,drop,dynamic,dynamic_function,dynamic_function_code,each,element,else,empty,enable,encoding,encrypted,end,end-exec,end_frame,end_partition,enforced,enum,equals,escape,event,every,except,exception,exclude,excluding,exclusive,exec,execute,exists,exp,explain,expression,extension,external,extract,false,family,fetch,file,filter,final,first,first_value,flag,float,floor,following,for,force,foreign,fortran,forward,found,frame_row,free,freeze,from,fs,full,function,functions,fusion,g,general,generated,get,global,go,goto,grant,granted,greatest,group,grouping,groups,handler,having,header,hex,hierarchy,hold,hour,id,identity,if,ignore,ilike,immediate,immediately,immutable,implementation,implicit,import,in,including,increment,indent,index,indexes,indicator,inherit,inherits,initially,inline,inner,inout,input,insensitive,insert,instance,instantiable,instead,int,integer,integrity,intersect,intersection,interval,into,invoker,is,isnull,isolation,join,k,key,key_member,key_type,label,lag,language,large,last,last_value,lateral,lc_collate,lc_ctype,lead,leading,leakproof,least,left,length,level,library,like,like_regex,limit,link,listen,ln,load,local,localtime,localtimestamp,location,locator,lock,lower,m,map,mapping,match,matched,materialized,max,maxvalue,max_cardinality,member,merge,message_length,message_octet_length,message_text,method,min,minute,minvalue,mod,mode,modifies,module,month,more,move,multiset,mumps,name,names,namespace,national,natural,nchar,nclob,nesting,new,next,nfc,nfd,nfkc,nfkd,nil,no,none,normalize,normalized,not,nothing,notify,notnull,nowait,nth_value,ntile,null,nullable,nullif,nulls,number,numeric,object,occurrences_regex,octets,octet_length,of,off,offset,oids,old,on,only,open,operator,option,options,or,order,ordering,ordinality,others,out,outer,output,over,overlaps,overlay,overriding,owned,owner,p,pad,parameter,parameter_mode,parameter_name,parameter_ordinal_position,parameter_specific_catalog,parameter_specific_name,parameter_specific_schema,parser,partial,partition,pascal,passing,passthrough,password,path,percent,percentile_cont,percentile_disc,percent_rank,period,permission,placing,plans,pli,portion,position,position_regex,power,precedes,preceding,precision,prepare,prepared,preserve,primary,prior,privileges,procedural,procedure,program,public,quote,range,rank,read,reads,real,reassign,recheck,recovery,recursive,ref,references,referencing,refresh,regr_avgx,regr_avgy,regr_count,regr_intercept,regr_r2,regr_slope,regr_sxx,regr_sxy,regr_syy,reindex,relative,release,rename,repeatable,replace,replica,requiring,reset,respect,restart,restore,restrict,result,return,returned_cardinality,returned_length,returned_octet_length,returned_sqlstate,returning,returns,revoke,right,role,rollback,rollup,routine,routine_catalog,routine_name,routine_schema,row,rows,row_count,row_number,rule,savepoint,scale,schema,schema_name,scope,scope_catalog,scope_name,scope_schema,scroll,search,second,section,security,select,selective,self,sensitive,sequence,sequences,serializable,server,server_name,session,session_user,set,setof,sets,share,show,similar,simple,size,smallint,snapshot,some,source,space,specific,specifictype,specific_name,sql,sqlcode,sqlerror,sqlexception,sqlstate,sqlwarning,sqrt,stable,standalone,start,state,statement,static,statistics,stddev_pop,stddev_samp,stdin,stdout,storage,strict,strip,structure,style,subclass_origin,submultiset,substring,substring_regex,succeeds,sum,symmetric,sysid,system,system_time,system_user,t,table,tables,tablesample,tablespace,table_name,temp,template,temporary,text,then,ties,time,timestamp,timezone_hour,timezone_minute,to,token,top_level_count,trailing,transaction,transactions_committed,transactions_rolled_back,transaction_active,transform,transforms,translate,translate_regex,translation,treat,trigger,trigger_catalog,trigger_name,trigger_schema,trim,trim_array,true,truncate,trusted,type,types,uescape,unbounded,uncommitted,under,unencrypted,union,unique,unknown,unlink,unlisten,unlogged,unnamed,unnest,until,untyped,update,upper,uri,usage,user,user_defined_type_catalog,user_defined_type_code,user_defined_type_name,user_defined_type_schema,using,vacuum,valid,validate,validator,value,values,value_of,varbinary,varchar,variadic,varying,var_pop,var_samp,verbose,version,versioning,view,views,volatile,when,whenever,where,whitespace,width_bucket,window,with,within,without,work,wrapper,write,xml,xmlagg,xmlattributes,xmlbinary,xmlcast,xmlcomment,xmlconcat,xmldeclaration,xmldocument,xmlelement,xmlexists,xmlforest,xmliterate,xmlnamespaces,xmlparse,xmlpi,xmlquery,xmlroot,xmlschema,xmlserialize,xmltable,xmltext,xmlvalidate,year,yes,zone +abs,all,allocate,alter,analyse,analyze,and,any,are,array,array_agg,array_max_cardinality,as,asc,asensitive,asymmetric,at,atomic,authorization,avg,begin,begin_frame,begin_partition,between,bigint,binary,blob,boolean,both,by,call,called,cardinality,cascaded,case,cast,ceil,ceiling,char,character,character_length,char_length,check,clob,close,coalesce,collate,collation,collect,column,commit,concurrently,condition,connect,constraint,contains,convert,corr,corresponding,count,covar_pop,covar_samp,create,cross,cube,cume_dist,current,current_catalog,current_date,current_default_transform_group,current_path,current_role,current_row,current_schema,current_time,current_timestamp,current_transform_group_for_type,current_user,cursor,cycle,datalink,date,day,deallocate,dec,decimal,declare,default,deferrable,delete,dense_rank,deref,desc,describe,deterministic,disconnect,distinct,dlnewcopy,dlpreviouscopy,dlurlcomplete,dlurlcompleteonly,dlurlcompletewrite,dlurlpath,dlurlpathonly,dlurlpathwrite,dlurlscheme,dlurlserver,dlvalue,do,double,drop,dynamic,each,element,else,end,end-exec,end_frame,end_partition,equals,escape,every,except,exec,execute,exists,exp,external,extract,false,fetch,filter,first_value,float,floor,for,foreign,frame_row,free,freeze,from,full,function,fusion,get,global,grant,group,grouping,groups,having,hold,hour,identity,ilike,import,in,indicator,initially,inner,inout,insensitive,insert,int,integer,intersect,intersection,interval,into,is,isnull,join,lag,language,large,last_value,lateral,lead,leading,left,like,like_regex,limit,ln,local,localtime,localtimestamp,lower,match,max,member,merge,method,min,minute,mod,modifies,module,month,multiset,national,natural,nchar,nclob,new,no,none,normalize,not,notnull,nth_value,ntile,null,nullif,numeric,occurrences_regex,octet_length,of,offset,old,on,only,open,or,order,out,outer,over,overlaps,overlay,parameter,partition,percent,percentile_cont,percentile_disc,percent_rank,period,placing,portion,position,position_regex,power,precedes,precision,prepare,primary,procedure,range,rank,reads,real,recursive,ref,references,referencing,regr_avgx,regr_avgy,regr_count,regr_intercept,regr_r2,regr_slope,regr_sxx,regr_sxy,regr_syy,release,result,return,returning,returns,revoke,right,rollback,rollup,row,rows,row_number,savepoint,scope,scroll,search,second,select,sensitive,session_user,set,similar,smallint,some,specific,specifictype,sql,sqlexception,sqlstate,sqlwarning,sqrt,start,static,stddev_pop,stddev_samp,submultiset,substring,substring_regex,succeeds,sum,symmetric,system,system_time,system_user,table,tablesample,then,time,timestamp,timezone_hour,timezone_minute,to,trailing,translate,translate_regex,translation,treat,trigger,trim,trim_array,true,truncate,uescape,union,unique,unknown,unnest,update,upper,user,using,value,values,value_of,varbinary,varchar,variadic,varying,var_pop,var_samp,verbose,versioning,when,whenever,where,width_bucket,window,with,within,without,xml,xmlagg,xmlattributes,xmlbinary,xmlcast,xmlcomment,xmlconcat,xmldocument,xmlelement,xmlexists,xmlforest,xmliterate,xmlnamespaces,xmlparse,xmlpi,xmlquery,xmlserialize,xmltable,xmltext,xmlvalidate,year diff --git a/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlCompleterTest.java b/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlCompleterTest.java index 1ec3ae4a41a..59eccae25d8 100644 --- a/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlCompleterTest.java +++ b/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlCompleterTest.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one or more contributor license * agreements. See the NOTICE file distributed with this work for additional information regarding * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the @@ -12,35 +12,26 @@ * or implied. See the License for the specific language governing permissions and limitations under * the License. */ -package org.apache.zeppelin.jdbc; - -import static com.google.common.collect.Sets.newHashSet; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertTrue; +package org.apache.zeppelin.jdbc; import com.google.common.base.Joiner; - import org.apache.commons.lang.StringUtils; -import org.junit.Assert; +import org.apache.zeppelin.completer.CompletionType; +import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.junit.Before; import org.junit.Test; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.sql.SQLException; import java.util.ArrayList; -import java.util.HashMap; import java.util.HashSet; import java.util.List; -import java.util.Map; import java.util.Set; -import jline.console.completer.ArgumentCompleter; - -import org.apache.zeppelin.completer.CompletionType; -import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; +import static com.google.common.collect.Sets.newHashSet; +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; /** * SQL completer unit tests. @@ -54,51 +45,51 @@ public class CompleterTester { private int toCursor; private Set expectedCompletions; - public CompleterTester(SqlCompleter completer) { + CompleterTester(SqlCompleter completer) { this.completer = completer; } - public CompleterTester buffer(String buffer) { + CompleterTester buffer(String buffer) { this.buffer = buffer; return this; } - public CompleterTester from(int fromCursor) { + CompleterTester from(int fromCursor) { this.fromCursor = fromCursor; return this; } - public CompleterTester to(int toCursor) { + CompleterTester to(int toCursor) { this.toCursor = toCursor; return this; } - public CompleterTester expect(Set expectedCompletions) { + CompleterTester expect(Set expectedCompletions) { this.expectedCompletions = expectedCompletions; return this; } - public void test() { + void test() { for (int c = fromCursor; c <= toCursor; c++) { expectedCompletions(buffer, c, expectedCompletions); } } private void expectedCompletions(String buffer, int cursor, - Set expected) { + Set expected) { if (StringUtils.isNotEmpty(buffer) && buffer.length() > cursor) { buffer = buffer.substring(0, cursor); } List candidates = new ArrayList<>(); - completer.complete(buffer, cursor, candidates); + completer.fillCandidates(buffer, cursor, candidates); String explain = explain(buffer, cursor, candidates); logger.info(explain); - Assert.assertEquals("Buffer [" + buffer.replace(" ", ".") + "] and Cursor[" + cursor + "] " + assertEquals("Buffer [" + buffer.replace(" ", ".") + "] and Cursor[" + cursor + "] " + explain, expected, newHashSet(candidates)); } @@ -134,17 +125,14 @@ private String explain(String buffer, int cursor, List ca private Logger logger = LoggerFactory.getLogger(SqlCompleterTest.class); - private static final Set EMPTY = new HashSet<>(); + //private static final Set EMPTY = new HashSet<>(); private CompleterTester tester; - private ArgumentCompleter.WhitespaceArgumentDelimiter delimiter = - new ArgumentCompleter.WhitespaceArgumentDelimiter(); - private SqlCompleter sqlCompleter = new SqlCompleter(0); @Before - public void beforeTest() throws IOException, SQLException { + public void beforeTest() { Set schemas = new HashSet<>(); Set keywords = new HashSet<>(); @@ -177,6 +165,7 @@ public void beforeTest() throws IOException, SQLException { Set prodDdsFinancialAccountColumns = new HashSet<>(); prodDdsFinancialAccountColumns.add("account_rk"); prodDdsFinancialAccountColumns.add("account_id"); + prodDdsFinancialAccountColumns.add("open_dt"); sqlCompleter.initColumns("prod_dds.financial_account", prodDdsFinancialAccountColumns); @@ -197,30 +186,64 @@ public void beforeTest() throws IOException, SQLException { } @Test - public void testFindAliasesInSQL_Simple() { - String sql = "select * from prod_emart.financial_account a"; - Map res = sqlCompleter.findAliasesInSQL( - delimiter.delimit(sql, 0).getArguments()); - assertEquals(1, res.size()); - assertTrue(res.get("a").equals("prod_emart.financial_account")); + public void testSchemaAndTable() { + String buffer = "select * from prod_emart.fi"; + tester.buffer(buffer).from(20).to(23).expect(newHashSet( + new InterpreterCompletion("prod_emart", "prod_emart", + CompletionType.schema.name()))).test(); + tester.buffer(buffer).from(25).to(27).expect(newHashSet( + new InterpreterCompletion("financial_account", "financial_account", + CompletionType.table.name()))).test(); + } + + @Test + public void testEdges() { + String buffer = " ORDER "; + tester.buffer(buffer).from(3).to(7).expect(newHashSet( + new InterpreterCompletion("ORDER", "ORDER", CompletionType.keyword.name()))).test(); + tester.buffer(buffer).from(0).to(1).expect(newHashSet( + new InterpreterCompletion("ORDER", "ORDER", CompletionType.keyword.name()), + new InterpreterCompletion("SUBCLASS_ORIGIN", "SUBCLASS_ORIGIN", + CompletionType.keyword.name()), + new InterpreterCompletion("SUBSTRING", "SUBSTRING", CompletionType.keyword.name()), + new InterpreterCompletion("prod_emart", "prod_emart", CompletionType.schema.name()), + new InterpreterCompletion("LIMIT", "LIMIT", CompletionType.keyword.name()), + new InterpreterCompletion("SUM", "SUM", CompletionType.keyword.name()), + new InterpreterCompletion("prod_dds", "prod_dds", CompletionType.schema.name()), + new InterpreterCompletion("SELECT", "SELECT", CompletionType.keyword.name()), + new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name()) + )).test(); + } + + @Test + public void testMultipleWords() { + String buffer = "SELE FRO LIM"; + tester.buffer(buffer).from(2).to(4).expect(newHashSet( + new InterpreterCompletion("SELECT", "SELECT", CompletionType.keyword.name()))).test(); + tester.buffer(buffer).from(6).to(8).expect(newHashSet( + new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name()))).test(); + tester.buffer(buffer).from(10).to(12).expect(newHashSet( + new InterpreterCompletion("LIMIT", "LIMIT", CompletionType.keyword.name()))).test(); } @Test - public void testFindAliasesInSQL_Two() { - String sql = "select * from prod_dds.financial_account a, prod_dds.customer b"; - Map res = sqlCompleter.findAliasesInSQL( - sqlCompleter.getSqlDelimiter().delimit(sql, 0).getArguments()); - assertEquals(2, res.size()); - assertTrue(res.get("a").equals("prod_dds.financial_account")); - assertTrue(res.get("b").equals("prod_dds.customer")); + public void testMultiLineBuffer() { + String buffer = " \n SELE\nFRO"; + tester.buffer(buffer).from(5).to(7).expect(newHashSet( + new InterpreterCompletion("SELECT", "SELECT", CompletionType.keyword.name()))).test(); + tester.buffer(buffer).from(9).to(11).expect(newHashSet( + new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name()))).test(); } @Test - public void testFindAliasesInSQL_WrongTables() { - String sql = "select * from prod_ddsxx.financial_account a, prod_dds.customerxx b"; - Map res = sqlCompleter.findAliasesInSQL( - sqlCompleter.getSqlDelimiter().delimit(sql, 0).getArguments()); - assertEquals(0, res.size()); + public void testMultipleCompletionSuggestions() { + String buffer = "SU"; + tester.buffer(buffer).from(2).to(2).expect(newHashSet( + new InterpreterCompletion("SUBCLASS_ORIGIN", "SUBCLASS_ORIGIN", + CompletionType.keyword.name()), + new InterpreterCompletion("SUM", "SUM", CompletionType.keyword.name()), + new InterpreterCompletion("SUBSTRING", "SUBSTRING", CompletionType.keyword.name())) + ).test(); } @Test @@ -228,41 +251,39 @@ public void testCompleteName_Empty() { String buffer = ""; int cursor = 0; List candidates = new ArrayList<>(); - Map aliases = new HashMap<>(); - sqlCompleter.completeName(buffer, cursor, candidates, aliases); + sqlCompleter.fillCandidates(buffer, cursor, candidates); assertEquals(9, candidates.size()); assertTrue(candidates.contains(new InterpreterCompletion("prod_dds", "prod_dds", - CompletionType.schema.name()))); + CompletionType.schema.name()))); assertTrue(candidates.contains(new InterpreterCompletion("prod_emart", "prod_emart", - CompletionType.schema.name()))); + CompletionType.schema.name()))); assertTrue(candidates.contains(new InterpreterCompletion("SUM", "SUM", - CompletionType.keyword.name()))); + CompletionType.keyword.name()))); assertTrue(candidates.contains(new InterpreterCompletion("SUBSTRING", "SUBSTRING", - CompletionType.keyword.name()))); + CompletionType.keyword.name()))); assertTrue(candidates.contains(new InterpreterCompletion("SUBCLASS_ORIGIN", "SUBCLASS_ORIGIN", - CompletionType.keyword.name()))); + CompletionType.keyword.name()))); assertTrue(candidates.contains(new InterpreterCompletion("SELECT", "SELECT", - CompletionType.keyword.name()))); + CompletionType.keyword.name()))); assertTrue(candidates.contains(new InterpreterCompletion("ORDER", "ORDER", - CompletionType.keyword.name()))); + CompletionType.keyword.name()))); assertTrue(candidates.contains(new InterpreterCompletion("LIMIT", "LIMIT", - CompletionType.keyword.name()))); + CompletionType.keyword.name()))); assertTrue(candidates.contains(new InterpreterCompletion("FROM", "FROM", - CompletionType.keyword.name()))); + CompletionType.keyword.name()))); } @Test public void testCompleteName_SimpleSchema() { - String buffer = "prod_"; - int cursor = 3; + String buffer = "select * from prod_ "; + int cursor = 19; List candidates = new ArrayList<>(); - Map aliases = new HashMap<>(); - sqlCompleter.completeName(buffer, cursor, candidates, aliases); + sqlCompleter.fillCandidates(buffer, cursor, candidates); assertEquals(2, candidates.size()); assertTrue(candidates.contains(new InterpreterCompletion("prod_dds", "prod_dds", - CompletionType.schema.name()))); + CompletionType.schema.name()))); assertTrue(candidates.contains(new InterpreterCompletion("prod_emart", "prod_emart", - CompletionType.schema.name()))); + CompletionType.schema.name()))); } @Test @@ -270,12 +291,11 @@ public void testCompleteName_SimpleTable() { String buffer = "prod_dds.fin"; int cursor = 11; List candidates = new ArrayList<>(); - Map aliases = new HashMap<>(); - sqlCompleter.completeName(buffer, cursor, candidates, aliases); + sqlCompleter.fillCandidates(buffer, cursor, candidates); assertEquals(1, candidates.size()); assertTrue(candidates.contains( - new InterpreterCompletion("financial_account", "financial_account", - CompletionType.table.name()))); + new InterpreterCompletion("financial_account", "financial_account", + CompletionType.table.name()))); } @Test @@ -283,111 +303,39 @@ public void testCompleteName_SimpleColumn() { String buffer = "prod_dds.financial_account.acc"; int cursor = 30; List candidates = new ArrayList<>(); - Map aliases = new HashMap<>(); - sqlCompleter.completeName(buffer, cursor, candidates, aliases); + sqlCompleter.fillCandidates(buffer, cursor, candidates); assertEquals(2, candidates.size()); assertTrue(candidates.contains(new InterpreterCompletion("account_rk", "account_rk", - CompletionType.column.name()))); + CompletionType.column.name()))); assertTrue(candidates.contains(new InterpreterCompletion("account_id", "account_id", - CompletionType.column.name()))); + CompletionType.column.name()))); } @Test public void testCompleteName_WithAlias() { - String buffer = "a.acc"; + String buffer = "a.acc from prod_dds.financial_account a"; int cursor = 4; List candidates = new ArrayList<>(); - Map aliases = new HashMap<>(); - aliases.put("a", "prod_dds.financial_account"); - sqlCompleter.completeName(buffer, cursor, candidates, aliases); + sqlCompleter.fillCandidates(buffer, cursor, candidates); assertEquals(2, candidates.size()); assertTrue(candidates.contains(new InterpreterCompletion("account_rk", "account_rk", - CompletionType.column.name()))); + CompletionType.column.name()))); assertTrue(candidates.contains(new InterpreterCompletion("account_id", "account_id", - CompletionType.column.name()))); + CompletionType.column.name()))); } @Test public void testCompleteName_WithAliasAndPoint() { - String buffer = "a."; + String buffer = "a. from prod_dds.financial_account a"; int cursor = 2; List candidates = new ArrayList<>(); - Map aliases = new HashMap<>(); - aliases.put("a", "prod_dds.financial_account"); - sqlCompleter.completeName(buffer, cursor, candidates, aliases); - assertEquals(2, candidates.size()); + sqlCompleter.fillCandidates(buffer, cursor, candidates); + assertEquals(3, candidates.size()); assertTrue(candidates.contains(new InterpreterCompletion("account_rk", "account_rk", - CompletionType.column.name()))); + CompletionType.column.name()))); assertTrue(candidates.contains(new InterpreterCompletion("account_id", "account_id", - CompletionType.column.name()))); - } - - @Test - public void testSchemaAndTable() { - String buffer = "select * from prod_emart.fi"; - tester.buffer(buffer).from(20).to(23).expect(newHashSet( - new InterpreterCompletion("prod_emart", "prod_emart", - CompletionType.schema.name()))).test(); - tester.buffer(buffer).from(25).to(27).expect(newHashSet( - new InterpreterCompletion("financial_account", "financial_account", - CompletionType.table.name()))).test(); - } - - @Test - public void testEdges() { - String buffer = " ORDER "; - tester.buffer(buffer).from(3).to(7).expect(newHashSet( - new InterpreterCompletion("ORDER", "ORDER", CompletionType.keyword.name()))).test(); - tester.buffer(buffer).from(0).to(1).expect(newHashSet( - new InterpreterCompletion("ORDER", "ORDER", CompletionType.keyword.name()), - new InterpreterCompletion("SUBCLASS_ORIGIN", "SUBCLASS_ORIGIN", - CompletionType.keyword.name()), - new InterpreterCompletion("SUBSTRING", "SUBSTRING", CompletionType.keyword.name()), - new InterpreterCompletion("prod_emart", "prod_emart", CompletionType.schema.name()), - new InterpreterCompletion("LIMIT", "LIMIT", CompletionType.keyword.name()), - new InterpreterCompletion("SUM", "SUM", CompletionType.keyword.name()), - new InterpreterCompletion("prod_dds", "prod_dds", CompletionType.schema.name()), - new InterpreterCompletion("SELECT", "SELECT", CompletionType.keyword.name()), - new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name()) - )).test(); - } - - @Test - public void testMultipleWords() { - String buffer = "SELE FRO LIM"; - tester.buffer(buffer).from(2).to(4).expect(newHashSet( - new InterpreterCompletion("SELECT", "SELECT", CompletionType.keyword.name()))).test(); - tester.buffer(buffer).from(6).to(8).expect(newHashSet( - new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name()))).test(); - tester.buffer(buffer).from(10).to(12).expect(newHashSet( - new InterpreterCompletion("LIMIT", "LIMIT", CompletionType.keyword.name()))).test(); - } - - @Test - public void testMultiLineBuffer() { - String buffer = " \n SELE\nFRO"; - tester.buffer(buffer).from(5).to(7).expect(newHashSet( - new InterpreterCompletion("SELECT", "SELECT", CompletionType.keyword.name()))).test(); - tester.buffer(buffer).from(9).to(11).expect(newHashSet( - new InterpreterCompletion("FROM", "FROM", CompletionType.keyword.name()))).test(); - } - - @Test - public void testMultipleCompletionSuggestions() { - String buffer = "SU"; - tester.buffer(buffer).from(2).to(2).expect(newHashSet( - new InterpreterCompletion("SUBCLASS_ORIGIN", "SUBCLASS_ORIGIN", - CompletionType.keyword.name()), - new InterpreterCompletion("SUM", "SUM", CompletionType.keyword.name()), - new InterpreterCompletion("SUBSTRING", "SUBSTRING", CompletionType.keyword.name())) - ).test(); - } - - @Test - public void testSqlDelimiterCharacters() { - assertTrue(sqlCompleter.getSqlDelimiter().isDelimiterChar("r,", 1)); - assertTrue(sqlCompleter.getSqlDelimiter().isDelimiterChar("SS,", 2)); - assertTrue(sqlCompleter.getSqlDelimiter().isDelimiterChar(",", 0)); - assertTrue(sqlCompleter.getSqlDelimiter().isDelimiterChar("ttt,", 3)); + CompletionType.column.name()))); + assertTrue(candidates.contains(new InterpreterCompletion("open_dt", "open_dt", + CompletionType.column.name()))); } } diff --git a/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlStatementTest.java b/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlStatementTest.java new file mode 100644 index 00000000000..d3305c99cb7 --- /dev/null +++ b/jdbc/src/test/java/org/apache/zeppelin/jdbc/SqlStatementTest.java @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more contributor license + * agreements. See the NOTICE file distributed with this work for additional information regarding + * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package org.apache.zeppelin.jdbc; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Collection; +import java.util.TreeSet; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; + +public class SqlStatementTest { + private String defaultSchema = "prod_dds"; + private Collection schemas = new TreeSet<>(); + private Collection tablesInDefaultSchema = new TreeSet<>(); + private Collection keywords = new TreeSet<>(); + + @Before + public void beforeTest() { + schemas.add("prod_dds"); + schemas.add("prod_emart"); + tablesInDefaultSchema.add("customer"); + tablesInDefaultSchema.add("account"); + keywords.add("select"); + keywords.add("where"); + keywords.add("order"); + keywords.add("limit"); + keywords.add("inner"); + keywords.add("left"); + keywords.add("join"); + } + + @Test + public void testSqlDelimiterCharacters() { + String statement = "select * from table;"; + int cursor = 7; + + SqlStatement sqlStatement = new SqlStatement(statement, cursor, defaultSchema, + schemas, tablesInDefaultSchema, keywords); + + assertTrue(sqlStatement.getSqlDelimiter().isDelimiterChar("r,", 1)); + assertTrue(sqlStatement.getSqlDelimiter().isDelimiterChar("SS,", 2)); + assertTrue(sqlStatement.getSqlDelimiter().isDelimiterChar(",", 0)); + assertTrue(sqlStatement.getSqlDelimiter().isDelimiterChar("ttt,", 3)); + } + + @Test + public void testSqlStatementActiveSchemaTables() { + String statement = "select from account acc left join prod_emart.application as app on ;"; + int cursor = 7; + + SqlStatement sqlStatementActiveSchemaTables = new SqlStatement(statement, cursor, defaultSchema, + schemas, tablesInDefaultSchema, keywords); + + assertTrue(sqlStatementActiveSchemaTables.getActiveSchemaTables().contains("prod_dds.account")); + assertTrue(sqlStatementActiveSchemaTables.getActiveSchemaTables() + .contains("prod_emart.application")); + assertEquals(0, sqlStatementActiveSchemaTables.getCursorPosition()); + } + + @Test + public void testSqlStatementAliasForTableInDefaultSchema() { + String statement = "select acc.z from account acc left join prod_emart.application app on ;"; + int cursor = 12; + + SqlStatement sqlStatementAliasForTableInDefaultSchema = new SqlStatement(statement, cursor, + defaultSchema, schemas, tablesInDefaultSchema, keywords); + + assertEquals(defaultSchema, sqlStatementAliasForTableInDefaultSchema.getSchema()); + assertEquals("account", sqlStatementAliasForTableInDefaultSchema.getTable()); + assertEquals("z", sqlStatementAliasForTableInDefaultSchema.getColumn()); + assertEquals(1, sqlStatementAliasForTableInDefaultSchema.getCursorPosition()); + } + + @Test + public void testSqlStatementAliasSchemaTable() { + String statement = "select app.y from account acc left join prod_emart.application app on ;"; + int cursor = 12; + + SqlStatement sqlStatementAliasSchemaTable = new SqlStatement(statement, cursor, defaultSchema, + schemas, tablesInDefaultSchema, keywords); + + assertEquals("prod_emart", sqlStatementAliasSchemaTable.getSchema()); + assertEquals("application", sqlStatementAliasSchemaTable.getTable()); + assertEquals("y", sqlStatementAliasSchemaTable.getColumn()); + assertEquals(1, sqlStatementAliasSchemaTable.getCursorPosition()); + } + + @Test + public void testSqlStatementSchemaTableColumn() { + String statement = "select prod_emart.application.yy from account acc" + + "left join prod_emart.application app on ;"; + int cursor = 32; + + SqlStatement sqlStatementSchemaTableColumn = new SqlStatement(statement, cursor, defaultSchema, + schemas, tablesInDefaultSchema, keywords); + + assertEquals("prod_emart", sqlStatementSchemaTableColumn.getSchema()); + assertEquals("application", sqlStatementSchemaTableColumn.getTable()); + assertEquals("yy", sqlStatementSchemaTableColumn.getColumn()); + assertEquals(2, sqlStatementSchemaTableColumn.getCursorPosition()); + } + + @Test + public void testSqlStatementSchemaTable() { + String statement = "select \n from prod_emart.ap \nlimit 300"; + int cursor = 27; + + SqlStatement sqlStatementSchemaTable = new SqlStatement(statement, cursor, defaultSchema, + schemas, tablesInDefaultSchema, keywords); + + assertEquals("prod_emart", sqlStatementSchemaTable.getSchema()); + assertEquals("ap", sqlStatementSchemaTable.getTable()); + assertEquals(2, sqlStatementSchemaTable.getCursorPosition()); + assertTrue(sqlStatementSchemaTable.needLoadTables()); + assertNull(sqlStatementSchemaTable.getColumn()); + assertNull(sqlStatementSchemaTable.getCursorString()); + } + + @Test + public void testSqlStatementSchema() { + String statement = "select \n from prod_emart. ;"; + int cursor = 25; + + SqlStatement sqlStatementSchema = new SqlStatement(statement, cursor, defaultSchema, + schemas, tablesInDefaultSchema, keywords); + + assertEquals("prod_emart", sqlStatementSchema.getSchema()); + assertEquals("", sqlStatementSchema.getTable()); + assertEquals(0, sqlStatementSchema.getCursorPosition()); + assertNull(sqlStatementSchema.getColumn()); + assertNull(sqlStatementSchema.getCursorString()); + } + + @Test + public void testSqlStatementCustom() { + String statement = "select pro from account acc left join prod_emart.application app on ;"; + int cursor = 10; + + SqlStatement sqlStatementCustom = new SqlStatement(statement, cursor, defaultSchema, + schemas, tablesInDefaultSchema, keywords); + + assertEquals("pro", sqlStatementCustom.getCursorString()); + assertNull(sqlStatementCustom.getSchema()); + assertNull(sqlStatementCustom.getTable()); + assertNull(sqlStatementCustom.getColumn()); + assertEquals(3, sqlStatementCustom.getCursorPosition()); + } +} diff --git a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/completer/CachedCompleter.java b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/completer/CachedCompleter.java index ef2223eb2cf..c25604b3f0c 100644 --- a/zeppelin-interpreter/src/main/java/org/apache/zeppelin/completer/CachedCompleter.java +++ b/zeppelin-interpreter/src/main/java/org/apache/zeppelin/completer/CachedCompleter.java @@ -1,4 +1,4 @@ -/** +/* * Licensed to the Apache Software Foundation (ASF) under one or more contributor license * agreements. See the NOTICE file distributed with this work for additional information regarding * copyright ownership. The ASF licenses this file to you under the Apache License, Version 2.0 (the @@ -17,14 +17,14 @@ import jline.console.completer.Completer; /** - * Completer with time to live + * Completer with time to live. */ -public class CachedCompleter { - private Completer completer; +public class CachedCompleter { + private T completer; private int ttlInSeconds; private long createdAt; - public CachedCompleter(Completer completer, int ttlInSeconds) { + public CachedCompleter(T completer, int ttlInSeconds) { this.completer = completer; this.ttlInSeconds = ttlInSeconds; this.createdAt = System.currentTimeMillis(); @@ -38,7 +38,7 @@ public boolean isExpired() { return false; } - public Completer getCompleter() { + public T getCompleter() { return completer; } }