diff --git a/docs/interpreter/jdbc.md b/docs/interpreter/jdbc.md index 32adcba94be..346fcbb2f3e 100644 --- a/docs/interpreter/jdbc.md +++ b/docs/interpreter/jdbc.md @@ -167,6 +167,10 @@ There are more JDBC interpreter properties you can specify like below. default.jceks.credentialKey jceks credential key + + zeppelin.jdbc.precode + Some SQL which executes while opening connection + You can also add more properties by using this [method](http://docs.oracle.com/javase/7/docs/api/java/sql/DriverManager.html#getConnection%28java.lang.String,%20java.util.Properties%29). diff --git a/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java b/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java index c43e3920ea5..d4952246c91 100644 --- a/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java +++ b/jdbc/src/main/java/org/apache/zeppelin/jdbc/JDBCInterpreter.java @@ -14,14 +14,7 @@ */ package org.apache.zeppelin.jdbc; -import static org.apache.commons.lang.StringUtils.containsIgnoreCase; -import static org.apache.commons.lang.StringUtils.isEmpty; -import static org.apache.commons.lang.StringUtils.isNotEmpty; -import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.KERBEROS; -import java.io.ByteArrayOutputStream; -import java.io.PrintStream; import java.io.IOException; -import java.nio.charset.StandardCharsets; import java.security.PrivilegedExceptionAction; import java.sql.Connection; import java.sql.DriverManager; @@ -37,11 +30,11 @@ import java.util.Properties; import java.util.Set; -import com.google.common.base.Throwables; import org.apache.commons.dbcp2.ConnectionFactory; import org.apache.commons.dbcp2.DriverManagerConnectionFactory; import org.apache.commons.dbcp2.PoolableConnectionFactory; import org.apache.commons.dbcp2.PoolingDriver; +import org.apache.commons.lang.StringUtils; import org.apache.commons.pool2.ObjectPool; import org.apache.commons.pool2.impl.GenericObjectPool; import org.apache.hadoop.conf.Configuration; @@ -49,7 +42,10 @@ import org.apache.hadoop.security.alias.CredentialProvider; import org.apache.hadoop.security.alias.CredentialProviderFactory; import org.apache.thrift.transport.TTransportException; -import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResult.Code; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.jdbc.security.JDBCSecurityImpl; @@ -61,9 +57,13 @@ import org.slf4j.LoggerFactory; import com.google.common.base.Function; +import com.google.common.base.Throwables; import com.google.common.collect.Lists; -import com.google.common.collect.Sets; -import com.google.common.collect.Sets.SetView; + +import static org.apache.commons.lang.StringUtils.containsIgnoreCase; +import static org.apache.commons.lang.StringUtils.isEmpty; +import static org.apache.commons.lang.StringUtils.isNotEmpty; +import static org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod.KERBEROS; /** * JDBC interpreter for Zeppelin. This interpreter can also be used for accessing HAWQ, @@ -103,6 +103,7 @@ public class JDBCInterpreter extends Interpreter { static final String PASSWORD_KEY = "password"; static final String JDBC_JCEKS_FILE = "jceks.file"; static final String JDBC_JCEKS_CREDENTIAL_KEY = "jceks.credentialKey"; + static final String ZEPPELIN_JDBC_PRECODE_KEY = "zeppelin.jdbc.precode"; static final String DOT = "."; private static final char WHITESPACE = ' '; @@ -340,6 +341,9 @@ private Connection getConnectionFromPool(String url, String user, String propert if (!getJDBCConfiguration(user).isConnectionInDBDriverPool(propertyKey)) { createConnectionPool(url, user, propertyKey, properties); + try (Connection connection = DriverManager.getConnection(jdbcDriver)) { + executePrecode(connection); + } } return DriverManager.getConnection(jdbcDriver); } @@ -540,6 +544,20 @@ protected ArrayList splitSqlQueries(String sql) { return queries; } + private void executePrecode(Connection connection) throws SQLException { + String precode = getProperty(ZEPPELIN_JDBC_PRECODE_KEY); + if (StringUtils.isNotBlank(precode)) { + precode = StringUtils.trim(precode); + logger.info("Run SQL precode '{}'", precode); + try (Statement statement = connection.createStatement()) { + statement.execute(precode); + if (!connection.getAutoCommit()) { + connection.commit(); + } + } + } + } + private InterpreterResult executeSql(String propertyKey, String sql, InterpreterContext interpreterContext) { Connection connection; @@ -761,4 +779,3 @@ int getMaxConcurrentConnection() { } } } - diff --git a/jdbc/src/main/resources/interpreter-setting.json b/jdbc/src/main/resources/interpreter-setting.json index 20a900f9d79..6134243502e 100644 --- a/jdbc/src/main/resources/interpreter-setting.json +++ b/jdbc/src/main/resources/interpreter-setting.json @@ -63,6 +63,12 @@ "propertyName": "zeppelin.jdbc.principal", "defaultValue": "", "description": "Kerberos principal" + }, + "zeppelin.jdbc.precode": { + "envName": null, + "propertyName": "zeppelin.jdbc.precode", + "defaultValue": "", + "description": "SQL which executes while opening connection" } }, "editor": { diff --git a/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java b/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java index 9a041f923fd..197c368154c 100644 --- a/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java +++ b/jdbc/src/test/java/org/apache/zeppelin/jdbc/JDBCInterpreterTest.java @@ -43,6 +43,9 @@ import org.junit.Test; import com.mockrunner.jdbc.BasicJDBCTestCaseAdapter; + +import static org.apache.zeppelin.jdbc.JDBCInterpreter.ZEPPELIN_JDBC_PRECODE_KEY; + /** * JDBC interpreter unit tests */ @@ -386,4 +389,43 @@ public void testMultiTenant() throws SQLException, IOException { assertNull(user2JDBC2Conf.getPropertyMap("default").get("password")); jdbc2.close(); } + + @Test + public void testPrecode() throws SQLException, IOException { + Properties properties = new Properties(); + properties.setProperty("default.driver", "org.h2.Driver"); + properties.setProperty("default.url", getJdbcConnection()); + properties.setProperty("default.user", ""); + properties.setProperty("default.password", ""); + properties.setProperty(ZEPPELIN_JDBC_PRECODE_KEY, "SET @testVariable=1"); + JDBCInterpreter jdbcInterpreter = new JDBCInterpreter(properties); + jdbcInterpreter.open(); + + String sqlQuery = "select @testVariable"; + + InterpreterResult interpreterResult = jdbcInterpreter.interpret(sqlQuery, interpreterContext); + + assertEquals(InterpreterResult.Code.SUCCESS, interpreterResult.code()); + assertEquals(InterpreterResult.Type.TABLE, interpreterResult.message().get(0).getType()); + assertEquals("@TESTVARIABLE\n1\n", interpreterResult.message().get(0).getData()); + } + + @Test + public void testIncorrectPrecode() throws SQLException, IOException { + Properties properties = new Properties(); + properties.setProperty("default.driver", "org.h2.Driver"); + properties.setProperty("default.url", getJdbcConnection()); + properties.setProperty("default.user", ""); + properties.setProperty("default.password", ""); + properties.setProperty(ZEPPELIN_JDBC_PRECODE_KEY, "incorrect command"); + JDBCInterpreter jdbcInterpreter = new JDBCInterpreter(properties); + jdbcInterpreter.open(); + + String sqlQuery = "select 1"; + + InterpreterResult interpreterResult = jdbcInterpreter.interpret(sqlQuery, interpreterContext); + + assertEquals(InterpreterResult.Code.ERROR, interpreterResult.code()); + assertEquals(InterpreterResult.Type.TEXT, interpreterResult.message().get(0).getType()); + } }