diff --git a/keystore.jks b/keystore.jks new file mode 100644 index 000000000..743add50e Binary files /dev/null and b/keystore.jks differ diff --git a/src/main/example/simplelogger.properties b/src/main/example/simplelogger.properties index aa4322e92..f2f4dcac3 100644 --- a/src/main/example/simplelogger.properties +++ b/src/main/example/simplelogger.properties @@ -1,5 +1,5 @@ org.slf4j.simpleLogger.logFile=System.out -org.slf4j.simpleLogger.defaultLogLevel=trace +org.slf4j.simpleLogger.defaultLogLevel=error org.slf4j.simpleLogger.showDateTime=true org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss.SSS org.slf4j.simpleLogger.showThreadName=false diff --git a/src/main/java/org/java_websocket/SSLSocketChannel.java b/src/main/java/org/java_websocket/SSLSocketChannel.java index dcdb6aa89..441b482a8 100644 --- a/src/main/java/org/java_websocket/SSLSocketChannel.java +++ b/src/main/java/org/java_websocket/SSLSocketChannel.java @@ -178,7 +178,7 @@ public synchronized int read( ByteBuffer dst ) throws IOException { return ByteBufferUtils.transferByteBuffer( peerAppData, dst ); case BUFFER_OVERFLOW: peerAppData = enlargeApplicationBuffer( peerAppData ); - break; + return read(dst); case CLOSED: closeConnection(); dst.clear(); @@ -323,8 +323,6 @@ private boolean doHandshake() throws IOException { } break; case NEED_WRAP: - - myNetData.clear(); try { result = engine.wrap( myAppData, myNetData ); diff --git a/src/main/java/org/java_websocket/SSLSocketChannel3.java b/src/main/java/org/java_websocket/SSLSocketChannel3.java new file mode 100644 index 000000000..5be17aa7c --- /dev/null +++ b/src/main/java/org/java_websocket/SSLSocketChannel3.java @@ -0,0 +1,480 @@ +/* + * Copyright (c) 2010-2018 Nathan Rajlich + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + * + */ + +package org.java_websocket; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.util.concurrent.ExecutorService; + + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLEngineResult.Status; + + +// TODO: Add more logging? +/** + * A ByteChannel that passes the data through an SSLEngine. + */ +public class SSLSocketChannel3 implements ByteChannel { + private final ByteChannel wrappedChannel; + private final SSLEngine engine; + protected final Logger logger = LoggerFactory.getLogger(SSLSocketChannel3.class); + + + private ByteBuffer inAppData; // cleartext decoded from SSL + private final ByteBuffer outAppData; // cleartext data to send + private ByteBuffer inNetData; // SSL data read from wrappedChannel + private final ByteBuffer outNetData; // SSL data to send on wrappedChannel + + + private boolean closed = false; + private int timeoutMillis = 0; + private Selector selector = null; + + + public void setTimeout(int timeoutMillis) { + this.timeoutMillis = timeoutMillis; + } + + + public int getTimeout() { + return timeoutMillis; + } + + + /** + * Creates a new instance of SSLByteChannel + * + * @param wrappedChannel + * The byte channel on which this ssl channel is built. This channel contains + * encrypted data. + * @param engine + * A SSLEngine instance that will remember SSL current context. Warning, such an + * instance CAN NOT be shared + * @param logger + * Logger for logging. + */ + public SSLSocketChannel3(ByteChannel wrappedChannel, SSLEngine engine, ExecutorService inputExecutor, SelectionKey key) { + this.wrappedChannel = wrappedChannel; + this.engine = engine; + + + SSLSession session = engine.getSession(); + + + inAppData = ByteBuffer.allocate(session.getApplicationBufferSize()); + outAppData = ByteBuffer.allocate(session.getApplicationBufferSize()); + logger.trace("app buffer size=" + session.getApplicationBufferSize()); + + + inNetData = ByteBuffer.allocate(session.getPacketBufferSize()); + outNetData = ByteBuffer.allocate(session.getPacketBufferSize()); + logger.trace("app buffer size=" + session.getPacketBufferSize()); + } + + + /** + * Ends SSL operation and close the wrapped byte channel + * + * @throws java.io.IOException + * May be raised by close operation on wrapped byte channel + */ + public void close() throws java.io.IOException { + if (!closed) { + try { + try { + engine.closeOutbound(); + handleHandshake(wrapAppData()); + if (selector != null) + selector.close(); + } catch (IOException e) { + // do nothing here + } + wrappedChannel.close(); + } finally { + closed = true; + } + } + } + + + /** + * Is the channel open ? + * + * @return true if the channel is still open + */ + public boolean isOpen() { + return !closed; + } + + + /** + * Fill the given buffer with some bytes and return the number of bytes added in the buffer.
+ * This method may return immediately with nothing added in the buffer. This method must be use + * exactly in the same way of ByteChannel read operation, so be careful with buffer position, + * limit, ... Check corresponding javadoc. + * + * @param clientBuffer + * The buffer that will received read bytes + * @return The number of bytes read + * @throws java.io.IOException + * May be raised by ByteChannel read operation + */ + public int read(ByteBuffer clientBuffer) throws IOException { + // first try to copy out anything left over from last time + int bytesCopied = copyOutClientData(clientBuffer); + if (bytesCopied > 0) + return bytesCopied; + + + fillBufferFromEngine(); + bytesCopied = copyOutClientData(clientBuffer); + if (bytesCopied > 0) + return bytesCopied; + + + return -1; + } + + + private void fillBufferFromEngine() throws IOException { + while (true) { + SSLEngineResult ser = unwrapNetData(); + if (ser.bytesProduced() > 0) + return; + + + switch (ser.getStatus()) { + case OK: + break; + + + case CLOSED: + close(); + return; + + + case BUFFER_OVERFLOW: { + int appSize = engine.getSession().getApplicationBufferSize(); + ByteBuffer b = ByteBuffer.allocate(appSize + inAppData.position()); + inAppData.flip(); + b.put(inAppData); + inAppData = b; + continue; // retry operation + } + + + case BUFFER_UNDERFLOW: { + int netSize = engine.getSession().getPacketBufferSize(); + if (netSize > inNetData.capacity()) { + ByteBuffer b = ByteBuffer.allocate(netSize); + inNetData.flip(); + b.put(inNetData); + inNetData = b; + } + + + int rc = timedRead(inNetData, timeoutMillis); + if (rc == 0 && timeoutMillis > 0) { + throw new IOException("Timeout waiting for read (" + timeoutMillis + " milliseconds)"); + } + if (rc == -1) + break; + continue; // retry operation + } + } + + + switch (ser.getHandshakeStatus()) { + case NOT_HANDSHAKING: + return; + + + default: + handleHandshake(ser); + break; + } + } + } + + + private int timedRead(ByteBuffer buf, int timeoutMillis) throws IOException { + if (timeoutMillis <= 0) + return wrappedChannel.read(buf); + + + SelectableChannel ch = (SelectableChannel)wrappedChannel; + + + synchronized (ch) { + SelectionKey key = null; + + + if (selector == null) { + selector = Selector.open(); + } + + + try { + selector.selectNow(); // Needed to clear old key state + ch.configureBlocking(false); + key = ch.register(selector, SelectionKey.OP_READ); + + + selector.select(timeoutMillis); + + + return wrappedChannel.read(buf); + } finally { + if (key != null) + key.cancel(); + ch.configureBlocking(true); + } + } + } + + + /** + * Write remaining bytes of the given byte buffer. This method may return immediately with + * nothing written. This method must be use exactly in the same way of ByteChannel write + * operation, so be careful with buffer position, limit, ... Check corresponding javadoc. + * + * @param clientBuffer + * buffer with remaining bytes to write + * @return The number of bytes written + * @throws java.io.IOException + * May be raised by ByteChannel write operation + */ + public int write(ByteBuffer clientBuffer) throws IOException { + int bytesWritten = 0; + + + while (clientBuffer.remaining() > 0) { + bytesWritten += pushToEngine(clientBuffer); + } + + + return bytesWritten; + } + + + private int pushToEngine(ByteBuffer clientBuffer) throws IOException { + int bytesWritten = 0; + + + while (clientBuffer.remaining() > 0) { + bytesWritten += copyInClientData(clientBuffer); + logger.trace("bytesWritten="+bytesWritten); + + + while (outAppData.position() > 0) { + SSLEngineResult ser = wrapAppData(); + logger.trace("ser.getStatus()="+ser.getStatus()); + logger.trace("ser.getHandshakeStatus()="+ser.getHandshakeStatus()); + logger.trace("app bytes after wrap()="+outAppData.position()); + + + switch (ser.getStatus()) { + case OK: + break; + + + case CLOSED: + pushNetData(); + close(); + return bytesWritten; + + + case BUFFER_OVERFLOW: + continue; + + + case BUFFER_UNDERFLOW: + return bytesWritten; // TODO: handshake needed here? + } + + + switch (ser.getHandshakeStatus()) { + case NOT_HANDSHAKING: + break; + + + default: + handleHandshake(ser); + break; + } + } + } + + + return bytesWritten; + } + + + private void handleHandshake(SSLEngineResult initialSer) throws IOException { + SSLEngineResult ser = initialSer; + + + while (ser.getStatus() != Status.CLOSED) { + switch (ser.getHandshakeStatus()) { + case NEED_TASK: + Runnable task; + + + while ((task = engine.getDelegatedTask()) != null) { + task.run(); + } + + + pushNetData(); + ser = wrapAppData(); + break; + + + case NEED_WRAP: + pushNetData(); + ser = wrapAppData(); + break; + + + case NEED_UNWRAP: + pushNetData(); + if (inNetData.position() == 0) { + int n = wrappedChannel.read(inNetData); + if (n<0) throw new EOFException("SSL wrapped byte channel"); + } + ser = unwrapNetData(); + break; + + + case FINISHED: + case NOT_HANDSHAKING: + return; + } + } + } + + + private SSLEngineResult unwrapNetData() throws SSLException { + SSLEngineResult ser; + inNetData.flip(); + ser = engine.unwrap(inNetData, inAppData); + inNetData.compact(); + return ser; + } + + + private SSLEngineResult wrapAppData() throws IOException { + outAppData.flip(); + + + SSLEngineResult ser = engine.wrap(outAppData, outNetData); + + + outAppData.compact(); + + + pushNetData(); + + + return ser; + } + + + private void pushNetData() throws IOException { + outNetData.flip(); + + + while (outNetData.remaining() > 0) { + wrappedChannel.write(outNetData); + } + + + outNetData.compact(); + } + + + // ------------------------------------------------------------ + + + private int copyInClientData(ByteBuffer clientBuffer) { + if (clientBuffer.remaining() == 0) { + return 0; + } + + + int posBefore; + + + posBefore = clientBuffer.position(); + + + if (clientBuffer.remaining() <= outAppData.remaining()) { + outAppData.put(clientBuffer); + } else { + while (clientBuffer.hasRemaining() && outAppData.hasRemaining()) { + outAppData.put(clientBuffer.get()); + } + } + + + return clientBuffer.position() - posBefore; + } + + + private int copyOutClientData(ByteBuffer clientBuffer) { + inAppData.flip(); + int posBefore = inAppData.position(); + + + if (inAppData.remaining() <= clientBuffer.remaining()) { + clientBuffer.put(inAppData); + } else { + while (clientBuffer.hasRemaining()) { + clientBuffer.put(inAppData.get()); + } + } + + + int posAfter = inAppData.position(); + inAppData.compact(); + + + return posAfter - posBefore; + } +} \ No newline at end of file diff --git a/src/main/java/org/java_websocket/server/DefaultSSLWebSocketServerFactory.java b/src/main/java/org/java_websocket/server/DefaultSSLWebSocketServerFactory.java index 5deddc156..ac39c00de 100644 --- a/src/main/java/org/java_websocket/server/DefaultSSLWebSocketServerFactory.java +++ b/src/main/java/org/java_websocket/server/DefaultSSLWebSocketServerFactory.java @@ -37,10 +37,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; -import org.java_websocket.SSLSocketChannel2; -import org.java_websocket.WebSocketAdapter; -import org.java_websocket.WebSocketImpl; -import org.java_websocket.WebSocketServerFactory; +import org.java_websocket.*; import org.java_websocket.drafts.Draft; public class DefaultSSLWebSocketServerFactory implements WebSocketServerFactory { @@ -71,7 +68,7 @@ public ByteChannel wrapChannel( SocketChannel channel, SelectionKey key ) throws ciphers.remove("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); e.setEnabledCipherSuites( ciphers.toArray( new String[ciphers.size()] ) ); e.setUseClientMode( false ); - return new SSLSocketChannel2( channel, e, exec, key ); + return new SSLSocketChannel3( channel, e, exec, key ); } @Override diff --git a/src/main/java/org/java_websocket/server/WebSocketServer.java b/src/main/java/org/java_websocket/server/WebSocketServer.java index 2f466a5f8..d7fc0a823 100644 --- a/src/main/java/org/java_websocket/server/WebSocketServer.java +++ b/src/main/java/org/java_websocket/server/WebSocketServer.java @@ -300,34 +300,20 @@ public int getPort() { return port; } + /** + * Get the list of active drafts + * @return the available drafts for this server + */ public List getDraft() { return Collections.unmodifiableList( drafts ); } // Runnable IMPLEMENTATION ///////////////////////////////////////////////// public void run() { - synchronized ( this ) { - if( selectorthread != null ) - throw new IllegalStateException( getClass().getName() + " can only be started once." ); - selectorthread = Thread.currentThread(); - if( isclosed.get() ) { - return; - } + if (!doEnsureSingleThread()) { + return; } - selectorthread.setName( "WebSocketSelector-" + selectorthread.getId() ); - try { - server = ServerSocketChannel.open(); - server.configureBlocking( false ); - ServerSocket socket = server.socket(); - socket.setReceiveBufferSize( WebSocketImpl.RCVBUF ); - socket.setReuseAddress( isReuseAddr() ); - socket.bind( address ); - selector = Selector.open(); - server.register( selector, server.validOps() ); - startConnectionLostTimer(); - onStart(); - } catch ( IOException ex ) { - handleFatal( null, ex ); + if (!doSetupSelectorAndServerThread()) { return; } try { @@ -352,98 +338,23 @@ public void run() { conn = null; if( !key.isValid() ) { - // Object o = key.attachment(); continue; } if( key.isAcceptable() ) { - if( !onConnect( key ) ) { - key.cancel(); - continue; - } - - SocketChannel channel = server.accept(); - if(channel==null){ - continue; - } - channel.configureBlocking( false ); - Socket socket = channel.socket(); - socket.setTcpNoDelay( isTcpNoDelay() ); - socket.setKeepAlive( true ); - WebSocketImpl w = wsf.createWebSocket( this, drafts ); - w.setSelectionKey(channel.register( selector, SelectionKey.OP_READ, w )); - try { - w.channel = wsf.wrapChannel( channel, w.getSelectionKey() ); - i.remove(); - allocateBuffers( w ); - continue; - } catch (IOException ex) { - if( w.getSelectionKey() != null ) - w.getSelectionKey().cancel(); - - handleIOException( w.getSelectionKey(), null, ex ); - } + doAccept(key, i); continue; } - if( key.isReadable() ) { - conn = (WebSocketImpl) key.attachment(); - ByteBuffer buf = takeBuffer(); - if(conn.channel == null){ - if( key != null ) - key.cancel(); - - handleIOException( key, conn, new IOException() ); - continue; - } - try { - if( SocketChannelIOHelper.read( buf, conn, conn.channel ) ) { - if( buf.hasRemaining() ) { - conn.inQueue.put( buf ); - queue( conn ); - i.remove(); - if( conn.channel instanceof WrappedByteChannel ) { - if( ( (WrappedByteChannel) conn.channel ).isNeedRead() ) { - iqueue.add( conn ); - } - } - } else - pushBuffer( buf ); - } else { - pushBuffer( buf ); - } - } catch ( IOException e ) { - pushBuffer( buf ); - throw e; - } + if( key.isReadable() && !doRead(key, i)) { + continue; } + if( key.isWritable() ) { - conn = (WebSocketImpl) key.attachment(); - if( SocketChannelIOHelper.batch( conn, conn.channel ) ) { - if( key.isValid() ) - key.interestOps( SelectionKey.OP_READ ); - } + doWrite(key); } } - while ( !iqueue.isEmpty() ) { - conn = iqueue.remove( 0 ); - WrappedByteChannel c = ( (WrappedByteChannel) conn.channel ); - ByteBuffer buf = takeBuffer(); - try { - if( SocketChannelIOHelper.readMore( buf, conn, c ) ) - iqueue.add( conn ); - if( buf.hasRemaining() ) { - conn.inQueue.put( buf ); - queue( conn ); - } else { - pushBuffer( buf ); - } - } catch ( IOException e ) { - pushBuffer( buf ); - throw e; - } - - } + doAdditionalRead(); } catch ( CancelledKeyException e ) { // an other thread may cancel the key } catch ( ClosedByInterruptException e ) { @@ -453,38 +364,202 @@ public void run() { key.cancel(); handleIOException( key, conn, ex ); } catch ( InterruptedException e ) { - return;// FIXME controlled shutdown (e.g. take care of buffermanagement) + // FIXME controlled shutdown (e.g. take care of buffermanagement) + return; } } - } catch ( RuntimeException e ) { // should hopefully never occur handleFatal( null, e ); } finally { - stopConnectionLostTimer(); - if( decoders != null ) { - for( WebSocketWorker w : decoders ) { - w.interrupt(); + doServerShutdown(); + } + } + + /** + * Do an additional read + * @throws InterruptedException thrown by taking a buffer + * @throws IOException if an error happened during read + */ + private void doAdditionalRead() throws InterruptedException, IOException { + WebSocketImpl conn; + while ( !iqueue.isEmpty() ) { + conn = iqueue.remove( 0 ); + WrappedByteChannel c = ( (WrappedByteChannel) conn.channel ); + ByteBuffer buf = takeBuffer(); + try { + if( SocketChannelIOHelper.readMore( buf, conn, c ) ) + iqueue.add( conn ); + if( buf.hasRemaining() ) { + conn.inQueue.put( buf ); + queue( conn ); + } else { + pushBuffer( buf ); } + } catch ( IOException e ) { + pushBuffer( buf ); + throw e; } - if( selector != null ) { - try { - selector.close(); - } catch ( IOException e ) { - log.error( "IOException during selector.close", e ); - onError( null, e ); - } + } + } + + /** + * Execute a accept operation + * @param key the selectionkey to read off + * @param i the iterator for the selection keys + * @throws InterruptedException thrown by taking a buffer + * @throws IOException if an error happened during accept + */ + private void doAccept(SelectionKey key, Iterator i) throws IOException, InterruptedException { + if( !onConnect( key ) ) { + key.cancel(); + return; + } + + SocketChannel channel = server.accept(); + if(channel==null){ + return; + } + channel.configureBlocking( false ); + Socket socket = channel.socket(); + socket.setTcpNoDelay( isTcpNoDelay() ); + socket.setKeepAlive( true ); + WebSocketImpl w = wsf.createWebSocket( this, drafts ); + w.setSelectionKey(channel.register( selector, SelectionKey.OP_READ, w )); + try { + w.channel = wsf.wrapChannel( channel, w.getSelectionKey() ); + i.remove(); + allocateBuffers( w ); + } catch (IOException ex) { + if( w.getSelectionKey() != null ) + w.getSelectionKey().cancel(); + + handleIOException( w.getSelectionKey(), null, ex ); + } + } + + /** + * Execute a read operation + * @param key the selectionkey to read off + * @param i the iterator for the selection keys + * @return true, if the read was successful, or false if there was an error + * @throws InterruptedException thrown by taking a buffer + * @throws IOException if an error happened during read + */ + private boolean doRead(SelectionKey key, Iterator i) throws InterruptedException, IOException { + WebSocketImpl conn = (WebSocketImpl) key.attachment(); + ByteBuffer buf = takeBuffer(); + if(conn.channel == null){ + if( key != null ) + key.cancel(); + + handleIOException( key, conn, new IOException() ); + return false; + } + try { + if( SocketChannelIOHelper.read( buf, conn, conn.channel ) ) { + if( buf.hasRemaining() ) { + conn.inQueue.put( buf ); + queue( conn ); + i.remove(); + if( conn.channel instanceof WrappedByteChannel ) { + if( ( (WrappedByteChannel) conn.channel ).isNeedRead() ) { + iqueue.add( conn ); + } + } + } else + pushBuffer( buf ); + } else { + pushBuffer( buf ); } - if( server != null ) { - try { - server.close(); - } catch ( IOException e ) { - log.error( "IOException during server.close", e ); - onError( null, e ); - } + } catch ( IOException e ) { + pushBuffer( buf ); + throw e; + } + return true; + } + + /** + * Execute a write operation + * @param key the selectionkey to write on + * @throws IOException if an error happened during batch + */ + private void doWrite(SelectionKey key) throws IOException { + WebSocketImpl conn = (WebSocketImpl) key.attachment(); + if( SocketChannelIOHelper.batch( conn, conn.channel ) ) { + if( key.isValid() ) + key.interestOps( SelectionKey.OP_READ ); + } + } + + /** + * Setup the selector thread as well as basic server settings + * @return true, if everything was successful, false if some error happened + */ + private boolean doSetupSelectorAndServerThread() { + selectorthread.setName( "WebSocketSelector-" + selectorthread.getId() ); + try { + server = ServerSocketChannel.open(); + server.configureBlocking( false ); + ServerSocket socket = server.socket(); + socket.setReceiveBufferSize( WebSocketImpl.RCVBUF ); + socket.setReuseAddress( isReuseAddr() ); + socket.bind( address ); + selector = Selector.open(); + server.register( selector, server.validOps() ); + startConnectionLostTimer(); + onStart(); + } catch ( IOException ex ) { + handleFatal( null, ex ); + return false; + } + return true; + } + + /** + * The websocket server can only be started once + * @return true, if the server can be started, false if already a thread is running + */ + private boolean doEnsureSingleThread() { + synchronized ( this ) { + if( selectorthread != null ) + throw new IllegalStateException( getClass().getName() + " can only be started once." ); + selectorthread = Thread.currentThread(); + if( isclosed.get() ) { + return false; + } + } + return true; + } + + /** + * Clean up everything after a shutdown + */ + private void doServerShutdown() { + stopConnectionLostTimer(); + if( decoders != null ) { + for( WebSocketWorker w : decoders ) { + w.interrupt(); + } + } + if( selector != null ) { + try { + selector.close(); + } catch ( IOException e ) { + log.error( "IOException during selector.close", e ); + onError( null, e ); + } + } + if( server != null ) { + try { + server.close(); + } catch ( IOException e ) { + log.error( "IOException during server.close", e ); + onError( null, e ); } } } + protected void allocateBuffers( WebSocket c ) throws InterruptedException { if( queuesize.get() >= 2 * decoders.size() + 1 ) { return; @@ -632,9 +707,7 @@ public ServerHandshakeBuilder onWebsocketHandshakeReceivedAsServer( WebSocket co protected boolean addConnection( WebSocket ws ) { if( !isclosed.get() ) { synchronized ( connections ) { - boolean succ = this.connections.add( ws ); - assert ( succ ); - return succ; + return this.connections.add( ws ); } } else { // This case will happen when a new connection gets ready while the server is already stopping. @@ -698,7 +771,6 @@ public final WebSocketFactory getWebSocketFactory() { * @return Can this new connection be accepted **/ protected boolean onConnect( SelectionKey key ) { - //FIXME return true; } @@ -858,18 +930,7 @@ private void doBroadcast(Object data, Collection clients) { for( WebSocket client : clients ) { if( client != null ) { Draft draft = client.getDraft(); - if( !draftFrames.containsKey( draft ) ) { - List frames = null; - if (sData != null) { - frames = draft.createFrames( sData, false ); - } - if (bData != null) { - frames = draft.createFrames( bData, false ); - } - if (frames != null) { - draftFrames.put(draft, frames); - } - } + fillFrames(draft, draftFrames, sData, bData); try { client.sendFrame( draftFrames.get( draft ) ); } catch ( WebsocketNotConnectedException e ) { @@ -879,6 +940,28 @@ private void doBroadcast(Object data, Collection clients) { } } + /** + * Fills the draftFrames with new data for the broadcast + * @param draft The draft to use + * @param draftFrames The list of frames per draft to fill + * @param sData the string data, can be null + * @param bData the bytebuffer data, can be null + */ + private void fillFrames(Draft draft, Map> draftFrames, String sData, ByteBuffer bData) { + if( !draftFrames.containsKey( draft ) ) { + List frames = null; + if (sData != null) { + frames = draft.createFrames( sData, false ); + } + if (bData != null) { + frames = draft.createFrames( bData, false ); + } + if (frames != null) { + draftFrames.put(draft, frames); + } + } + } + /** * This class is used to process incoming data */ @@ -910,15 +993,8 @@ public void run() { ws = iqueue.take(); buf = ws.inQueue.poll(); assert ( buf != null ); - try { - ws.decode( buf ); - } catch(Exception e){ - log.error("Error while reading from remote connection", e); - } - finally { - ws = null; - pushBuffer( buf ); - } + doDecode(ws, buf); + ws = null; } } catch ( InterruptedException e ) { Thread.currentThread().interrupt(); @@ -926,5 +1002,22 @@ public void run() { handleFatal( ws, e ); } } + + /** + * call ws.decode on the bytebuffer + * @param ws the Websocket + * @param buf the buffer to decode to + * @throws InterruptedException thrown by pushBuffer + */ + private void doDecode(WebSocketImpl ws, ByteBuffer buf) throws InterruptedException { + try { + ws.decode( buf ); + } catch(Exception e){ + log.error("Error while reading from remote connection", e); + } + finally { + pushBuffer( buf ); + } + } } } diff --git a/src/test/java/org/java_websocket/extensions/AllExtensionTests.java b/src/test/java/org/java_websocket/extensions/AllExtensionTests.java index 93f1bd98d..965473203 100644 --- a/src/test/java/org/java_websocket/extensions/AllExtensionTests.java +++ b/src/test/java/org/java_websocket/extensions/AllExtensionTests.java @@ -30,7 +30,8 @@ @RunWith(Suite.class) @Suite.SuiteClasses({ - org.java_websocket.extensions.DefaultExtensionTest.class + org.java_websocket.extensions.DefaultExtensionTest.class, + org.java_websocket.extensions.CompressionExtensionTest.class }) /** * Start all tests for extensuins diff --git a/src/test/java/org/java_websocket/extensions/CompressionExtensionTest.java b/src/test/java/org/java_websocket/extensions/CompressionExtensionTest.java new file mode 100644 index 000000000..efd3bec67 --- /dev/null +++ b/src/test/java/org/java_websocket/extensions/CompressionExtensionTest.java @@ -0,0 +1,76 @@ +package org.java_websocket.extensions; + +import org.java_websocket.framing.PingFrame; +import org.java_websocket.framing.TextFrame; +import org.junit.Test; + +import static org.junit.Assert.fail; + +public class CompressionExtensionTest { + + + @Test + public void testIsFrameValid() { + CustomCompressionExtension customCompressionExtension = new CustomCompressionExtension(); + TextFrame textFrame = new TextFrame(); + try { + customCompressionExtension.isFrameValid( textFrame ); + } catch ( Exception e ) { + fail( "This frame is valid" ); + } + textFrame.setRSV1( true ); + try { + customCompressionExtension.isFrameValid( textFrame ); + } catch ( Exception e ) { + fail( "This frame is valid" ); + } + textFrame.setRSV1( false ); + textFrame.setRSV2( true ); + try { + customCompressionExtension.isFrameValid( textFrame ); + fail( "This frame is not valid" ); + } catch ( Exception e ) { + // + } + textFrame.setRSV2( false ); + textFrame.setRSV3( true ); + try { + customCompressionExtension.isFrameValid( textFrame ); + fail( "This frame is not valid" ); + } catch ( Exception e ) { + // + } + PingFrame pingFrame = new PingFrame(); + try { + customCompressionExtension.isFrameValid( pingFrame ); + } catch ( Exception e ) { + fail( "This frame is valid" ); + } + pingFrame.setRSV1( true ); + try { + customCompressionExtension.isFrameValid( pingFrame ); + fail( "This frame is not valid" ); + } catch ( Exception e ) { + // + } + pingFrame.setRSV1( false ); + pingFrame.setRSV2( true ); + try { + customCompressionExtension.isFrameValid( pingFrame ); + fail( "This frame is not valid" ); + } catch ( Exception e ) { + // + } + pingFrame.setRSV2( false ); + pingFrame.setRSV3( true ); + try { + customCompressionExtension.isFrameValid( pingFrame ); + fail( "This frame is not valid" ); + } catch ( Exception e ) { + // + } + } + + private static class CustomCompressionExtension extends CompressionExtension { + } +} diff --git a/src/test/java/org/java_websocket/issues/Issue713Test.java b/src/test/java/org/java_websocket/issues/Issue713Test.java index 9a4939e71..fe1afd429 100644 --- a/src/test/java/org/java_websocket/issues/Issue713Test.java +++ b/src/test/java/org/java_websocket/issues/Issue713Test.java @@ -35,6 +35,7 @@ import org.junit.Assert; import org.junit.Test; +import java.io.IOException; import java.net.InetSocketAddress; import java.net.URI; import java.net.URISyntaxException; @@ -43,12 +44,56 @@ import java.util.concurrent.CountDownLatch; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class Issue713Test { CountDownLatch countDownLatchString = new CountDownLatch( 10 ); CountDownLatch countDownLatchConnect = new CountDownLatch( 10 ); CountDownLatch countDownLatchBytebuffer = new CountDownLatch( 10 ); + + @Test + public void testIllegalArgument() throws IOException { + WebSocketServer server = new WebSocketServer( new InetSocketAddress( SocketUtil.getAvailablePort() ) ) { + @Override + public void onOpen(WebSocket conn, ClientHandshake handshake) { + + } + + @Override + public void onClose(WebSocket conn, int code, String reason, boolean remote) { + + } + + @Override + public void onMessage(WebSocket conn, String message) { + + } + + @Override + public void onError(WebSocket conn, Exception ex) { + + } + + @Override + public void onStart() { + + } + }; + try { + server.broadcast((byte[]) null, null); + fail("IllegalArgumentException should be thrown"); + } catch (Exception e) { + // OK + } + try { + server.broadcast((String) null, null); + fail("IllegalArgumentException should be thrown"); + } catch (Exception e) { + // OK + } + } + @Test(timeout=2000) public void testIssue() throws Exception { final int port = SocketUtil.getAvailablePort(); @@ -79,12 +124,11 @@ public void onStart() { tw.connect(); } } catch (Exception e) { - Assert.fail("Exception during connect!"); + fail("Exception during connect!"); } } }; server.start(); - countDownLatchConnect.await(); server.broadcast("Hello world!"); countDownLatchString.await(); diff --git a/src/test/java/org/java_websocket/server/AllServerTests.java b/src/test/java/org/java_websocket/server/AllServerTests.java new file mode 100644 index 000000000..e50c32efa --- /dev/null +++ b/src/test/java/org/java_websocket/server/AllServerTests.java @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2010-2018 Nathan Rajlich + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.java_websocket.server; + +import org.junit.runner.RunWith; +import org.junit.runners.Suite; + + +@RunWith(Suite.class) +@Suite.SuiteClasses({ + org.java_websocket.server.DefaultWebSocketServerFactoryTest.class, + org.java_websocket.protocols.ProtoclHandshakeRejectionTest.class +}) +/** + * Start all tests for the server + */ +public class AllServerTests { +} diff --git a/src/test/java/org/java_websocket/server/CustomSSLWebSocketServerFactoryTest.java b/src/test/java/org/java_websocket/server/CustomSSLWebSocketServerFactoryTest.java new file mode 100644 index 000000000..ff8c3c676 --- /dev/null +++ b/src/test/java/org/java_websocket/server/CustomSSLWebSocketServerFactoryTest.java @@ -0,0 +1,160 @@ +package org.java_websocket.server; + +import org.java_websocket.WebSocket; +import org.java_websocket.WebSocketAdapter; +import org.java_websocket.WebSocketImpl; +import org.java_websocket.drafts.Draft; +import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.handshake.Handshakedata; +import org.junit.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.NotYetConnectedException; +import java.nio.channels.SocketChannel; +import java.security.NoSuchAlgorithmException; +import java.util.Collections; +import java.util.concurrent.Executors; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; + +public class CustomSSLWebSocketServerFactoryTest { + + final String[] emptyArray = new String[0]; + @Test + public void testConstructor() throws NoSuchAlgorithmException { + try { + new CustomSSLWebSocketServerFactory(null, null, null); + fail("IllegalArgumentException should be thrown"); + } catch (IllegalArgumentException e) { + // Good + } + try { + new CustomSSLWebSocketServerFactory(null, null, null, null); + fail("IllegalArgumentException should be thrown"); + } catch (IllegalArgumentException e) { + // Good + } + try { + new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), null, null, null); + fail("IllegalArgumentException should be thrown"); + } catch (IllegalArgumentException e) { + } + try { + new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), null, null); + } catch (IllegalArgumentException e) { + fail("IllegalArgumentException should not be thrown"); + } + try { + new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), Executors.newCachedThreadPool(), null, null); + } catch (IllegalArgumentException e) { + fail("IllegalArgumentException should not be thrown"); + } + try { + new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), Executors.newCachedThreadPool(), emptyArray, null); + } catch (IllegalArgumentException e) { + fail("IllegalArgumentException should not be thrown"); + } + try { + new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), Executors.newCachedThreadPool(), null, emptyArray); + } catch (IllegalArgumentException e) { + fail("IllegalArgumentException should not be thrown"); + } + try { + new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), Executors.newCachedThreadPool(), emptyArray, emptyArray); + } catch (IllegalArgumentException e) { + fail("IllegalArgumentException should not be thrown"); + } + } + @Test + public void testCreateWebSocket() throws NoSuchAlgorithmException { + CustomSSLWebSocketServerFactory webSocketServerFactory = new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), null, null); + CustomWebSocketAdapter webSocketAdapter = new CustomWebSocketAdapter(); + WebSocketImpl webSocketImpl = webSocketServerFactory.createWebSocket(webSocketAdapter, new Draft_6455()); + assertNotNull("webSocketImpl != null", webSocketImpl); + webSocketImpl = webSocketServerFactory.createWebSocket(webSocketAdapter, Collections.singletonList(new Draft_6455())); + assertNotNull("webSocketImpl != null", webSocketImpl); + } + + @Test + public void testWrapChannel() throws IOException, NoSuchAlgorithmException { + CustomSSLWebSocketServerFactory webSocketServerFactory = new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), null, null); + SocketChannel channel = SocketChannel.open(); + try { + ByteChannel result = webSocketServerFactory.wrapChannel(channel, null); + } catch (NotYetConnectedException e) { + //We do not really connect + } + channel.close(); + webSocketServerFactory = new CustomSSLWebSocketServerFactory(SSLContext.getDefault(), new String[]{"TLSv1.2"}, new String[]{"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"}); + channel = SocketChannel.open(); + try { + ByteChannel result = webSocketServerFactory.wrapChannel(channel, null); + } catch (NotYetConnectedException e) { + //We do not really connect + } + channel.close(); + } + + @Test + public void testClose() { + DefaultWebSocketServerFactory webSocketServerFactory = new DefaultWebSocketServerFactory(); + webSocketServerFactory.close(); + } + + private static class CustomWebSocketAdapter extends WebSocketAdapter { + @Override + public void onWebsocketMessage(WebSocket conn, String message) { + + } + + @Override + public void onWebsocketMessage(WebSocket conn, ByteBuffer blob) { + + } + + @Override + public void onWebsocketOpen(WebSocket conn, Handshakedata d) { + + } + + @Override + public void onWebsocketClose(WebSocket ws, int code, String reason, boolean remote) { + + } + + @Override + public void onWebsocketClosing(WebSocket ws, int code, String reason, boolean remote) { + + } + + @Override + public void onWebsocketCloseInitiated(WebSocket ws, int code, String reason) { + + } + + @Override + public void onWebsocketError(WebSocket conn, Exception ex) { + + } + + @Override + public void onWriteDemand(WebSocket conn) { + + } + + @Override + public InetSocketAddress getLocalSocketAddress(WebSocket conn) { + return null; + } + + @Override + public InetSocketAddress getRemoteSocketAddress(WebSocket conn) { + return null; + } + } +} diff --git a/src/test/java/org/java_websocket/server/DefaultSSLWebSocketServerFactoryTest.java b/src/test/java/org/java_websocket/server/DefaultSSLWebSocketServerFactoryTest.java new file mode 100644 index 000000000..75f89b932 --- /dev/null +++ b/src/test/java/org/java_websocket/server/DefaultSSLWebSocketServerFactoryTest.java @@ -0,0 +1,139 @@ +package org.java_websocket.server; + +import org.java_websocket.WebSocket; +import org.java_websocket.WebSocketAdapter; +import org.java_websocket.WebSocketImpl; +import org.java_websocket.drafts.Draft; +import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.handshake.Handshakedata; +import org.junit.Test; + +import javax.net.ssl.SSLContext; +import java.io.IOException; +import java.net.*; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.NotYetConnectedException; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; +import java.security.NoSuchAlgorithmException; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.Executors; + +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; + +public class DefaultSSLWebSocketServerFactoryTest { + + @Test + public void testConstructor() throws NoSuchAlgorithmException { + try { + new DefaultSSLWebSocketServerFactory(null); + fail("IllegalArgumentException should be thrown"); + } catch (IllegalArgumentException e) { + // Good + } + try { + new DefaultSSLWebSocketServerFactory(null, null); + fail("IllegalArgumentException should be thrown"); + } catch (IllegalArgumentException e) { + // Good + } + try { + new DefaultSSLWebSocketServerFactory(SSLContext.getDefault(), null); + fail("IllegalArgumentException should be thrown"); + } catch (IllegalArgumentException e) { + } + try { + new DefaultSSLWebSocketServerFactory(SSLContext.getDefault()); + } catch (IllegalArgumentException e) { + fail("IllegalArgumentException should not be thrown"); + } + try { + new DefaultSSLWebSocketServerFactory(SSLContext.getDefault(), Executors.newCachedThreadPool()); + } catch (IllegalArgumentException e) { + fail("IllegalArgumentException should not be thrown"); + } + } + @Test + public void testCreateWebSocket() throws NoSuchAlgorithmException { + DefaultSSLWebSocketServerFactory webSocketServerFactory = new DefaultSSLWebSocketServerFactory(SSLContext.getDefault()); + CustomWebSocketAdapter webSocketAdapter = new CustomWebSocketAdapter(); + WebSocketImpl webSocketImpl = webSocketServerFactory.createWebSocket(webSocketAdapter, new Draft_6455()); + assertNotNull("webSocketImpl != null", webSocketImpl); + webSocketImpl = webSocketServerFactory.createWebSocket(webSocketAdapter, Collections.singletonList(new Draft_6455())); + assertNotNull("webSocketImpl != null", webSocketImpl); + } + + @Test + public void testWrapChannel() throws IOException, NoSuchAlgorithmException { + DefaultSSLWebSocketServerFactory webSocketServerFactory = new DefaultSSLWebSocketServerFactory(SSLContext.getDefault()); + SocketChannel channel = SocketChannel.open(); + try { + ByteChannel result = webSocketServerFactory.wrapChannel(channel, null); + } catch (NotYetConnectedException e) { + //We do not really connect + } + channel.close(); + } + + @Test + public void testClose() { + DefaultWebSocketServerFactory webSocketServerFactory = new DefaultWebSocketServerFactory(); + webSocketServerFactory.close(); + } + + private static class CustomWebSocketAdapter extends WebSocketAdapter { + @Override + public void onWebsocketMessage(WebSocket conn, String message) { + + } + + @Override + public void onWebsocketMessage(WebSocket conn, ByteBuffer blob) { + + } + + @Override + public void onWebsocketOpen(WebSocket conn, Handshakedata d) { + + } + + @Override + public void onWebsocketClose(WebSocket ws, int code, String reason, boolean remote) { + + } + + @Override + public void onWebsocketClosing(WebSocket ws, int code, String reason, boolean remote) { + + } + + @Override + public void onWebsocketCloseInitiated(WebSocket ws, int code, String reason) { + + } + + @Override + public void onWebsocketError(WebSocket conn, Exception ex) { + + } + + @Override + public void onWriteDemand(WebSocket conn) { + + } + + @Override + public InetSocketAddress getLocalSocketAddress(WebSocket conn) { + return null; + } + + @Override + public InetSocketAddress getRemoteSocketAddress(WebSocket conn) { + return null; + } + } +} diff --git a/src/test/java/org/java_websocket/server/DefaultWebSocketServerFactoryTest.java b/src/test/java/org/java_websocket/server/DefaultWebSocketServerFactoryTest.java new file mode 100644 index 000000000..f3f758cd3 --- /dev/null +++ b/src/test/java/org/java_websocket/server/DefaultWebSocketServerFactoryTest.java @@ -0,0 +1,96 @@ +package org.java_websocket.server; + +import org.java_websocket.SocketChannelIOHelper; +import org.java_websocket.WebSocket; +import org.java_websocket.WebSocketAdapter; +import org.java_websocket.WebSocketImpl; +import org.java_websocket.drafts.Draft; +import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.handshake.Handshakedata; +import org.junit.Test; + +import java.net.InetSocketAddress; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.Collections; + +import static org.junit.Assert.*; + +public class DefaultWebSocketServerFactoryTest { + + @Test + public void testCreateWebSocket() { + DefaultWebSocketServerFactory webSocketServerFactory = new DefaultWebSocketServerFactory(); + CustomWebSocketAdapter webSocketAdapter = new CustomWebSocketAdapter(); + WebSocketImpl webSocketImpl = webSocketServerFactory.createWebSocket(webSocketAdapter, new Draft_6455()); + assertNotNull("webSocketImpl != null", webSocketImpl); + webSocketImpl = webSocketServerFactory.createWebSocket(webSocketAdapter, Collections.singletonList(new Draft_6455())); + assertNotNull("webSocketImpl != null", webSocketImpl); + } + + @Test + public void testWrapChannel() { + DefaultWebSocketServerFactory webSocketServerFactory = new DefaultWebSocketServerFactory(); + SocketChannel channel = (new Socket()).getChannel(); + SocketChannel result = webSocketServerFactory.wrapChannel(channel, null); + assertSame("channel == result", channel, result); + } + @Test + public void testClose() { + DefaultWebSocketServerFactory webSocketServerFactory = new DefaultWebSocketServerFactory(); + webSocketServerFactory.close(); + } + + private static class CustomWebSocketAdapter extends WebSocketAdapter { + @Override + public void onWebsocketMessage(WebSocket conn, String message) { + + } + + @Override + public void onWebsocketMessage(WebSocket conn, ByteBuffer blob) { + + } + + @Override + public void onWebsocketOpen(WebSocket conn, Handshakedata d) { + + } + + @Override + public void onWebsocketClose(WebSocket ws, int code, String reason, boolean remote) { + + } + + @Override + public void onWebsocketClosing(WebSocket ws, int code, String reason, boolean remote) { + + } + + @Override + public void onWebsocketCloseInitiated(WebSocket ws, int code, String reason) { + + } + + @Override + public void onWebsocketError(WebSocket conn, Exception ex) { + + } + + @Override + public void onWriteDemand(WebSocket conn) { + + } + + @Override + public InetSocketAddress getLocalSocketAddress(WebSocket conn) { + return null; + } + + @Override + public InetSocketAddress getRemoteSocketAddress(WebSocket conn) { + return null; + } + } +}