diff --git a/docs/interpreter/spark.md b/docs/interpreter/spark.md
index a19eda22212..f5febda275e 100644
--- a/docs/interpreter/spark.md
+++ b/docs/interpreter/spark.md
@@ -115,6 +115,11 @@ You can also set other Spark properties which are not listed in the table. For a
Python binary executable to use for PySpark in driver only (default is PYSPARK_PYTHON).
Property spark.pyspark.driver.python take precedence if it is set |
+
+ | zeppelin.pyspark.precode |
+ |
+ Snippet of code which executes when interpreter initialize. Variables, methods, classes, etc defined in snippet are available in your paragraphs |
+
| zeppelin.spark.concurrentSQL |
false |
diff --git a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
index db52a5342a7..29e16c36963 100644
--- a/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
+++ b/spark/src/main/java/org/apache/zeppelin/spark/PySparkInterpreter.java
@@ -20,7 +20,6 @@
import java.io.BufferedWriter;
import java.io.ByteArrayOutputStream;
import java.io.File;
-import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.io.PipedInputStream;
@@ -34,7 +33,6 @@
import java.util.Map;
import java.util.Properties;
-import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.exec.CommandLine;
import org.apache.commons.exec.DefaultExecutor;
import org.apache.commons.exec.ExecuteException;
@@ -42,12 +40,22 @@
import org.apache.commons.exec.ExecuteWatchdog;
import org.apache.commons.exec.PumpStreamHandler;
import org.apache.commons.exec.environment.EnvironmentUtils;
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.io.IOUtils;
+import org.apache.commons.lang.StringUtils;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.SQLContext;
-import org.apache.zeppelin.interpreter.*;
+import org.apache.zeppelin.interpreter.Interpreter;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterException;
+import org.apache.zeppelin.interpreter.InterpreterGroup;
import org.apache.zeppelin.interpreter.InterpreterHookRegistry.HookType;
+import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.InterpreterResult.Code;
+import org.apache.zeppelin.interpreter.InterpreterResultMessage;
+import org.apache.zeppelin.interpreter.LazyOpenInterpreter;
+import org.apache.zeppelin.interpreter.WrappedInterpreter;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.interpreter.util.InterpreterOutputStream;
import org.apache.zeppelin.spark.dep.SparkDependencyContext;
@@ -55,14 +63,16 @@
import org.slf4j.LoggerFactory;
import com.google.gson.Gson;
-
import py4j.GatewayServer;
/**
*
*/
public class PySparkInterpreter extends Interpreter implements ExecuteResultHandler {
- Logger logger = LoggerFactory.getLogger(PySparkInterpreter.class);
+
+ public static final String ZEPPELIN_PYSPARK_PRECODE_KEY = "zeppelin.pyspark.precode";
+ public static Logger logger = LoggerFactory.getLogger(PySparkInterpreter.class);
+
private GatewayServer gatewayServer;
private DefaultExecutor executor;
private int port;
@@ -96,11 +106,13 @@ private void createPythonScript() {
}
try {
- FileOutputStream outStream = new FileOutputStream(out);
- IOUtils.copy(
- classLoader.getResourceAsStream("python/zeppelin_pyspark.py"),
- outStream);
- outStream.close();
+ String pythonScript = IOUtils.toString(
+ classLoader.getResourceAsStream("python/zeppelin_pyspark.py"), "UTF-8");
+ String precode = getProperty(ZEPPELIN_PYSPARK_PRECODE_KEY);
+ if (StringUtils.isNotBlank(precode)) {
+ pythonScript = pythonScript.replace("#precode#", StringUtils.trim(precode));
+ }
+ FileUtils.writeStringToFile(out, pythonScript, "UTF-8");
} catch (IOException e) {
throw new InterpreterException(e);
}
@@ -524,8 +536,7 @@ private String getCompletionTargetString(String text, int cursor) {
String completionScriptText = "";
try {
completionScriptText = text.substring(0, cursor);
- }
- catch (Exception e) {
+ } catch (Exception e) {
logger.error(e.toString());
return null;
}
@@ -544,13 +555,11 @@ private String getCompletionTargetString(String text, int cursor) {
if (completionStartPosition == completionEndPosition) {
completionStartPosition = 0;
- }
- else
- {
+ } else {
completionStartPosition = completionEndPosition - completionStartPosition;
}
resultCompletionText = completionScriptText.substring(
- completionStartPosition , completionEndPosition);
+ completionStartPosition, completionEndPosition);
return resultCompletionText;
}
diff --git a/spark/src/main/resources/interpreter-setting.json b/spark/src/main/resources/interpreter-setting.json
index 2b78d1627cd..0362e929c14 100644
--- a/spark/src/main/resources/interpreter-setting.json
+++ b/spark/src/main/resources/interpreter-setting.json
@@ -128,6 +128,12 @@
"propertyName": null,
"defaultValue": "python",
"description": "Python command to run pyspark with"
+ },
+ "zeppelin.pyspark.precode": {
+ "envName": null,
+ "propertyName": null,
+ "defaultValue": "",
+ "description": "Snippet of code which executes when interpreter initialize"
}
},
"editor": {
diff --git a/spark/src/main/resources/python/zeppelin_pyspark.py b/spark/src/main/resources/python/zeppelin_pyspark.py
index d9c68c28970..d74e825e4a5 100644
--- a/spark/src/main/resources/python/zeppelin_pyspark.py
+++ b/spark/src/main/resources/python/zeppelin_pyspark.py
@@ -282,6 +282,9 @@ def getCompletion(self, text_value):
z = PyZeppelinContext(intp.getZeppelinContext())
z._setup_matplotlib()
+# Place to set precode if precode is set
+#precode#
+
while True :
req = intp.getStatements()
try:
diff --git a/spark/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java b/spark/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java
index 36975126418..78e85727dc5 100644
--- a/spark/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java
+++ b/spark/src/test/java/org/apache/zeppelin/spark/PySparkInterpreterTest.java
@@ -17,26 +17,39 @@
package org.apache.zeppelin.spark;
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Properties;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
import org.apache.zeppelin.display.AngularObjectRegistry;
import org.apache.zeppelin.display.GUI;
-import org.apache.zeppelin.interpreter.*;
+import org.apache.zeppelin.interpreter.Interpreter;
+import org.apache.zeppelin.interpreter.InterpreterContext;
+import org.apache.zeppelin.interpreter.InterpreterContextRunner;
+import org.apache.zeppelin.interpreter.InterpreterGroup;
+import org.apache.zeppelin.interpreter.InterpreterOutput;
+import org.apache.zeppelin.interpreter.InterpreterResult;
import org.apache.zeppelin.interpreter.thrift.InterpreterCompletion;
import org.apache.zeppelin.resource.LocalResourcePool;
import org.apache.zeppelin.user.AuthenticationInfo;
-import org.junit.*;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.ClassRule;
+import org.junit.FixMethodOrder;
+import org.junit.Test;
import org.junit.rules.TemporaryFolder;
import org.junit.runners.MethodSorters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-import java.io.IOException;
-import java.util.HashMap;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Properties;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
-import static org.junit.Assert.*;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class PySparkInterpreterTest {
@@ -58,6 +71,7 @@ private static Properties getPySparkTestProperties() throws IOException {
p.setProperty("zeppelin.spark.maxResult", "1000");
p.setProperty("zeppelin.spark.importImplicit", "true");
p.setProperty("zeppelin.pyspark.python", "python");
+ p.setProperty("zeppelin.pyspark.precode", "precodeVar = 'test'");
p.setProperty("zeppelin.dep.localrepo", tmpDir.newFolder().getAbsolutePath());
return p;
}
@@ -123,6 +137,13 @@ public void testCompletion() {
}
}
+ @Test
+ public void testPrecode() {
+ assertEquals(InterpreterResult.Code.SUCCESS, pySparkInterpreter.interpret("print(precodeVar)\n", context).code());
+ }
+
+
+
private class infinityPythonJob implements Runnable {
@Override
public void run() {