diff --git a/docs/interpreter/spark.md b/docs/interpreter/spark.md index a19eda22212..f5febda275e 100644 --- a/docs/interpreter/spark.md +++ b/docs/interpreter/spark.md @@ -115,6 +115,11 @@ You can also set other Spark properties which are not listed in the table. For a Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON). Property spark.pyspark.driver.python take precedence if it is set + + zeppelin.pyspark.precode + + Snippet of code which executes when interpreter initialize. Variables, methods, classes, etc defined in snippet are available in your paragraphs + zeppelin.spark.concurrentSQL false diff --git a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java index db52a5342a7..29e16c36963 100644 --- a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java +++ b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java @@ -20,7 +20,6 @@ import java.io.BufferedWriter; import java.io.ByteArrayOutputStream; import java.io.File; -import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStreamWriter; import java.io.PipedInputStream; @@ -34,7 +33,6 @@ import java.util.Map; import java.util.Properties; -import org.apache.commons.compress.utils.IOUtils; import org.apache.commons.exec.CommandLine; import org.apache.commons.exec.DefaultExecutor; import org.apache.commons.exec.ExecuteException; @@ -42,12 +40,22 @@ import org.apache.commons.exec.ExecuteWatchdog; import org.apache.commons.exec.PumpStreamHandler; import org.apache.commons.exec.environment.EnvironmentUtils; +import org.apache.commons.io.FileUtils; +import org.apache.commons.io.IOUtils; +import org.apache.commons.lang.StringUtils; import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.sql.SQLContext; -import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterException; +import org.apache.zeppelin.interpreter.InterpreterGroup; import org.apache.zeppelin.interpreter.InterpreterHookRegistry.HookType; +import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.InterpreterResult.Code; +import org.apache.zeppelin.interpreter.InterpreterResultMessage; +import org.apache.zeppelin.interpreter.LazyOpenInterpreter; +import org.apache.zeppelin.interpreter.WrappedInterpreter; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.interpreter.util.InterpreterOutputStream; import org.apache.zeppelin.spark.dep.SparkDependencyContext; @@ -55,14 +63,16 @@ import org.slf4j.LoggerFactory; import com.google.gson.Gson; - import py4j.GatewayServer; /** * */ public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler { - Logger logger = LoggerFactory.getLogger(PySparkInterpreter.class); + + public static final String ZEPPELIN_PYSPARK_PRECODE_KEY = "zeppelin.pyspark.precode"; + public static Logger logger = LoggerFactory.getLogger(PySparkInterpreter.class); + private GatewayServer gatewayServer; private DefaultExecutor executor; private int port; @@ -96,11 +106,13 @@ private void createPythonScript() { } try { - FileOutputStream outStream = new FileOutputStream(out); - IOUtils.copy( - classLoader.getResourceAsStream("python/zeppelin_pyspark.py"), - outStream); - outStream.close(); + String pythonScript = IOUtils.toString( + classLoader.getResourceAsStream("python/zeppelin_pyspark.py"), "UTF-8"); + String precode = getProperty(ZEPPELIN_PYSPARK_PRECODE_KEY); + if (StringUtils.isNotBlank(precode)) { + pythonScript = pythonScript.replace("#precode#", StringUtils.trim(precode)); + } + FileUtils.writeStringToFile(out, pythonScript, "UTF-8"); } catch (IOException e) { throw new InterpreterException(e); } @@ -524,8 +536,7 @@ private String getCompletionTargetString(String text, int cursor) { String completionScriptText = ""; try { completionScriptText = text.substring(0, cursor); - } - catch (Exception e) { + } catch (Exception e) { logger.error(e.toString()); return null; } @@ -544,13 +555,11 @@ private String getCompletionTargetString(String text, int cursor) { if (completionStartPosition == completionEndPosition) { completionStartPosition = 0; - } - else - { + } else { completionStartPosition = completionEndPosition - completionStartPosition; } resultCompletionText = completionScriptText.substring( - completionStartPosition , completionEndPosition); + completionStartPosition, completionEndPosition); return resultCompletionText; } diff --git a/spark/src/main/resources/interpreter-setting.json b/spark/src/main/resources/interpreter-setting.json index 2b78d1627cd..0362e929c14 100644 --- a/spark/src/main/resources/interpreter-setting.json +++ b/spark/src/main/resources/interpreter-setting.json @@ -128,6 +128,12 @@ "propertyName": null, "defaultValue": "python", "description": "Python command to run pyspark with" + }, + "zeppelin.pyspark.precode": { + "envName": null, + "propertyName": null, + "defaultValue": "", + "description": "Snippet of code which executes when interpreter initialize" } }, "editor": { diff --git a/spark/src/main/resources/python/zeppelin_pyspark.py b/spark/src/main/resources/python/zeppelin_pyspark.py index d9c68c28970..d74e825e4a5 100644 --- a/spark/src/main/resources/python/zeppelin_pyspark.py +++ b/spark/src/main/resources/python/zeppelin_pyspark.py @@ -282,6 +282,9 @@ def getCompletion(self, text_value): z = PyZeppelinContext(intp.getZeppelinContext()) z._setup_matplotlib() +# Place to set precode if precode is set +#precode# + while True : req = intp.getStatements() try: diff --git a/spark/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java b/spark/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java index 36975126418..78e85727dc5 100644 --- a/spark/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java +++ b/spark/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java @@ -17,26 +17,39 @@ package org.apache.zeppelin.spark; +import java.io.IOException; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Properties; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + import org.apache.zeppelin.display.AngularObjectRegistry; import org.apache.zeppelin.display.GUI; -import org.apache.zeppelin.interpreter.*; +import org.apache.zeppelin.interpreter.Interpreter; +import org.apache.zeppelin.interpreter.InterpreterContext; +import org.apache.zeppelin.interpreter.InterpreterContextRunner; +import org.apache.zeppelin.interpreter.InterpreterGroup; +import org.apache.zeppelin.interpreter.InterpreterOutput; +import org.apache.zeppelin.interpreter.InterpreterResult; import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion; import org.apache.zeppelin.resource.LocalResourcePool; import org.apache.zeppelin.user.AuthenticationInfo; -import org.junit.*; +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.FixMethodOrder; +import org.junit.Test; import org.junit.rules.TemporaryFolder; import org.junit.runners.MethodSorters; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import java.io.IOException; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Properties; -import java.util.regex.Matcher; -import java.util.regex.Pattern; -import static org.junit.Assert.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; @FixMethodOrder(MethodSorters.NAME_ASCENDING) public class PySparkInterpreterTest { @@ -58,6 +71,7 @@ private static Properties getPySparkTestProperties() throws IOException { p.setProperty("zeppelin.spark.maxResult", "1000"); p.setProperty("zeppelin.spark.importImplicit", "true"); p.setProperty("zeppelin.pyspark.python", "python"); + p.setProperty("zeppelin.pyspark.precode", "precodeVar = 'test'"); p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath()); return p; } @@ -123,6 +137,13 @@ public void testCompletion() { } } + @Test + public void testPrecode() { + assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret("print(precodeVar)\n", context).code()); + } + + + private class infinityPythonJob implements Runnable { @Override public void run() {