diff --git a/src/main/example/simplelogger.properties b/src/main/example/simplelogger.properties index f2f4dcac3..aa4322e92 100644 --- a/src/main/example/simplelogger.properties +++ b/src/main/example/simplelogger.properties @@ -1,5 +1,5 @@ org.slf4j.simpleLogger.logFile=System.out -org.slf4j.simpleLogger.defaultLogLevel=error +org.slf4j.simpleLogger.defaultLogLevel=trace org.slf4j.simpleLogger.showDateTime=true org.slf4j.simpleLogger.dateTimeFormat=yyyy-MM-dd HH:mm:ss.SSS org.slf4j.simpleLogger.showThreadName=false diff --git a/src/main/java/org/java_websocket/SSLSocketChannel.java b/src/main/java/org/java_websocket/SSLSocketChannel.java index 6fb21dd6e..dcdb6aa89 100644 --- a/src/main/java/org/java_websocket/SSLSocketChannel.java +++ b/src/main/java/org/java_websocket/SSLSocketChannel.java @@ -265,8 +265,15 @@ private boolean doHandshake() throws IOException { peerNetData.clear(); handshakeStatus = engine.getHandshakeStatus(); - while( handshakeStatus != SSLEngineResult.HandshakeStatus.FINISHED && handshakeStatus != SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING ) { + boolean handshakeComplete = false; + while( !handshakeComplete) { switch(handshakeStatus) { + case FINISHED: + handshakeComplete = !this.peerNetData.hasRemaining(); + if (handshakeComplete) + return true; + socketChannel.write(this.peerNetData); + break; case NEED_UNWRAP: if( socketChannel.read( peerNetData ) < 0 ) { if( engine.isInboundDone() && engine.isOutboundDone() ) { @@ -316,6 +323,8 @@ private boolean doHandshake() throws IOException { } break; case NEED_WRAP: + + myNetData.clear(); try { result = engine.wrap( myAppData, myNetData ); @@ -363,8 +372,7 @@ private boolean doHandshake() throws IOException { } handshakeStatus = engine.getHandshakeStatus(); break; - case FINISHED: - break; + case NOT_HANDSHAKING: break; default: diff --git a/src/main/java/org/java_websocket/WebSocketImpl.java b/src/main/java/org/java_websocket/WebSocketImpl.java index 89e1c8a92..c2e82caca 100644 --- a/src/main/java/org/java_websocket/WebSocketImpl.java +++ b/src/main/java/org/java_websocket/WebSocketImpl.java @@ -28,10 +28,7 @@ import org.java_websocket.drafts.Draft; import org.java_websocket.drafts.Draft_6455; import org.java_websocket.enums.*; -import org.java_websocket.exceptions.IncompleteHandshakeException; -import org.java_websocket.exceptions.InvalidDataException; -import org.java_websocket.exceptions.InvalidHandshakeException; -import org.java_websocket.exceptions.WebsocketNotConnectedException; +import org.java_websocket.exceptions.*; import org.java_websocket.framing.CloseFrame; import org.java_websocket.framing.Framedata; import org.java_websocket.framing.PingFrame; @@ -369,6 +366,12 @@ private void decodeFrames( ByteBuffer socketBuffer ) { log.trace( "matched frame: {}" , f ); draft.processFrame( this, f ); } + } catch ( LimitExedeedException e ) { + if (e.getLimit() == Integer.MAX_VALUE) { + log.error("Closing due to invalid size of frame", e); + wsl.onWebsocketError(this, e); + } + close(e); } catch ( InvalidDataException e ) { log.error("Closing due to invalid data in frame", e); wsl.onWebsocketError( this, e ); diff --git a/src/main/java/org/java_websocket/drafts/Draft_6455.java b/src/main/java/org/java_websocket/drafts/Draft_6455.java index 17e2e6da0..6c713f3bd 100644 --- a/src/main/java/org/java_websocket/drafts/Draft_6455.java +++ b/src/main/java/org/java_websocket/drafts/Draft_6455.java @@ -86,7 +86,7 @@ public class Draft_6455 extends Draft { /** * Attribute for the payload of the current continuous frame */ - private List byteBufferList; + private final List byteBufferList; /** * Attribute for the current incomplete frame @@ -98,6 +98,13 @@ public class Draft_6455 extends Draft { */ private final Random reuseableRandom = new Random(); + /** + * Attribute for the maximum allowed size of a frame + * + * @since 1.4.0 + */ + private int maxFrameSize; + /** * Constructor for the websocket protocol specified by RFC 6455 with default extensions * @since 1.3.5 @@ -135,7 +142,32 @@ public Draft_6455( List inputExtensions ) { * @since 1.3.7 */ public Draft_6455( List inputExtensions , List inputProtocols ) { - if (inputExtensions == null || inputProtocols == null) { + this(inputExtensions, inputProtocols, Integer.MAX_VALUE); + } + + /** + * Constructor for the websocket protocol specified by RFC 6455 with custom extensions and protocols + * + * @param inputExtensions the extensions which should be used for this draft + * @param inputMaxFrameSize the maximum allowed size of a frame (the real payload size, decoded frames can be bigger) + * + * @since 1.4.0 + */ + public Draft_6455( List inputExtensions , int inputMaxFrameSize) { + this(inputExtensions, Collections.singletonList( new Protocol( "" )), inputMaxFrameSize); + } + + /** + * Constructor for the websocket protocol specified by RFC 6455 with custom extensions and protocols + * + * @param inputExtensions the extensions which should be used for this draft + * @param inputProtocols the protocols which should be used for this draft + * @param inputMaxFrameSize the maximum allowed size of a frame (the real payload size, decoded frames can be bigger) + * + * @since 1.4.0 + */ + public Draft_6455( List inputExtensions , List inputProtocols, int inputMaxFrameSize ) { + if (inputExtensions == null || inputProtocols == null || inputMaxFrameSize < 1) { throw new IllegalArgumentException(); } knownExtensions = new ArrayList( inputExtensions.size()); @@ -153,6 +185,7 @@ public Draft_6455( List inputExtensions , List inputProto knownExtensions.add( this.knownExtensions.size(), extension ); } knownProtocols.addAll( inputProtocols ); + maxFrameSize = inputMaxFrameSize; } @Override @@ -263,6 +296,17 @@ public IProtocol getProtocol() { return protocol; } + + /** + * Getter for the maximum allowed payload size which is used by this draft + * + * @return the size, which is allowed for the payload + * @since 1.4.0 + */ + public int getMaxFrameSize() { + return maxFrameSize; + } + /** * Getter for all available protocols for this draft * @return the protocols which are enabled for this draft @@ -337,7 +381,7 @@ public Draft copyInstance() { for( IProtocol protocol : getKnownProtocols() ) { newProtocols.add( protocol.copyInstance() ); } - return new Draft_6455( newExtensions, newProtocols ); + return new Draft_6455( newExtensions, newProtocols, maxFrameSize ); } @Override @@ -440,7 +484,7 @@ public Framedata translateSingleFrame( ByteBuffer buffer ) throws IncompleteExce bytes[i] = buffer.get( /*1 + i*/ ); } long length = new BigInteger( bytes ).longValue(); - if( length > Integer.MAX_VALUE ) { + if( length > Integer.MAX_VALUE) { log.trace( "Limit exedeed: Payloadsize is to big..." ); throw new LimitExedeedException( "Payloadsize is to big..." ); } else { @@ -448,6 +492,10 @@ public Framedata translateSingleFrame( ByteBuffer buffer ) throws IncompleteExce } } } + if( payloadlength > maxFrameSize) { + log.trace( "Payload limit reached. Allowed: {0} Current: {1}" , maxFrameSize, payloadlength); + throw new LimitExedeedException( "Payload limit reached.", maxFrameSize ); + } // int maskskeystart = foff + realpacketsize; realpacketsize += ( MASK ? 4 : 0 ); @@ -682,13 +730,15 @@ public void processFrame( WebSocketImpl webSocketImpl, Framedata frame ) throws throw new InvalidDataException( CloseFrame.PROTOCOL_ERROR, "Previous continuous frame sequence not completed." ); } current_continuous_frame = frame; - byteBufferList.add( frame.getPayloadData() ); + addToBufferList(frame.getPayloadData()); + checkBufferLimit(); } else if( frame.isFin() ) { if( current_continuous_frame == null ) { log.trace( "Protocol error: Previous continuous frame sequence not completed." ); throw new InvalidDataException( CloseFrame.PROTOCOL_ERROR, "Continuous frame sequence was not started." ); } - byteBufferList.add( frame.getPayloadData() ); + addToBufferList(frame.getPayloadData()); + checkBufferLimit(); if( current_continuous_frame.getOpcode() == Opcode.TEXT ) { ((FramedataImpl1) current_continuous_frame).setPayload( getPayloadFromByteBufferList() ); ((FramedataImpl1) current_continuous_frame ).isValid(); @@ -709,7 +759,7 @@ public void processFrame( WebSocketImpl webSocketImpl, Framedata frame ) throws } } current_continuous_frame = null; - byteBufferList.clear(); + clearBufferList(); } else if( current_continuous_frame == null ) { log.error( "Protocol error: Continuous frame sequence was not started." ); throw new InvalidDataException( CloseFrame.PROTOCOL_ERROR, "Continuous frame sequence was not started." ); @@ -723,7 +773,7 @@ public void processFrame( WebSocketImpl webSocketImpl, Framedata frame ) throws } //Checking if the current continuous frame contains a correct payload with the other frames combined if( curop == Opcode.CONTINUOUS && current_continuous_frame != null ) { - byteBufferList.add( frame.getPayloadData() ); + addToBufferList(frame.getPayloadData()); } } else if( current_continuous_frame != null ) { log.error( "Protocol error: Continuous frame sequence not completed." ); @@ -748,6 +798,38 @@ public void processFrame( WebSocketImpl webSocketImpl, Framedata frame ) throws } } + /** + * Clear the current bytebuffer list + */ + private void clearBufferList() { + synchronized (byteBufferList) { + byteBufferList.clear(); + } + } + + /** + * Add a payload to the current bytebuffer list + * @param payloadData the new payload + */ + private void addToBufferList(ByteBuffer payloadData) { + synchronized (byteBufferList) { + byteBufferList.add(payloadData); + } + } + + /** + * Check the current size of the buffer and throw an exception if the size is bigger than the max allowed frame size + * @throws LimitExedeedException if the current size is bigger than the allowed size + */ + private void checkBufferLimit() throws LimitExedeedException { + long totalSize = getByteBufferListSize(); + if( totalSize > maxFrameSize ) { + clearBufferList(); + log.trace("Payload limit reached. Allowed: {0} Current: {1}", maxFrameSize, totalSize); + throw new LimitExedeedException(maxFrameSize); + } + } + @Override public CloseHandshakeType getCloseHandshakeType() { return CloseHandshakeType.TWOWAY; @@ -760,24 +842,27 @@ public String toString() { result += " extension: " + getExtension().toString(); if ( getProtocol() != null ) result += " protocol: " + getProtocol().toString(); + result += " max frame size: " + this.maxFrameSize; return result; } @Override - public boolean equals( Object o ) { - if( this == o ) return true; - if( o == null || getClass() != o.getClass() ) return false; + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; - Draft_6455 that = ( Draft_6455 ) o; + Draft_6455 that = (Draft_6455) o; - if( extension != null ? !extension.equals( that.extension ) : that.extension != null ) return false; - return protocol != null ? protocol.equals( that.protocol ) : that.protocol == null; + if (maxFrameSize != that.getMaxFrameSize()) return false; + if (extension != null ? !extension.equals(that.getExtension()) : that.getExtension() != null) return false; + return protocol != null ? protocol.equals(that.getProtocol()) : that.getProtocol() == null; } @Override public int hashCode() { int result = extension != null ? extension.hashCode() : 0; - result = 31 * result + ( protocol != null ? protocol.hashCode() : 0 ); + result = 31 * result + (protocol != null ? protocol.hashCode() : 0); + result = 31 * result + (int) (maxFrameSize ^ (maxFrameSize >>> 32)); return result; } @@ -788,18 +873,32 @@ public int hashCode() { */ private ByteBuffer getPayloadFromByteBufferList() throws LimitExedeedException { long totalSize = 0; - for (ByteBuffer buffer : byteBufferList) { - totalSize +=buffer.limit(); - } - if (totalSize > Integer.MAX_VALUE) { - log.trace( "Payloadsize is to big..."); - throw new LimitExedeedException( "Payloadsize is to big..." ); - } - ByteBuffer resultingByteBuffer = ByteBuffer.allocate( (int) totalSize ); - for (ByteBuffer buffer : byteBufferList) { - resultingByteBuffer.put( buffer ); + ByteBuffer resultingByteBuffer; + synchronized (byteBufferList) { + for (ByteBuffer buffer : byteBufferList) { + totalSize += buffer.limit(); + } + checkBufferLimit(); + resultingByteBuffer = ByteBuffer.allocate( (int) totalSize ); + for (ByteBuffer buffer : byteBufferList) { + resultingByteBuffer.put( buffer ); + } } resultingByteBuffer.flip(); return resultingByteBuffer; } + + /** + * Get the current size of the resulting bytebuffer in the bytebuffer list + * @return the size as long (to not get an integer overflow) + */ + private long getByteBufferListSize() { + long totalSize = 0; + synchronized (byteBufferList) { + for (ByteBuffer buffer : byteBufferList) { + totalSize += buffer.limit(); + } + } + return totalSize; + } } diff --git a/src/main/java/org/java_websocket/exceptions/LimitExedeedException.java b/src/main/java/org/java_websocket/exceptions/LimitExedeedException.java index a4fc08970..397bb2eb4 100644 --- a/src/main/java/org/java_websocket/exceptions/LimitExedeedException.java +++ b/src/main/java/org/java_websocket/exceptions/LimitExedeedException.java @@ -37,13 +37,38 @@ public class LimitExedeedException extends InvalidDataException { */ private static final long serialVersionUID = 6908339749836826785L; + /** + * A closer indication about the limit + */ + private final int limit; + /** * constructor for a LimitExedeedException *

* calling InvalidDataException with closecode TOOBIG */ public LimitExedeedException() { + this(Integer.MAX_VALUE); + } + + /** + * constructor for a LimitExedeedException + *

+ * calling InvalidDataException with closecode TOOBIG + */ + public LimitExedeedException(int limit) { super( CloseFrame.TOOBIG); + this.limit = limit; + } + + /** + * constructor for a LimitExedeedException + *

+ * calling InvalidDataException with closecode TOOBIG + */ + public LimitExedeedException(String s, int limit) { + super( CloseFrame.TOOBIG, s); + this.limit = limit; } /** @@ -54,7 +79,14 @@ public LimitExedeedException() { * @param s the detail message. */ public LimitExedeedException(String s) { - super( CloseFrame.TOOBIG, s); + this(s, Integer.MAX_VALUE); } + /** + * Get the limit which was hit so this exception was caused + * @return the limit as int + */ + public int getLimit() { + return limit; + } } diff --git a/src/test/java/org/java_websocket/drafts/Draft_6455Test.java b/src/test/java/org/java_websocket/drafts/Draft_6455Test.java index a3c2bfe8a..9a781f8d5 100644 --- a/src/test/java/org/java_websocket/drafts/Draft_6455Test.java +++ b/src/test/java/org/java_websocket/drafts/Draft_6455Test.java @@ -96,6 +96,18 @@ public void testConstructor() throws Exception { } catch ( IllegalArgumentException e ) { //Fine } + try { + Draft_6455 draft_6455 = new Draft_6455( Collections.emptyList(), Collections.emptyList(), -1 ); + fail( "IllegalArgumentException expected" ); + } catch ( IllegalArgumentException e ) { + //Fine + } + try { + Draft_6455 draft_6455 = new Draft_6455( Collections.emptyList(), Collections.emptyList(), 0 ); + fail( "IllegalArgumentException expected" ); + } catch ( IllegalArgumentException e ) { + //Fine + } Draft_6455 draft_6455 = new Draft_6455( Collections.emptyList(), Collections.emptyList() ); assertEquals( 1, draft_6455.getKnownExtensions().size() ); assertEquals( 0, draft_6455.getKnownProtocols().size() ); @@ -159,7 +171,7 @@ public void testCopyInstance() throws Exception { @Test public void testReset() throws Exception { - Draft_6455 draft_6455 = new Draft_6455( Collections.singletonList( new TestExtension() ) ); + Draft_6455 draft_6455 = new Draft_6455( Collections.singletonList( new TestExtension() ), 100 ); draft_6455.acceptHandshakeAsServer( handshakedataProtocolExtension ); List extensionList = new ArrayList( draft_6455.getKnownExtensions() ); List protocolList = new ArrayList( draft_6455.getKnownProtocols() ); @@ -180,17 +192,21 @@ public void testGetCloseHandshakeType() throws Exception { @Test public void testToString() throws Exception { Draft_6455 draft_6455 = new Draft_6455(); - assertEquals( "Draft_6455 extension: DefaultExtension", draft_6455.toString() ); + assertEquals( "Draft_6455 extension: DefaultExtension max frame size: 2147483647", draft_6455.toString() ); draft_6455.acceptHandshakeAsServer( handshakedataProtocolExtension ); - assertEquals( "Draft_6455 extension: DefaultExtension", draft_6455.toString() ); + assertEquals( "Draft_6455 extension: DefaultExtension max frame size: 2147483647", draft_6455.toString() ); draft_6455 = new Draft_6455( Collections.emptyList(), Collections.singletonList( new Protocol( "chat" ) ) ); - assertEquals( "Draft_6455 extension: DefaultExtension", draft_6455.toString() ); + assertEquals( "Draft_6455 extension: DefaultExtension max frame size: 2147483647", draft_6455.toString() ); draft_6455.acceptHandshakeAsServer( handshakedataProtocolExtension ); - assertEquals( "Draft_6455 extension: DefaultExtension protocol: chat", draft_6455.toString() ); + assertEquals( "Draft_6455 extension: DefaultExtension protocol: chat max frame size: 2147483647", draft_6455.toString() ); draft_6455 = new Draft_6455( Collections.singletonList( new TestExtension() ), Collections.singletonList( new Protocol( "chat" ) ) ); - assertEquals( "Draft_6455 extension: DefaultExtension", draft_6455.toString() ); + assertEquals( "Draft_6455 extension: DefaultExtension max frame size: 2147483647", draft_6455.toString() ); + draft_6455.acceptHandshakeAsServer( handshakedataProtocolExtension ); + assertEquals( "Draft_6455 extension: TestExtension protocol: chat max frame size: 2147483647", draft_6455.toString() ); + draft_6455 = new Draft_6455( Collections.emptyList(), Collections.singletonList( new Protocol( "chat" ) ) ,10); + assertEquals( "Draft_6455 extension: DefaultExtension max frame size: 10", draft_6455.toString() ); draft_6455.acceptHandshakeAsServer( handshakedataProtocolExtension ); - assertEquals( "Draft_6455 extension: TestExtension protocol: chat", draft_6455.toString() ); + assertEquals( "Draft_6455 extension: DefaultExtension protocol: chat max frame size: 10", draft_6455.toString() ); } @Test diff --git a/src/test/java/org/java_websocket/example/AutobahnServerTest.java b/src/test/java/org/java_websocket/example/AutobahnServerTest.java index 5cc82f4e3..b63ada900 100644 --- a/src/test/java/org/java_websocket/example/AutobahnServerTest.java +++ b/src/test/java/org/java_websocket/example/AutobahnServerTest.java @@ -38,10 +38,13 @@ public class AutobahnServerTest extends WebSocketServer { - private static int counter = 0; + private static int openCounter = 0; + private static int closeCounter = 0; + private int limit = Integer.MAX_VALUE; - public AutobahnServerTest( int port, Draft d ) throws UnknownHostException { + public AutobahnServerTest(int port, int limit, Draft d) throws UnknownHostException { super( new InetSocketAddress( port ), Collections.singletonList( d ) ); + this.limit = limit; } public AutobahnServerTest( InetSocketAddress address, Draft d ) { @@ -50,13 +53,17 @@ public AutobahnServerTest( InetSocketAddress address, Draft d ) { @Override public void onOpen( WebSocket conn, ClientHandshake handshake ) { - counter++; - System.out.println( "///////////Opened connection number" + counter ); + openCounter++; + System.out.println( "///////////Opened connection number" + openCounter); } @Override public void onClose( WebSocket conn, int code, String reason, boolean remote ) { + closeCounter++; System.out.println( "closed" ); + if (closeCounter >= limit) { + System.exit(0); + } } @Override @@ -81,14 +88,20 @@ public void onMessage( WebSocket conn, ByteBuffer blob ) { } public static void main( String[] args ) throws UnknownHostException { - int port; + int port, limit; try { port = new Integer( args[0] ); } catch ( Exception e ) { System.out.println( "No port specified. Defaulting to 9003" ); port = 9003; } - AutobahnServerTest test = new AutobahnServerTest( port, new Draft_6455() ); + try { + limit = new Integer( args[1] ); + } catch ( Exception e ) { + System.out.println( "No limit specified. Defaulting to MaxInteger" ); + limit = Integer.MAX_VALUE; + } + AutobahnServerTest test = new AutobahnServerTest( port, limit, new Draft_6455() ); test.setConnectionLostTimeout( 0 ); test.start(); } diff --git a/src/test/java/org/java_websocket/issues/AllIssueTests.java b/src/test/java/org/java_websocket/issues/AllIssueTests.java index 617262991..94c01716d 100644 --- a/src/test/java/org/java_websocket/issues/AllIssueTests.java +++ b/src/test/java/org/java_websocket/issues/AllIssueTests.java @@ -33,6 +33,7 @@ org.java_websocket.issues.Issue609Test.class, org.java_websocket.issues.Issue621Test.class, org.java_websocket.issues.Issue580Test.class, + org.java_websocket.issues.Issue598Test.class, org.java_websocket.issues.Issue256Test.class, org.java_websocket.issues.Issue661Test.class, org.java_websocket.issues.Issue666Test.class, diff --git a/src/test/java/org/java_websocket/issues/Issue598Test.java b/src/test/java/org/java_websocket/issues/Issue598Test.java new file mode 100644 index 000000000..e6dd4c184 --- /dev/null +++ b/src/test/java/org/java_websocket/issues/Issue598Test.java @@ -0,0 +1,221 @@ +/* + * Copyright (c) 2010-2018 Nathan Rajlich + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.java_websocket.issues; + +import org.java_websocket.WebSocket; +import org.java_websocket.client.WebSocketClient; +import org.java_websocket.drafts.Draft; +import org.java_websocket.drafts.Draft_6455; +import org.java_websocket.enums.Opcode; +import org.java_websocket.extensions.IExtension; +import org.java_websocket.framing.CloseFrame; +import org.java_websocket.handshake.ClientHandshake; +import org.java_websocket.handshake.ServerHandshake; +import org.java_websocket.protocols.IProtocol; +import org.java_websocket.protocols.Protocol; +import org.java_websocket.server.WebSocketServer; +import org.java_websocket.util.SocketUtil; +import org.junit.Test; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import static org.junit.Assert.assertTrue; + +public class Issue598Test { + + private static List protocolList = Collections.singletonList( new Protocol( "" )); + private static List extensionList = Collections.emptyList(); + + private static void runTestScenario(int testCase) throws Exception { + final CountDownLatch countServerDownLatch = new CountDownLatch( 1 ); + final CountDownLatch countReceiveDownLatch = new CountDownLatch( 1 ); + final CountDownLatch countCloseDownLatch = new CountDownLatch( 1 ); + int port = SocketUtil.getAvailablePort(); + Draft draft = null; + switch (testCase) { + case 0: + case 1: + case 2: + case 3: + draft = new Draft_6455(extensionList, protocolList, 10); + break; + case 4: + case 5: + case 6: + case 7: + draft = new Draft_6455(extensionList, protocolList, 9); + break; + } + WebSocketClient webSocket = new WebSocketClient( new URI( "ws://localhost:" + port )) { + @Override + public void onOpen( ServerHandshake handshakedata ) { + + } + + @Override + public void onMessage( String message ) { + } + + + @Override + public void onClose( int code, String reason, boolean remote ) { + + } + + @Override + public void onError( Exception ex ) { + + } + }; + WebSocketServer server = new WebSocketServer( new InetSocketAddress( port ) , Collections.singletonList(draft)) { + @Override + public void onOpen( WebSocket conn, ClientHandshake handshake ) { + } + + @Override + public void onClose( WebSocket conn, int code, String reason, boolean remote ) { + if (code == CloseFrame.TOOBIG) { + countCloseDownLatch.countDown(); + } + } + + @Override + public void onMessage( WebSocket conn, String message ) { + countReceiveDownLatch.countDown(); + } + + @Override + public void onMessage( WebSocket conn, ByteBuffer message ) { + countReceiveDownLatch.countDown(); + } + + @Override + public void onError( WebSocket conn, Exception ex ) { + + } + + @Override + public void onStart() { + countServerDownLatch.countDown(); + } + }; + server.start(); + countServerDownLatch.await(); + webSocket.connectBlocking(); + switch (testCase) { + case 0: + case 4: + byte[] bArray = new byte[10]; + for (byte i = 0; i < 10; i++) { + bArray[i] = i; + } + webSocket.send(ByteBuffer.wrap(bArray)); + if (testCase == 0) + countReceiveDownLatch.await(); + if (testCase == 4) + countCloseDownLatch.await(); + break; + case 2: + case 6: + bArray = "0123456789".getBytes(); + webSocket.send(ByteBuffer.wrap(bArray)); + if (testCase == 2) + countReceiveDownLatch.await(); + if (testCase == 6) + countCloseDownLatch.await(); + break; + case 1: + case 5: + bArray = new byte[2]; + for (byte i = 0; i < 10; i++) { + bArray[i%2] = i; + if (i % 2 == 1) + webSocket.sendFragmentedFrame(Opcode.BINARY, ByteBuffer.wrap(bArray), i == 9); + } + if (testCase == 1) + countReceiveDownLatch.await(); + if (testCase == 5) + countCloseDownLatch.await(); + break; + case 3: + case 7: + for (byte i = 0; i < 10; i++) { + webSocket.sendFragmentedFrame(Opcode.TEXT, ByteBuffer.wrap((Integer.toString(i)).getBytes()), i == 9); + } + if (testCase == 3) + countReceiveDownLatch.await(); + if (testCase == 7) + countCloseDownLatch.await(); + break; + } + server.stop(); + } + + @Test(timeout = 2000) + public void runBelowLimitBytebuffer() throws Exception { + runTestScenario(0); + } + + @Test(timeout = 2000) + public void runBelowSplitLimitBytebuffer() throws Exception { + runTestScenario(1); + } + + @Test(timeout = 2000) + public void runBelowLimitString() throws Exception { + runTestScenario(2); + } + + @Test(timeout = 2000) + public void runBelowSplitLimitString() throws Exception { + runTestScenario(3); + } + + @Test(timeout = 2000) + public void runAboveLimitBytebuffer() throws Exception { + runTestScenario(4); + } + + @Test(timeout = 2000) + public void runAboveSplitLimitBytebuffer() throws Exception { + runTestScenario(5); + } + + @Test(timeout = 2000) + public void runAboveLimitString() throws Exception { + runTestScenario(6); + } + + @Test(timeout = 2000) + public void runAboveSplitLimitString() throws Exception { + runTestScenario(7); + } +}