diff --git a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java index fe0d3912bb8..8c8b60089a2 100644 --- a/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java +++ b/zeppelin-server/src/main/java/org/apache/zeppelin/socket/NotebookServer.java @@ -14,19 +14,17 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.zeppelin.socket; - import java.io.IOException; -import java.net.InetSocketAddress; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.UnknownHostException; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; - import javax.servlet.http.HttpServletRequest; - import org.apache.zeppelin.display.AngularObject; import org.apache.zeppelin.display.AngularObjectRegistry; import org.apache.zeppelin.display.AngularObjectRegistryListener; @@ -46,28 +44,46 @@ import org.quartz.SchedulerException; import org.slf4j.Logger; import org.slf4j.LoggerFactory; - import com.google.common.base.Strings; import com.google.gson.Gson; - /** * Zeppelin websocket service. * * @author anthonycorbacho */ public class NotebookServer extends WebSocketServlet implements - NotebookSocketListener, JobListenerFactory, AngularObjectRegistryListener { - + NotebookSocketListener, JobListenerFactory, AngularObjectRegistryListener { private static final Logger LOG = LoggerFactory - .getLogger(NotebookServer.class); - + .getLogger(NotebookServer.class); Gson gson = new Gson(); - Map> noteSocketMap = new HashMap>(); - List connectedSockets = new LinkedList(); + final Map> noteSocketMap = new HashMap<>(); + final List connectedSockets = new LinkedList<>(); private Notebook notebook() { return ZeppelinServer.notebook; } + @Override + public boolean checkOrigin(HttpServletRequest request, String origin) { + URI sourceUri = null; + String currentHost = null; + + try { + sourceUri = new URI(origin); + currentHost = java.net.InetAddress.getLocalHost().getHostName(); + } catch (UnknownHostException e) { + e.printStackTrace(); + } + catch (URISyntaxException e) { + e.printStackTrace(); + } + + String sourceHost = sourceUri.getHost(); + if (currentHost.equals(sourceHost) || "localhost".equals(sourceHost)) { + return true; + } + + return false; + } @Override public WebSocket doWebSocketConnect(HttpServletRequest req, String protocol) { @@ -153,8 +169,7 @@ public void onClose(NotebookSocket conn, int code, String reason) { } private Message deserializeMessage(String msg) { - Message m = gson.fromJson(msg, Message.class); - return m; + return gson.fromJson(msg, Message.class); } private String serializeMessage(Message m) { @@ -164,14 +179,13 @@ private String serializeMessage(Message m) { private void addConnectionToNote(String noteId, NotebookSocket socket) { synchronized (noteSocketMap) { removeConnectionFromAllNote(socket); // make sure a socket relates only a - // single note. + // single note. List socketList = noteSocketMap.get(noteId); if (socketList == null) { - socketList = new LinkedList(); + socketList = new LinkedList<>(); noteSocketMap.put(noteId, socketList); } - - if (socketList.contains(socket) == false) { + if (!socketList.contains(socket)) { socketList.add(socket); } } @@ -212,6 +226,7 @@ private String getOpenNoteId(NotebookSocket socket) { } } } + return id; } @@ -235,9 +250,7 @@ private void broadcast(String noteId, Message m) { if (socketLists == null || socketLists.size() == 0) { return; } - LOG.info("SEND >> " + m.op); - for (NotebookSocket conn : socketLists) { try { conn.send(serializeMessage(m)); @@ -267,13 +280,14 @@ private void broadcastNote(Note note) { private void broadcastNoteList() { Notebook notebook = notebook(); List notes = notebook.getAllNotes(); - List> notesInfo = new LinkedList>(); + List> notesInfo = new LinkedList<>(); for (Note note : notes) { - Map info = new HashMap(); + Map info = new HashMap<>(); info.put("id", note.id()); info.put("name", note.getName()); notesInfo.add(info); } + broadcastAll(new Message(OP.NOTES_INFO).put("notes", notesInfo)); } @@ -283,8 +297,8 @@ private void sendNote(NotebookSocket conn, Notebook notebook, if (noteId == null) { return; } - Note note = notebook.getNote(noteId); + Note note = notebook.getNote(noteId); if (note != null) { addConnectionToNote(note.id(), conn); conn.send(serializeMessage(new Message(OP.NOTE).put("note", note))); @@ -304,17 +318,17 @@ private void updateNote(WebSocket conn, Notebook notebook, Message fromMessage) if (config == null) { return; } + Note note = notebook.getNote(noteId); if (note != null) { boolean cronUpdated = isCronUpdated(config, note.getConfig()); note.setName(name); note.setConfig(config); - if (cronUpdated) { notebook.refreshCron(note.id()); } - note.persist(); + note.persist(); broadcastNote(note); broadcastNoteList(); } @@ -331,17 +345,19 @@ private boolean isCronUpdated(Map configA, } else if (configA.get("cron") != null || configB.get("cron") != null) { cronUpdated = true; } + return cronUpdated; } - private void createNote(WebSocket conn, Notebook notebook, Message message) throws IOException { Note note = notebook.createNote(); note.addParagraph(); // it's an empty note. so add one paragraph if (message != null) { String noteName = (String) message.get("name"); - if (noteName != null && !noteName.isEmpty()) + if (noteName != null && !noteName.isEmpty()){ note.setName(noteName); + } } + note.persist(); broadcastNote(note); broadcastNoteList(); @@ -353,6 +369,7 @@ private void removeNote(WebSocket conn, Notebook notebook, Message fromMessage) if (noteId == null) { return; } + Note note = notebook.getNote(noteId); notebook.removeNote(noteId); removeNote(noteId); @@ -365,6 +382,7 @@ private void updateParagraph(NotebookSocket conn, Notebook notebook, if (paragraphId == null) { return; } + Map params = (Map) fromMessage .get("params"); Map config = (Map) fromMessage @@ -385,6 +403,7 @@ private void removeParagraph(NotebookSocket conn, Notebook notebook, if (paragraphId == null) { return; } + final Note note = notebook.getNote(getOpenNoteId(conn)); /** We dont want to remove the last paragraph */ if (!note.isLastParagraph(paragraphId)) { @@ -400,7 +419,6 @@ private void completion(NotebookSocket conn, Notebook notebook, String buffer = (String) fromMessage.get("buf"); int cursor = (int) Double.parseDouble(fromMessage.get("cursor").toString()); Message resp = new Message(OP.COMPLETION_LIST).put("id", paragraphId); - if (paragraphId == null) { conn.send(serializeMessage(resp)); return; @@ -414,10 +432,10 @@ private void completion(NotebookSocket conn, Notebook notebook, /** * When angular object updated from client - * - * @param conn - * @param notebook - * @param fromMessage + * + * @param conn the web socket. + * @param notebook the notebook. + * @param fromMessage the message. */ private void angularObjectUpdated(WebSocket conn, Notebook notebook, Message fromMessage) { @@ -425,10 +443,8 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, String interpreterGroupId = (String) fromMessage.get("interpreterGroupId"); String varName = (String) fromMessage.get("name"); Object varValue = fromMessage.get("value"); - AngularObject ao = null; boolean global = false; - // propagate change to (Remote) AngularObjectRegistry Note note = notebook.getNote(noteId); if (note != null) { @@ -438,11 +454,9 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, if (setting.getInterpreterGroup() == null) { continue; } - if (interpreterGroupId.equals(setting.getInterpreterGroup().getId())) { AngularObjectRegistry angularObjectRegistry = setting .getInterpreterGroup().getAngularObjectRegistry(); - // first trying to get local registry ao = angularObjectRegistry.get(varName, noteId); if (ao == null) { @@ -460,14 +474,13 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, ao.set(varValue, false); global = false; } - break; } } } if (global) { // broadcast change to all web session that uses related - // interpreter. + // interpreter. for (Note n : notebook.getAllNotes()) { List settings = note.getNoteReplLoader() .getInterpreterSettings(); @@ -475,7 +488,6 @@ private void angularObjectUpdated(WebSocket conn, Notebook notebook, if (setting.getInterpreterGroup() == null) { continue; } - if (interpreterGroupId.equals(setting.getInterpreterGroup().getId())) { AngularObjectRegistry angularObjectRegistry = setting .getInterpreterGroup().getAngularObjectRegistry(); @@ -514,8 +526,7 @@ private void moveParagraph(NotebookSocket conn, Notebook notebook, private void insertParagraph(NotebookSocket conn, Notebook notebook, Message fromMessage) throws IOException { final int index = (int) Double.parseDouble(fromMessage.get("index") - .toString()); - + .toString()); final Note note = notebook.getNote(getOpenNoteId(conn)); note.insertParagraph(index); note.persist(); @@ -540,27 +551,27 @@ private void runParagraph(NotebookSocket conn, Notebook notebook, if (paragraphId == null) { return; } + final Note note = notebook.getNote(getOpenNoteId(conn)); Paragraph p = note.getParagraph(paragraphId); String text = (String) fromMessage.get("paragraph"); p.setText(text); p.setTitle((String) fromMessage.get("title")); Map params = (Map) fromMessage - .get("params"); + .get("params"); p.settings.setParams(params); Map config = (Map) fromMessage - .get("config"); + .get("config"); p.setConfig(config); - // if it's the last paragraph, let's add a new one boolean isTheLastParagraph = note.getLastParagraph().getId() .equals(p.getId()); if (!Strings.isNullOrEmpty(text) && isTheLastParagraph) { note.addParagraph(); } + note.persist(); broadcastNote(note); - try { note.run(paragraphId); } catch (Exception ex) { @@ -581,7 +592,6 @@ private void runParagraph(NotebookSocket conn, Notebook notebook, public static class ParagraphJobListener implements JobListener { private NotebookServer notebookServer; private Note note; - public ParagraphJobListener(NotebookServer notebookServer, Note note) { this.notebookServer = notebookServer; this.note = note; @@ -606,6 +616,7 @@ public void afterStatusChange(Job job, Status before, Status after) { LOG.error("Error", job.getException()); } } + if (job.isTerminated()) { LOG.info("Job {} is finished", job.getId()); try { @@ -614,6 +625,7 @@ public void afterStatusChange(Job job, Status before, Status after) { e.printStackTrace(); } } + notebookServer.broadcastNote(note); } } @@ -622,7 +634,6 @@ public void afterStatusChange(Job job, Status before, Status after) { public JobListener getParagraphJobListener(Note note) { return new ParagraphJobListener(this, note); } - private void pong() { } @@ -667,10 +678,8 @@ public void onUpdate(String interpreterGroupId, AngularObject object) { List intpSettings = note.getNoteReplLoader() .getInterpreterSettings(); - if (intpSettings.isEmpty()) continue; - for (InterpreterSetting setting : intpSettings) { if (setting.getInterpreterGroup().getId().equals(interpreterGroupId)) { broadcast( @@ -699,9 +708,10 @@ public void onRemove(String interpreterGroupId, String name, String noteId) { broadcast( note.id(), new Message(OP.ANGULAR_OBJECT_REMOVE).put("name", name).put( - "noteId", noteId)); + "noteId", noteId)); } } } } } + diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java new file mode 100644 index 00000000000..3ab06f0e2c5 --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/NotebookServerTests.java @@ -0,0 +1,53 @@ +/** + * Created by joelz on 8/6/15. + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.socket; + +import org.apache.zeppelin.notebook.Note; +import org.apache.zeppelin.server.ZeppelinServer; +import org.junit.Assert; +import org.junit.FixMethodOrder; +import org.junit.Test; +import org.junit.runners.MethodSorters; + +import java.io.IOException; +import java.net.UnknownHostException; + +/** + * BASIC Zeppelin rest api tests + * + * + * @author joelz + * + */ + public class NotebookServerTests { + + @Test + public void CheckOrigin() throws UnknownHostException { + NotebookServer server = new NotebookServer(); + Assert.assertTrue(server.checkOrigin(new TestHttpServletRequest(), + "http://" + java.net.InetAddress.getLocalHost().getHostName() + ":8080")); + } + + @Test + public void CheckInvalidOrigin(){ + NotebookServer server = new NotebookServer(); + Assert.assertFalse(server.checkOrigin(new TestHttpServletRequest(), "http://evillocalhost:8080")); + } +} diff --git a/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java new file mode 100644 index 00000000000..9ec54baa95a --- /dev/null +++ b/zeppelin-server/src/test/java/org/apache/zeppelin/socket/TestHttpServletRequest.java @@ -0,0 +1,372 @@ +/** + * Created by joelz on 8/6/15. + * + * + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.zeppelin.socket; + +import javax.servlet.*; +import javax.servlet.http.*; +import java.io.BufferedReader; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.security.Principal; +import java.util.Collection; +import java.util.Enumeration; +import java.util.Locale; +import java.util.Map; + +/** + * Created by joelz on 8/6/15. + * Helps mocking a http servlet request + */ +public class TestHttpServletRequest implements HttpServletRequest { + @Override + public boolean authenticate(HttpServletResponse httpServletResponse) throws IOException, ServletException { + return false; + } + + @Override + public String getAuthType() { + return null; + } + + @Override + public String getContextPath() { + return null; + } + + @Override + public Cookie[] getCookies() { + return new Cookie[0]; + } + + @Override + public long getDateHeader(String s) { + return 0; + } + + @Override + public String getHeader(String s) { + switch (s) { + case "Origin": + return "http://localhost:8080"; + } + + return null; + } + + @Override + public Enumeration getHeaderNames() { + return null; + } + + @Override + public Enumeration getHeaders(String s) { + return null; + } + + @Override + public int getIntHeader(String s) { + return 0; + } + + @Override + public String getMethod() { + return null; + } + + @Override + public Part getPart(String s) throws IOException, ServletException { + return null; + } + + @Override + public Collection getParts() throws IOException, ServletException { + return null; + } + + @Override + public String getPathInfo() { + return null; + } + + @Override + public String getPathTranslated() { + return null; + } + + @Override + public String getQueryString() { + return null; + } + + @Override + public String getRemoteUser() { + return null; + } + + @Override + public String getRequestedSessionId() { + return null; + } + + @Override + public String getRequestURI() { + return null; + } + + @Override + public StringBuffer getRequestURL() { + return null; + } + + @Override + public String getServletPath() { + return null; + } + + @Override + public HttpSession getSession() { + return null; + } + + @Override + public HttpSession getSession(boolean b) { + return null; + } + + @Override + public Principal getUserPrincipal() { + return null; + } + + @Override + public boolean isRequestedSessionIdFromCookie() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromUrl() { + return false; + } + + @Override + public boolean isRequestedSessionIdFromURL() { + return false; + } + + @Override + public boolean isRequestedSessionIdValid() { + return false; + } + + @Override + public boolean isUserInRole(String s) { + return false; + } + + @Override + public void login(String s, String s1) throws ServletException { + + } + + @Override + public void logout() throws ServletException { + + } + + @Override + public AsyncContext getAsyncContext() { + return null; + } + + @Override + public Object getAttribute(String s) { + return null; + } + + @Override + public Enumeration getAttributeNames() { + return null; + } + + @Override + public String getCharacterEncoding() { + return null; + } + + @Override + public int getContentLength() { + return 0; + } + + @Override + public String getContentType() { + return null; + } + + @Override + public DispatcherType getDispatcherType() { + return null; + } + + @Override + public ServletInputStream getInputStream() throws IOException { + return null; + } + + @Override + public String getLocalAddr() { + return null; + } + + @Override + public Locale getLocale() { + return null; + } + + @Override + public Enumeration getLocales() { + return null; + } + + @Override + public String getLocalName() { + return null; + } + + @Override + public int getLocalPort() { + return 0; + } + + @Override + public String getParameter(String s) { + return null; + } + + @Override + public Map getParameterMap() { + return null; + } + + @Override + public Enumeration getParameterNames() { + return null; + } + + @Override + public String[] getParameterValues(String s) { + return new String[0]; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public BufferedReader getReader() throws IOException { + return null; + } + + @Override + public String getRealPath(String s) { + return null; + } + + @Override + public String getRemoteAddr() { + return null; + } + + @Override + public String getRemoteHost() { + return null; + } + + @Override + public int getRemotePort() { + return 0; + } + + @Override + public RequestDispatcher getRequestDispatcher(String s) { + return null; + } + + @Override + public String getScheme() { + return null; + } + + @Override + public String getServerName() { + return null; + } + + @Override + public int getServerPort() { + return 0; + } + + @Override + public ServletContext getServletContext() { + return null; + } + + @Override + public boolean isAsyncStarted() { + return false; + } + + @Override + public boolean isAsyncSupported() { + return false; + } + + @Override + public boolean isSecure() { + return false; + } + + @Override + public void removeAttribute(String s) { + + } + + @Override + public void setAttribute(String s, Object o) { + + } + + @Override + public void setCharacterEncoding(String s) throws UnsupportedEncodingException { + + } + + @Override + public AsyncContext startAsync() { + return null; + } + + @Override + public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) { + return null; + } +}