diff --git a/src/main/java/com/rethinkdb/net/Connection.java b/src/main/java/com/rethinkdb/net/Connection.java index 1765d385..21504e61 100644 --- a/src/main/java/com/rethinkdb/net/Connection.java +++ b/src/main/java/com/rethinkdb/net/Connection.java @@ -540,9 +540,8 @@ public Builder(@NotNull URI uri) { return this; } - public @NotNull Builder certFile(@Nullable InputStream val) { - sslContext = Crypto.readCertFile(val); - return this; + public @NotNull Builder certFile(@NotNull InputStream val) { + return sslContext(Internals.readCertFile(val)); } public @NotNull Builder sslContext(@Nullable SSLContext val) { diff --git a/src/main/java/com/rethinkdb/net/Crypto.java b/src/main/java/com/rethinkdb/net/Crypto.java deleted file mode 100644 index 6a2e4eee..00000000 --- a/src/main/java/com/rethinkdb/net/Crypto.java +++ /dev/null @@ -1,161 +0,0 @@ -package com.rethinkdb.net; - -import com.rethinkdb.gen.exc.ReqlDriverError; -import org.jetbrains.annotations.Nullable; - -import javax.crypto.Mac; -import javax.crypto.SecretKeyFactory; -import javax.crypto.spec.PBEKeySpec; -import javax.crypto.spec.SecretKeySpec; -import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManagerFactory; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.security.*; -import java.security.cert.CertificateException; -import java.security.cert.CertificateFactory; -import java.security.cert.X509Certificate; -import java.security.spec.InvalidKeySpecException; -import java.util.Arrays; -import java.util.Base64; -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -class Crypto { - private static final String DEFAULT_SSL_PROTOCOL = "TLSv1.2"; - private static final String HMAC_SHA_256 = "HmacSHA256"; - private static final String PBKDF2_ALGORITHM = "PBKDF2WithHmacSHA256"; - - private static final Base64.Encoder encoder = Base64.getEncoder(); - private static final Base64.Decoder decoder = Base64.getDecoder(); - private static final SecureRandom secureRandom = new SecureRandom(); - private static final Map pbkdf2Cache = new ConcurrentHashMap<>(); - private static final int NONCE_BYTES = 18; - - private static class PasswordLookup { - final byte[] password; - final byte[] salt; - final int iterations; - - PasswordLookup(byte[] password, byte[] salt, int iterations) { - this.password = password; - this.salt = salt; - this.iterations = iterations; - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - PasswordLookup that = (PasswordLookup) o; - - if (iterations != that.iterations) return false; - if (!Arrays.equals(password, that.password)) return false; - return Arrays.equals(salt, that.salt); - - } - - @Override - public int hashCode() { - int result = Arrays.hashCode(password); - result = 31 * result + Arrays.hashCode(salt); - result = 31 * result + iterations; - return result; - } - } - - private static byte[] cacheLookup(byte[] password, byte[] salt, int iterations) { - return pbkdf2Cache.get(new PasswordLookup(password, salt, iterations)); - } - - private static void setCache(byte[] password, byte[] salt, int iterations, byte[] result) { - pbkdf2Cache.put(new PasswordLookup(password, salt, iterations), result); - } - - static byte[] sha256(byte[] clientKey) { - try { - MessageDigest digest = MessageDigest.getInstance("SHA-256"); - return digest.digest(clientKey); - } catch (NoSuchAlgorithmException e) { - throw new ReqlDriverError(e); - } - } - - static byte[] hmac(byte[] key, String string) { - try { - Mac mac = Mac.getInstance(HMAC_SHA_256); - SecretKeySpec secretKey = new SecretKeySpec(key, HMAC_SHA_256); - mac.init(secretKey); - return mac.doFinal(string.getBytes(StandardCharsets.UTF_8)); - } catch (InvalidKeyException | NoSuchAlgorithmException e) { - throw new ReqlDriverError(e); - } - } - - static byte[] pbkdf2(byte[] password, byte[] salt, Integer iterationCount) { - final byte[] cachedValue = cacheLookup(password, salt, iterationCount); - if (cachedValue != null) { - return cachedValue; - } - final PBEKeySpec spec = new PBEKeySpec( - new String(password, StandardCharsets.UTF_8).toCharArray(), - salt, iterationCount, 256 - ); - final SecretKeyFactory skf; - try { - skf = SecretKeyFactory.getInstance(PBKDF2_ALGORITHM); - final byte[] calculatedValue = skf.generateSecret(spec).getEncoded(); - setCache(password, salt, iterationCount, calculatedValue); - return calculatedValue; - } catch (NoSuchAlgorithmException | InvalidKeySpecException e) { - throw new ReqlDriverError(e); - } - } - - static String makeNonce() { - byte[] rawNonce = new byte[NONCE_BYTES]; - secureRandom.nextBytes(rawNonce); - return toBase64(rawNonce); - } - - static byte[] xor(byte[] a, byte[] b) { - if (a.length != b.length) { - throw new ReqlDriverError("arrays must be the same length"); - } - byte[] result = new byte[a.length]; - for (int i = 0; i < result.length; i++) { - result[i] = (byte) (a[i] ^ b[i]); - } - return result; - } - - static String toBase64(byte[] bytes) { - return new String(encoder.encode(bytes), StandardCharsets.UTF_8); - } - - static byte[] fromBase64(String string) { - return decoder.decode(string); - } - - static SSLContext readCertFile(@Nullable InputStream certFile) { - try { - final CertificateFactory cf = CertificateFactory.getInstance("X.509"); - final X509Certificate caCert = (X509Certificate) cf.generateCertificate(certFile); - - final TrustManagerFactory tmf = TrustManagerFactory - .getInstance(TrustManagerFactory.getDefaultAlgorithm()); - KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); - ks.load(null); // You don't need the KeyStore instance to come from a file. - ks.setCertificateEntry("caCert", caCert); - tmf.init(ks); - - final SSLContext ssc = SSLContext.getInstance(DEFAULT_SSL_PROTOCOL); - ssc.init(null, tmf.getTrustManagers(), null); - return ssc; - } catch (IOException | CertificateException | NoSuchAlgorithmException | KeyStoreException | KeyManagementException e) { - throw new ReqlDriverError(e); - } - } -} diff --git a/src/main/java/com/rethinkdb/net/HandshakeProtocol.java b/src/main/java/com/rethinkdb/net/HandshakeProtocol.java index 8cc4cee5..b272649a 100644 --- a/src/main/java/com/rethinkdb/net/HandshakeProtocol.java +++ b/src/main/java/com/rethinkdb/net/HandshakeProtocol.java @@ -7,18 +7,27 @@ import com.rethinkdb.utils.Internals; import org.jetbrains.annotations.Nullable; +import javax.crypto.Mac; +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.PBEKeySpec; +import javax.crypto.spec.SecretKeySpec; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; +import java.security.InvalidKeyException; import java.security.MessageDigest; -import java.util.Map; - -import static com.rethinkdb.net.Crypto.*; +import java.security.NoSuchAlgorithmException; +import java.security.SecureRandom; +import java.security.spec.InvalidKeySpecException; +import java.util.*; +import java.util.concurrent.ConcurrentHashMap; /** * Internal class used by {@link Connection#connect()} to do a proper handshake with the server. */ -abstract class HandshakeProtocol { +class HandshakeProtocol { + private static final HandshakeProtocol FINISHED = new HandshakeProtocol(); + public static final Version VERSION = Version.V1_0; public static final Long SUB_PROTOCOL_VERSION = 0L; public static final Protocol PROTOCOL = Protocol.JSON; @@ -26,12 +35,15 @@ abstract class HandshakeProtocol { public static final String CLIENT_KEY = "Client Key"; public static final String SERVER_KEY = "Server Key"; + private HandshakeProtocol() { + } + static void doHandshake(ConnectionSocket socket, String username, String password, Long timeout) { // initialize handshake HandshakeProtocol handshake = new WaitingForProtocolRange(username, password); // Sit in the handshake until it's completed. Exceptions will be thrown if // anything goes wrong. - while (!handshake.isFinished()) { + while (handshake != FINISHED) { ByteBuffer toWrite = handshake.toSend(); if (toWrite != null) { socket.write(toWrite); @@ -40,15 +52,14 @@ static void doHandshake(ConnectionSocket socket, String username, String passwor } } - private HandshakeProtocol() { - } - - protected abstract HandshakeProtocol nextState(String response); - @Nullable - protected abstract ByteBuffer toSend(); + protected ByteBuffer toSend() { + throw new IllegalStateException(); + } - protected abstract boolean isFinished(); + protected HandshakeProtocol nextState(String response) { + throw new IllegalStateException(); + } private static void throwIfFailure(Map json) { if (!(boolean) json.get("success")) { @@ -62,19 +73,23 @@ private static void throwIfFailure(Map json) { } static class WaitingForProtocolRange extends HandshakeProtocol { + private static final SecureRandom secureRandom = new SecureRandom(); + private static final int NONCE_BYTES = 18; + private final String nonce; - private final ByteBuffer message; private final ScramAttributes clientFirstMessageBare; private final byte[] password; WaitingForProtocolRange(String username, String password) { this.password = password.getBytes(StandardCharsets.UTF_8); this.nonce = makeNonce(); - - // We could use a json serializer, but it's fairly straightforward this.clientFirstMessageBare = new ScramAttributes() .username(username) .nonce(nonce); + } + + @Override + public ByteBuffer toSend() { byte[] jsonBytes = ("{" + "\"protocol_version\":" + SUB_PROTOCOL_VERSION + "," + "\"authentication_method\":\"SCRAM-SHA-256\"," + @@ -83,12 +98,10 @@ static class WaitingForProtocolRange extends HandshakeProtocol { // Creating the ByteBuffer over an underlying array makes // it easier to turn into a string later. //return ByteBuffer.wrap(new byte[capacity]).order(ByteOrder.LITTLE_ENDIAN); - // size of VERSION - // json auth payload - // terminating null byte - this.message = ByteBuffer.allocate(Integer.BYTES + // size of VERSION - jsonBytes.length + // json auth payload - 1).order(ByteOrder.LITTLE_ENDIAN).putInt(VERSION.value) + // size of VERSION + json auth payload + terminating null byte + return ByteBuffer.allocate(Integer.BYTES + jsonBytes.length + 1) + .order(ByteOrder.LITTLE_ENDIAN) + .putInt(VERSION.value) .put(jsonBytes) .put(new byte[1]); } @@ -100,22 +113,17 @@ public HandshakeProtocol nextState(String response) { long minVersion = (long) json.get("min_protocol_version"); long maxVersion = (long) json.get("max_protocol_version"); if (SUB_PROTOCOL_VERSION < minVersion || SUB_PROTOCOL_VERSION > maxVersion) { - throw new ReqlDriverError( - "Unsupported protocol version " + SUB_PROTOCOL_VERSION + - ", expected between " + minVersion + " and " + maxVersion); + throw new ReqlDriverError("Unsupported protocol version " + SUB_PROTOCOL_VERSION + ", expected between " + minVersion + " and " + maxVersion); } return new WaitingForAuthResponse(nonce, password, clientFirstMessageBare); } - @Override - public ByteBuffer toSend() { - return message; + static String makeNonce() { + byte[] rawNonce = new byte[NONCE_BYTES]; + secureRandom.nextBytes(rawNonce); + return Base64.getEncoder().encodeToString(rawNonce); } - @Override - public boolean isFinished() { - return false; - } } static class WaitingForAuthResponse extends HandshakeProtocol { @@ -130,116 +138,152 @@ static class WaitingForAuthResponse extends HandshakeProtocol { this.clientFirstMessageBare = clientFirstMessageBare; } + @Override + public ByteBuffer toSend() { + return null; + } + @Override public HandshakeProtocol nextState(String response) { Map json = Internals.readJson(response); throwIfFailure(json); - String serverFirstMessage = (String) json.get("authentication"); - ScramAttributes serverAuth = ScramAttributes.from(serverFirstMessage); - if (!serverAuth.nonce().startsWith(nonce)) { + ScramAttributes serverScram = ScramAttributes.from((String) json.get("authentication")); + if (!Objects.requireNonNull(serverScram._nonce).startsWith(nonce)) { throw new ReqlAuthError("Invalid nonce from server"); } - ScramAttributes clientFinalMessageWithoutProof = new ScramAttributes() + ScramAttributes clientScram = new ScramAttributes() .headerAndChannelBinding("biws") - .nonce(serverAuth.nonce()); + .nonce(serverScram._nonce); // SaltedPassword := Hi(Normalize(password), salt, i) - byte[] saltedPassword = pbkdf2( - password, serverAuth.salt(), serverAuth.iterationCount()); - // ClientKey := HMAC(SaltedPassword, "Client Key") - byte[] clientKey = hmac(saltedPassword, CLIENT_KEY); - // StoredKey := H(ClientKey) + byte[] saltedPassword = PBKDF2.compute(password, serverScram._salt, serverScram._iterationCount); + byte[] clientKey = hmac(saltedPassword, CLIENT_KEY); byte[] storedKey = sha256(clientKey); // AuthMessage := client-first-message-bare + "," + // server-first-message + "," + // client-final-message-without-proof - String authMessage = - clientFirstMessageBare + "," + - serverFirstMessage + "," + - clientFinalMessageWithoutProof; + String authMessage = clientFirstMessageBare + "," + serverScram + "," + clientScram; // ClientSignature := HMAC(StoredKey, AuthMessage) - byte[] clientSignature = hmac(storedKey, authMessage); - // ClientProof := ClientKey XOR ClientSignature - byte[] clientProof = xor(clientKey, clientSignature); - // ServerKey := HMAC(SaltedPassword, "Server Key") - byte[] serverKey = hmac(saltedPassword, SERVER_KEY); - // ServerSignature := HMAC(ServerKey, AuthMessage) + byte[] clientSignature = hmac(storedKey, authMessage); + byte[] clientProof = xor(clientKey, clientSignature); + byte[] serverKey = hmac(saltedPassword, SERVER_KEY); byte[] serverSignature = hmac(serverKey, authMessage); - ScramAttributes auth = clientFinalMessageWithoutProof - .clientProof(clientProof); - byte[] authJson = ("{\"authentication\":\"" + auth + "\"}").getBytes(StandardCharsets.UTF_8); - - ByteBuffer message = ByteBuffer.allocate(authJson.length + 1).order(ByteOrder.LITTLE_ENDIAN) - .put(authJson) - .put(new byte[1]); - return new WaitingForAuthSuccess(serverSignature, message); + return new WaitingForAuthSuccess(serverSignature, clientScram.clientProof(clientProof)); } - @Override - public ByteBuffer toSend() { - return null; + static byte[] sha256(byte[] clientKey) { + try { + return MessageDigest.getInstance("SHA-256").digest(clientKey); + } catch (NoSuchAlgorithmException e) { + throw new ReqlDriverError(e); + } } - @Override - public boolean isFinished() { - return false; + static byte[] hmac(byte[] key, String string) { + try { + Mac mac = Mac.getInstance("HmacSHA256"); + mac.init(new SecretKeySpec(key, "HmacSHA256")); + return mac.doFinal(string.getBytes(StandardCharsets.UTF_8)); + } catch (InvalidKeyException | NoSuchAlgorithmException e) { + throw new ReqlDriverError(e); + } } - } - static class HandshakeSuccess extends HandshakeProtocol { - @Override - public HandshakeProtocol nextState(String response) { - return this; + static byte[] xor(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new ReqlDriverError("arrays must be the same length"); + } + byte[] result = new byte[a.length]; + for (int i = 0; i < result.length; i++) { + result[i] = (byte) (a[i] ^ b[i]); + } + return result; } - @Override - public ByteBuffer toSend() { - return null; - } + private static class PBKDF2 { + static byte[] compute(byte[] password, byte[] salt, Integer iterationCount) { + return cache.computeIfAbsent(new PBKDF2(password, salt, iterationCount), PBKDF2::compute); + } - @Override - public boolean isFinished() { - return true; + private static final Map cache = new ConcurrentHashMap<>(); + + final byte[] password; + final byte[] salt; + final int iterations; + + PBKDF2(byte[] password, byte[] salt, int iterations) { + this.password = password; + this.salt = salt; + this.iterations = iterations; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PBKDF2 that = (PBKDF2) o; + + if (iterations != that.iterations) return false; + if (!Arrays.equals(password, that.password)) return false; + return Arrays.equals(salt, that.salt); + } + + @Override + public int hashCode() { + int result = Arrays.hashCode(password); + result = 31 * result + Arrays.hashCode(salt); + result = 31 * result + iterations; + return result; + } + + public byte[] compute() { + final PBEKeySpec spec = new PBEKeySpec( + new String(password, StandardCharsets.UTF_8).toCharArray(), + salt, iterations, 256 + ); + try { + return SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256").generateSecret(spec).getEncoded(); + } catch (NoSuchAlgorithmException | InvalidKeySpecException e) { + throw new ReqlDriverError(e); + } + } } } static class WaitingForAuthSuccess extends HandshakeProtocol { private final byte[] serverSignature; - private final ByteBuffer message; + private final ScramAttributes auth; - public WaitingForAuthSuccess(byte[] serverSignature, ByteBuffer message) { + public WaitingForAuthSuccess(byte[] serverSignature, ScramAttributes auth) { this.serverSignature = serverSignature; - this.message = message; + this.auth = auth; + } + + @Override + public ByteBuffer toSend() { + byte[] authJson = ("{\"authentication\":\"" + auth + "\"}").getBytes(StandardCharsets.UTF_8); + return ByteBuffer.allocate(authJson.length + 1).order(ByteOrder.LITTLE_ENDIAN) + .put(authJson) + .put(new byte[1]); } @Override public HandshakeProtocol nextState(String response) { Map json = Internals.readJson(response); throwIfFailure(json); - ScramAttributes auth = ScramAttributes - .from((String) json.get("authentication")); - if (!MessageDigest.isEqual(auth.serverSignature(), serverSignature)) { + ScramAttributes auth = ScramAttributes.from((String) json.get("authentication")); + if (!MessageDigest.isEqual(auth._serverSignature, serverSignature)) { throw new ReqlAuthError("Invalid server signature"); } - return new HandshakeSuccess(); - } - - @Override - public ByteBuffer toSend() { - return message; - } - - @Override - public boolean isFinished() { - return false; + return FINISHED; } } @@ -299,7 +343,7 @@ private void setAttribute(String key, String val) { _headerAndChannelBinding = val; break; case "s": - _salt = Crypto.fromBase64(val); + _salt = Base64.getDecoder().decode(val); break; case "i": _iterationCount = Integer.parseInt(val); @@ -308,7 +352,7 @@ private void setAttribute(String key, String val) { _clientProof = val; break; case "v": - _serverSignature = Crypto.fromBase64(val); + _serverSignature = Base64.getDecoder().decode(val); break; case "e": _error = val; @@ -322,24 +366,20 @@ public String toString() { if (_originalString != null) { return _originalString; } - String output = ""; + StringJoiner j = new StringJoiner(","); if (_username != null) { - output += ",n=" + _username; + j.add("n=" + _username); } if (_nonce != null) { - output += ",r=" + _nonce; + j.add("r=" + _nonce); } if (_headerAndChannelBinding != null) { - output += ",c=" + _headerAndChannelBinding; + j.add("c=" + _headerAndChannelBinding); } if (_clientProof != null) { - output += ",p=" + _clientProof; - } - if (output.startsWith(",")) { - return output.substring(1); - } else { - return output; + j.add("p=" + _clientProof); } + return j.toString(); } // Setters with coercion @@ -363,45 +403,8 @@ ScramAttributes headerAndChannelBinding(String hacb) { ScramAttributes clientProof(byte[] clientProof) { ScramAttributes next = ScramAttributes.from(this); - next._clientProof = Crypto.toBase64(clientProof); + next._clientProof = Base64.getEncoder().encodeToString(clientProof); return next; } - - // Getters - String authIdentity() { - return _authIdentity; - } - - String username() { - return _username; - } - - String nonce() { - return _nonce; - } - - String headerAndChannelBinding() { - return _headerAndChannelBinding; - } - - byte[] salt() { - return _salt; - } - - Integer iterationCount() { - return _iterationCount; - } - - String clientProof() { - return _clientProof; - } - - byte[] serverSignature() { - return _serverSignature; - } - - String error() { - return _error; - } } } diff --git a/src/main/java/com/rethinkdb/utils/Internals.java b/src/main/java/com/rethinkdb/utils/Internals.java index 8af8e873..b42a9dc6 100644 --- a/src/main/java/com/rethinkdb/utils/Internals.java +++ b/src/main/java/com/rethinkdb/utils/Internals.java @@ -10,10 +10,21 @@ import com.rethinkdb.gen.exc.ReqlDriverCompileError; import com.rethinkdb.gen.exc.ReqlDriverError; import com.rethinkdb.model.*; +import org.jetbrains.annotations.NotNull; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; import java.io.IOException; +import java.io.InputStream; import java.lang.reflect.Array; import java.nio.ByteBuffer; +import java.security.KeyManagementException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; import java.time.*; import java.time.format.DateTimeFormatter; import java.time.temporal.Temporal; @@ -26,6 +37,7 @@ * Methods and fields are subject to change at any moment. */ public class Internals { + private static final String DEFAULT_SSL_PROTOCOL = "TLSv1.2"; private static final TypeReference> mapTypeRef = Types.mapOf(String.class, Object.class); private static final ObjectMapper internalMapper = new ObjectMapper() .disable(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES) @@ -241,6 +253,25 @@ private static ReqlAst toReqlAst(Object val, int remainingDepth) { return toReqlAst(RethinkDB.getResultMapper().convertValue(val, Map.class), remainingDepth - 1); } + public static SSLContext readCertFile(@NotNull InputStream certFile) { + try { + TrustManagerFactory tmf = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType()); + Certificate cert = CertificateFactory.getInstance("X.509").generateCertificate(certFile); + + ks.load(null); + ks.setCertificateEntry("caCert", cert); + tmf.init(ks); + + SSLContext ctx = SSLContext.getInstance(DEFAULT_SSL_PROTOCOL); + ctx.init(null, tmf.getTrustManagers(), null); + certFile.close(); + return ctx; + } catch (IOException | CertificateException | NoSuchAlgorithmException | KeyStoreException | KeyManagementException e) { + throw new ReqlDriverError(e); + } + } + public static class FormatOptions { public final boolean rawTime; public final boolean rawGroups;