From 96756dfe6c4722b7608f93192c06da96f85c9f50 Mon Sep 17 00:00:00 2001 From: Marcel P Date: Wed, 31 May 2017 11:47:10 +0200 Subject: [PATCH] New CustomSSLWebsocketServerFactory Allows you to enable/disable specific protocols and cipher suites --- ...SLServerCustomWebsocketFactoryExample.java | 67 +++ .../org/java_websocket/SSLSocketChannel2.java | 395 ++++++++++++++++++ .../CustomSSLWebSocketServerFactory.java | 67 +++ .../DefaultSSLWebSocketServerFactory.java | 4 +- 4 files changed, 531 insertions(+), 2 deletions(-) create mode 100644 src/main/example/SSLServerCustomWebsocketFactoryExample.java create mode 100644 src/main/java/org/java_websocket/SSLSocketChannel2.java create mode 100644 src/main/java/org/java_websocket/server/CustomSSLWebSocketServerFactory.java diff --git a/src/main/example/SSLServerCustomWebsocketFactoryExample.java b/src/main/example/SSLServerCustomWebsocketFactoryExample.java new file mode 100644 index 000000000..8294171d3 --- /dev/null +++ b/src/main/example/SSLServerCustomWebsocketFactoryExample.java @@ -0,0 +1,67 @@ +import org.java_websocket.WebSocketImpl; +import org.java_websocket.server.CustomSSLWebSocketServerFactory; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.TrustManagerFactory; +import java.io.File; +import java.io.FileInputStream; +import java.security.KeyStore; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Example for using the CustomSSLWebSocketServerFactory to allow just specific cipher suites + */ +public class SSLServerCustomWebsocketFactoryExample { + + /* + * Keystore with certificate created like so (in JKS format): + * + *keytool -genkey -validity 3650 -keystore "keystore.jks" -storepass "storepassword" -keypass "keypassword" -alias "default" -dname "CN=127.0.0.1, OU=MyOrgUnit, O=MyOrg, L=MyCity, S=MyRegion, C=MyCountry" + */ + public static void main(String[] args) throws Exception { + WebSocketImpl.DEBUG = true; + + ChatServer chatserver = new ChatServer(8887); // Firefox does allow multible ssl connection only via port 443 //tested on FF16 + + // load up the key store + String STORETYPE = "JKS"; + String KEYSTORE = "keystore.jks"; + String STOREPASSWORD = "storepassword"; + String KEYPASSWORD = "keypassword"; + + KeyStore ks = KeyStore.getInstance(STORETYPE); + File kf = new File(KEYSTORE); + ks.load(new FileInputStream(kf), STOREPASSWORD.toCharArray()); + + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(ks, KEYPASSWORD.toCharArray()); + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); + tmf.init(ks); + + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + + //Lets remove some ciphers and protocols + SSLEngine engine = sslContext.createSSLEngine(); + List ciphers = new ArrayList( Arrays.asList(engine.getEnabledCipherSuites())); + ciphers.remove("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); + List protocols = new ArrayList( Arrays.asList(engine.getEnabledProtocols())); + protocols.remove("SSLv3"); + CustomSSLWebSocketServerFactory factory = new CustomSSLWebSocketServerFactory(sslContext, protocols.toArray(new String[]{}), ciphers.toArray(new String[]{})); + + // Different example just using specific ciphers and protocols + /* + String[] enabledProtocols = {"TLSv1.2"}; + String[] enabledCipherSuites = {"TLS_RSA_WITH_AES_128_CBC_SHA", "TLS_RSA_WITH_AES_256_CBC_SHA"}; + CustomSSLWebSocketServerFactory factory = new CustomSSLWebSocketServerFactory(sslContext, enabledProtocols,enabledCipherSuites); + */ + chatserver.setWebSocketFactory(factory); + + chatserver.start(); + + } +} diff --git a/src/main/java/org/java_websocket/SSLSocketChannel2.java b/src/main/java/org/java_websocket/SSLSocketChannel2.java new file mode 100644 index 000000000..3739ad902 --- /dev/null +++ b/src/main/java/org/java_websocket/SSLSocketChannel2.java @@ -0,0 +1,395 @@ +/* + Copyright (C) 2003 Alexander Kout + Originally from the jFxp project (http://jfxp.sourceforge.net/). + Copied with permission June 11, 2012 by Femi Omojola (fomojola@ideasynthesis.com). + */ +package org.java_websocket; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; +import java.io.EOFException; +import java.io.IOException; +import java.net.Socket; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Future; + +/** + * Implements the relevant portions of the SocketChannel interface with the SSLEngine wrapper. + */ +public class SSLSocketChannel2 implements ByteChannel, WrappedByteChannel { + /** + * This object is used to feed the {@link SSLEngine}'s wrap and unwrap methods during the handshake phase. + **/ + protected static ByteBuffer emptybuffer = ByteBuffer.allocate( 0 ); + + protected ExecutorService exec; + + protected List> tasks; + + /** raw payload incomming */ + protected ByteBuffer inData; + /** encrypted data outgoing */ + protected ByteBuffer outCrypt; + /** encrypted data incoming */ + protected ByteBuffer inCrypt; + + /** the underlying channel */ + protected SocketChannel socketChannel; + /** used to set interestOP SelectionKey.OP_WRITE for the underlying channel */ + protected SelectionKey selectionKey; + + protected SSLEngine sslEngine; + protected SSLEngineResult readEngineResult; + protected SSLEngineResult writeEngineResult; + + /** + * Should be used to count the buffer allocations. + * But because of #190 where HandshakeStatus.FINISHED is not properly returned by nio wrap/unwrap this variable is used to check whether {@link #createBuffers(SSLSession)} needs to be called. + **/ + protected int bufferallocations = 0; + + public SSLSocketChannel2( SocketChannel channel , SSLEngine sslEngine , ExecutorService exec , SelectionKey key ) throws IOException { + if( channel == null || sslEngine == null || exec == null ) + throw new IllegalArgumentException( "parameter must not be null" ); + + this.socketChannel = channel; + this.sslEngine = sslEngine; + this.exec = exec; + + readEngineResult = writeEngineResult = new SSLEngineResult( Status.BUFFER_UNDERFLOW, sslEngine.getHandshakeStatus(), 0, 0 ); // init to prevent NPEs + + tasks = new ArrayList>( 3 ); + if( key != null ) { + key.interestOps( key.interestOps() | SelectionKey.OP_WRITE ); + this.selectionKey = key; + } + createBuffers( sslEngine.getSession() ); + // kick off handshake + socketChannel.write( wrap( emptybuffer ) );// initializes res + processHandshake(); + } + + private void consumeFutureUninterruptible( Future f ) { + try { + boolean interrupted = false; + while ( true ) { + try { + f.get(); + break; + } catch ( InterruptedException e ) { + interrupted = true; + } + } + if( interrupted ) + Thread.currentThread().interrupt(); + } catch ( ExecutionException e ) { + throw new RuntimeException( e ); + } + } + + /** + * This method will do whatever necessary to process the sslengine handshake. + * Thats why it's called both from the {@link #read(ByteBuffer)} and {@link #write(ByteBuffer)} + **/ + private synchronized void processHandshake() throws IOException { + if( sslEngine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING ) + return; // since this may be called either from a reading or a writing thread and because this method is synchronized it is necessary to double check if we are still handshaking. + if( !tasks.isEmpty() ) { + Iterator> it = tasks.iterator(); + while ( it.hasNext() ) { + Future f = it.next(); + if( f.isDone() ) { + it.remove(); + } else { + if( isBlocking() ) + consumeFutureUninterruptible( f ); + return; + } + } + } + + if( sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_UNWRAP ) { + if( !isBlocking() || readEngineResult.getStatus() == Status.BUFFER_UNDERFLOW ) { + inCrypt.compact(); + int read = socketChannel.read( inCrypt ); + if( read == -1 ) { + throw new IOException( "connection closed unexpectedly by peer" ); + } + inCrypt.flip(); + } + inData.compact(); + unwrap(); + if( readEngineResult.getHandshakeStatus() == HandshakeStatus.FINISHED ) { + createBuffers( sslEngine.getSession() ); + return; + } + } + consumeDelegatedTasks(); + if( tasks.isEmpty() || sslEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_WRAP ) { + socketChannel.write( wrap( emptybuffer ) ); + if( writeEngineResult.getHandshakeStatus() == HandshakeStatus.FINISHED ) { + createBuffers( sslEngine.getSession() ); + return; + } + } + assert ( sslEngine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING );// this function could only leave NOT_HANDSHAKING after createBuffers was called unless #190 occurs which means that nio wrap/unwrap never return HandshakeStatus.FINISHED + + bufferallocations = 1; // look at variable declaration why this line exists and #190. Without this line buffers would not be be recreated when #190 AND a rehandshake occur. + } + private synchronized ByteBuffer wrap( ByteBuffer b ) throws SSLException { + outCrypt.compact(); + writeEngineResult = sslEngine.wrap( b, outCrypt ); + outCrypt.flip(); + return outCrypt; + } + + /** + * performs the unwrap operation by unwrapping from {@link #inCrypt} to {@link #inData} + **/ + private synchronized ByteBuffer unwrap() throws SSLException { + int rem; + //There are some ssl test suites, which get around the selector.select() call, which cause an infinite unwrap and 100% cpu usage (see #459 and #458) + if(readEngineResult.getStatus() == SSLEngineResult.Status.CLOSED && sslEngine.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING){ + try { + close(); + } catch (IOException e) { + //Not really interesting + } + } + do { + rem = inData.remaining(); + readEngineResult = sslEngine.unwrap( inCrypt, inData ); + } while ( readEngineResult.getStatus() == SSLEngineResult.Status.OK && ( rem != inData.remaining() || sslEngine.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP ) ); + inData.flip(); + return inData; + } + + protected void consumeDelegatedTasks() { + Runnable task; + while ( ( task = sslEngine.getDelegatedTask() ) != null ) { + tasks.add( exec.submit( task ) ); + // task.run(); + } + } + + protected void createBuffers( SSLSession session ) { + int netBufferMax = session.getPacketBufferSize(); + int appBufferMax = Math.max(session.getApplicationBufferSize(), netBufferMax); + + if( inData == null ) { + inData = ByteBuffer.allocate( appBufferMax ); + outCrypt = ByteBuffer.allocate( netBufferMax ); + inCrypt = ByteBuffer.allocate( netBufferMax ); + } else { + if( inData.capacity() != appBufferMax ) + inData = ByteBuffer.allocate( appBufferMax ); + if( outCrypt.capacity() != netBufferMax ) + outCrypt = ByteBuffer.allocate( netBufferMax ); + if( inCrypt.capacity() != netBufferMax ) + inCrypt = ByteBuffer.allocate( netBufferMax ); + } + inData.rewind(); + inData.flip(); + inCrypt.rewind(); + inCrypt.flip(); + outCrypt.rewind(); + outCrypt.flip(); + bufferallocations++; + } + + public int write( ByteBuffer src ) throws IOException { + if( !isHandShakeComplete() ) { + processHandshake(); + return 0; + } + // assert ( bufferallocations > 1 ); //see #190 + //if( bufferallocations <= 1 ) { + // createBuffers( sslEngine.getSession() ); + //} + int num = socketChannel.write( wrap( src ) ); + if (writeEngineResult.getStatus() == SSLEngineResult.Status.CLOSED) { + throw new EOFException("Connection is closed"); + } + return num; + + } + + /** + * Blocks when in blocking mode until at least one byte has been decoded.
+ * When not in blocking mode 0 may be returned. + * + * @return the number of bytes read. + **/ + public int read(ByteBuffer dst) throws IOException { + while (true) { + if (!dst.hasRemaining()) + return 0; + if (!isHandShakeComplete()) { + if (isBlocking()) { + while (!isHandShakeComplete()) { + processHandshake(); + } + } else { + processHandshake(); + if (!isHandShakeComplete()) { + return 0; + } + } + } + // assert ( bufferallocations > 1 ); //see #190 + //if( bufferallocations <= 1 ) { + // createBuffers( sslEngine.getSession() ); + //} + /* 1. When "dst" is smaller than "inData" readRemaining will fill "dst" with data decoded in a previous read call. + * 2. When "inCrypt" contains more data than "inData" has remaining space, unwrap has to be called on more time(readRemaining) + */ + int purged = readRemaining(dst); + if (purged != 0) + return purged; + + /* We only continue when we really need more data from the network. + * Thats the case if inData is empty or inCrypt holds to less data than necessary for decryption + */ + assert (inData.position() == 0); + inData.clear(); + + if (!inCrypt.hasRemaining()) + inCrypt.clear(); + else + inCrypt.compact(); + + if (isBlocking() || readEngineResult.getStatus() == Status.BUFFER_UNDERFLOW) + if (socketChannel.read(inCrypt) == -1) { + return -1; + } + inCrypt.flip(); + unwrap(); + + int transfered = transfereTo(inData, dst); + if (transfered == 0 && isBlocking()) { + continue; + } + return transfered; + } + } + /** + * {@link #read(ByteBuffer)} may not be to leave all buffers(inData, inCrypt) + **/ + private int readRemaining( ByteBuffer dst ) throws SSLException { + if( inData.hasRemaining() ) { + return transfereTo( inData, dst ); + } + if( !inData.hasRemaining() ) + inData.clear(); + // test if some bytes left from last read (e.g. BUFFER_UNDERFLOW) + if( inCrypt.hasRemaining() ) { + unwrap(); + int amount = transfereTo( inData, dst ); + if (readEngineResult.getStatus() == SSLEngineResult.Status.CLOSED) { + return -1; + } + if( amount > 0 ) + return amount; + } + return 0; + } + + public boolean isConnected() { + return socketChannel.isConnected(); + } + + public void close() throws IOException { + sslEngine.closeOutbound(); + sslEngine.getSession().invalidate(); + if( socketChannel.isOpen() ) + socketChannel.write( wrap( emptybuffer ) );// FIXME what if not all bytes can be written + socketChannel.close(); + } + + private boolean isHandShakeComplete() { + HandshakeStatus status = sslEngine.getHandshakeStatus(); + return status == SSLEngineResult.HandshakeStatus.FINISHED || status == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; + } + + public SelectableChannel configureBlocking( boolean b ) throws IOException { + return socketChannel.configureBlocking( b ); + } + + public boolean connect( SocketAddress remote ) throws IOException { + return socketChannel.connect( remote ); + } + + public boolean finishConnect() throws IOException { + return socketChannel.finishConnect(); + } + + public Socket socket() { + return socketChannel.socket(); + } + + public boolean isInboundDone() { + return sslEngine.isInboundDone(); + } + + @Override + public boolean isOpen() { + return socketChannel.isOpen(); + } + + @Override + public boolean isNeedWrite() { + return outCrypt.hasRemaining() || !isHandShakeComplete(); // FIXME this condition can cause high cpu load during handshaking when network is slow + } + + @Override + public void writeMore() throws IOException { + write( outCrypt ); + } + + @Override + public boolean isNeedRead() { + return inData.hasRemaining() || ( inCrypt.hasRemaining() && readEngineResult.getStatus() != Status.BUFFER_UNDERFLOW && readEngineResult.getStatus() != Status.CLOSED ); + } + + @Override + public int readMore( ByteBuffer dst ) throws SSLException { + return readRemaining( dst ); + } + + private int transfereTo( ByteBuffer from, ByteBuffer to ) { + int fremain = from.remaining(); + int toremain = to.remaining(); + if( fremain > toremain ) { + // FIXME there should be a more efficient transfer method + int limit = Math.min( fremain, toremain ); + for( int i = 0 ; i < limit ; i++ ) { + to.put( from.get() ); + } + return limit; + } else { + to.put( from ); + return fremain; + } + + } + + @Override + public boolean isBlocking() { + return socketChannel.isBlocking(); + } + +} \ No newline at end of file diff --git a/src/main/java/org/java_websocket/server/CustomSSLWebSocketServerFactory.java b/src/main/java/org/java_websocket/server/CustomSSLWebSocketServerFactory.java new file mode 100644 index 000000000..43982da4e --- /dev/null +++ b/src/main/java/org/java_websocket/server/CustomSSLWebSocketServerFactory.java @@ -0,0 +1,67 @@ +package org.java_websocket.server; + +import org.java_websocket.SSLSocketChannel2; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import java.io.IOException; +import java.nio.channels.ByteChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; + +/** + * WebSocketFactory that can be configured to only support specific protocols and cipher suites. + */ +public class CustomSSLWebSocketServerFactory extends DefaultSSLWebSocketServerFactory { + + /** + * The enabled protocols saved as a String array + */ + private final String[] enabledProtocols; + + /** + * The enabled ciphersuites saved as a String array + */ + private final String[] enabledCiphersuites; + + /** + * New CustomSSLWebSocketServerFactory configured to only support given protocols and given cipher suites. + * + * @param sslContext - can not be null + * @param enabledProtocols - only these protocols are enabled, when null default settings will be used. + * @param enabledCiphersuites - only these cipher suites are enabled, when null default settings will be used. + */ + public CustomSSLWebSocketServerFactory(SSLContext sslContext, String[] enabledProtocols, String[] enabledCiphersuites) { + this(sslContext, Executors.newSingleThreadScheduledExecutor(), enabledProtocols, enabledCiphersuites); + } + + /** + * New CustomSSLWebSocketServerFactory configured to only support given protocols and given cipher suites. + * + * @param sslContext - can not be null + * @param executerService - can not be null + * @param enabledProtocols - only these protocols are enabled, when null default settings will be used. + * @param enabledCiphersuites - only these cipher suites are enabled, when null default settings will be used. + */ + public CustomSSLWebSocketServerFactory(SSLContext sslContext, ExecutorService executerService, String[] enabledProtocols, String[] enabledCiphersuites) { + super(sslContext, executerService); + this.enabledProtocols = enabledProtocols; + this.enabledCiphersuites = enabledCiphersuites; + } + + @Override + public ByteChannel wrapChannel(SocketChannel channel, SelectionKey key) throws IOException { + SSLEngine e = sslcontext.createSSLEngine(); + if (enabledProtocols != null) { + e.setEnabledProtocols(enabledProtocols); + } + if (enabledCiphersuites != null) { + e.setEnabledCipherSuites(enabledCiphersuites); + } + e.setUseClientMode(false); + return new SSLSocketChannel2(channel, e, exec, key); + } + +} \ No newline at end of file diff --git a/src/main/java/org/java_websocket/server/DefaultSSLWebSocketServerFactory.java b/src/main/java/org/java_websocket/server/DefaultSSLWebSocketServerFactory.java index f3124d082..71b6790be 100644 --- a/src/main/java/org/java_websocket/server/DefaultSSLWebSocketServerFactory.java +++ b/src/main/java/org/java_websocket/server/DefaultSSLWebSocketServerFactory.java @@ -12,7 +12,7 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; -import org.java_websocket.SSLSocketChannel; +import org.java_websocket.SSLSocketChannel2; import org.java_websocket.WebSocketAdapter; import org.java_websocket.WebSocketImpl; import org.java_websocket.drafts.Draft; @@ -46,7 +46,7 @@ public ByteChannel wrapChannel( SocketChannel channel, SelectionKey key ) throws ciphers.remove("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); e.setEnabledCipherSuites( ciphers.toArray(new String[]{})); e.setUseClientMode( false ); - return new SSLSocketChannel( channel, e, exec, key ); + return new SSLSocketChannel2( channel, e, exec, key ); } @Override