From 3cd47b702b09833cfbeecfc8a47ce5bc8d5715be Mon Sep 17 00:00:00 2001 From: old-bear Date: Tue, 10 Apr 2018 14:51:55 +0800 Subject: [PATCH] + Enable SSL request for single endpoint channel --- example/http_c++/http_client.cpp | 2 +- example/multi_threaded_echo_c++/cert.pem | 26 ++ example/multi_threaded_echo_c++/client.cpp | 2 + example/multi_threaded_echo_c++/key.pem | 27 ++ example/multi_threaded_echo_c++/server.cpp | 2 + src/brpc/acceptor.h | 1 + src/brpc/channel.cpp | 25 +- src/brpc/channel.h | 5 + src/brpc/controller.cpp | 6 + src/brpc/controller.h | 13 +- src/brpc/details/naming_service_thread.cpp | 13 +- src/brpc/details/ssl_helper.cpp | 299 ++++++++++++++---- src/brpc/details/ssl_helper.h | 42 ++- src/brpc/global.cpp | 3 + src/brpc/input_messenger.cpp | 2 +- src/brpc/nshead_service.h | 2 - src/brpc/parallel_channel.cpp | 1 + src/brpc/policy/baidu_rpc_protocol.cpp | 23 +- src/brpc/policy/baidu_rpc_protocol.h | 2 +- src/brpc/policy/hulu_pbrpc_protocol.cpp | 21 +- src/brpc/policy/mongo_protocol.cpp | 23 +- src/brpc/policy/nshead_protocol.cpp | 13 +- src/brpc/policy/sofa_pbrpc_protocol.cpp | 18 +- src/brpc/rtmp.cpp | 8 +- src/brpc/selective_channel.cpp | 1 + src/brpc/server.cpp | 16 +- src/brpc/server.h | 76 +---- src/brpc/socket.cpp | 299 +++++++++++------- src/brpc/socket.h | 23 +- src/brpc/socket_inl.h | 1 + src/brpc/socket_map.cpp | 122 ++++++-- src/brpc/socket_map.h | 82 ++++- src/brpc/socket_message.h | 2 + src/brpc/ssl_option.cpp | 38 +++ src/brpc/ssl_option.h | 162 ++++++++++ src/brpc/stream.cpp | 4 +- src/brpc/stream_impl.h | 2 +- src/butil/iobuf.cpp | 103 +++++-- src/butil/iobuf.h | 12 +- test/brpc_channel_unittest.cpp | 8 +- test/brpc_load_balancer_unittest.cpp | 2 + test/brpc_server_unittest.cpp | 102 ------- test/brpc_socket_map_unittest.cpp | 25 +- test/brpc_ssl_unittest.cpp | 334 +++++++++++++++++++++ 44 files changed, 1486 insertions(+), 507 deletions(-) create mode 100644 example/multi_threaded_echo_c++/cert.pem create mode 100644 example/multi_threaded_echo_c++/key.pem mode change 100755 => 100644 src/brpc/controller.cpp create mode 100644 src/brpc/ssl_option.cpp create mode 100644 src/brpc/ssl_option.h create mode 100644 test/brpc_ssl_unittest.cpp diff --git a/example/http_c++/http_client.cpp b/example/http_c++/http_client.cpp index 135f7d9e89..477088ec18 100644 --- a/example/http_c++/http_client.cpp +++ b/example/http_c++/http_client.cpp @@ -38,7 +38,7 @@ int main(int argc, char* argv[]) { GFLAGS_NS::ParseCommandLineFlags(&argc, &argv, true); if (argc != 2) { - LOG(ERROR) << "Usage: ./http_client \"www.foo.com\""; + LOG(ERROR) << "Usage: ./http_client \"http(s)://www.foo.com\""; return -1; } char* url = argv[1]; diff --git a/example/multi_threaded_echo_c++/cert.pem b/example/multi_threaded_echo_c++/cert.pem new file mode 100644 index 0000000000..28bcc21e4b --- /dev/null +++ b/example/multi_threaded_echo_c++/cert.pem @@ -0,0 +1,26 @@ +-----BEGIN CERTIFICATE----- +MIIEUTCCAzmgAwIBAgIBADANBgkqhkiG9w0BAQQFADB9MQswCQYDVQQGEwJDTjER +MA8GA1UECBMIU2hhbmdoYWkxETAPBgNVBAcTCFNoYW5naGFpMQ4wDAYDVQQKEwVC +YWlkdTEMMAoGA1UECxMDSU5GMQwwCgYDVQQDEwNTQVQxHDAaBgkqhkiG9w0BCQEW +DXNhdEBiYWlkdS5jb20wHhcNMTUwNzE2MDMxOTUxWhcNMTgwNTA1MDMxOTUxWjB9 +MQswCQYDVQQGEwJDTjERMA8GA1UECBMIU2hhbmdoYWkxETAPBgNVBAcTCFNoYW5n +aGFpMQ4wDAYDVQQKEwVCYWlkdTEMMAoGA1UECxMDSU5GMQwwCgYDVQQDEwNTQVQx +HDAaBgkqhkiG9w0BCQEWDXNhdEBiYWlkdS5jb20wggEiMA0GCSqGSIb3DQEBAQUA +A4IBDwAwggEKAoIBAQCqdyAeHY39tqY1RYVbfpqZjZlJDtZb04znxjgQrX+mKmLb +mwvXgJojlfn2Qcgp4NKYFqDFb9tU/Gbb436dRvkHyWOz0RPMspR0TTRU1NIY8wRy +0A1LOCgLHsbRJHqktGjylejALdgsspFWyDY9bEfb4oWsnKGzJqcvIDXrPmMOOY4o +pbA9SufSzwRZN7Yzc5jAedpaF9SK78RQXtvV0+JfCUwBsBWPKevRFFUrN7rQBYjP +cgV/HgDuquPrqnESVSYyfEBKZba6cmNb+xzO3cB1brPTtobSXh+0o/0CtRA+2m63 +ODexxCLntgkPm42IYCJLM15xTatcfVX/3LHQ31DrAgMBAAGjgdswgdgwHQYDVR0O +BBYEFGcd7lA//bSAoSC/NbWRx/H+O1zpMIGoBgNVHSMEgaAwgZ2AFGcd7lA//bSA +oSC/NbWRx/H+O1zpoYGBpH8wfTELMAkGA1UEBhMCQ04xETAPBgNVBAgTCFNoYW5n +aGFpMREwDwYDVQQHEwhTaGFuZ2hhaTEOMAwGA1UEChMFQmFpZHUxDDAKBgNVBAsT +A0lORjEMMAoGA1UEAxMDU0FUMRwwGgYJKoZIhvcNAQkBFg1zYXRAYmFpZHUuY29t +ggEAMAwGA1UdEwQFMAMBAf8wDQYJKoZIhvcNAQEEBQADggEBAKfoCn8SpLk3uQyT +X+oygcRWfTeJtN3D5J69NCMJ7wB+QPfpEBPwiqMgdbp4bRJ98H7x5UQsHT+EDOT/ +9OmipomHInFY4W1ew11zNKwuENeRrnZwTcCiVLZsxZsAU41ZeI5Yq+2WdtxnePCR +VL1/NjKOq+WoRdb2nLSNDWgYMkLRVlt32hyzryyrBbmaxUl8BxnPqUiWduMwsZUz +HNpXkoa1xTSd+En1SHYWfMg8BOVuV0I0/fjUUG9AXVqYpuogfbjAvibVNWAmxOfo +fOjCPCGoJC1ET3AxYkgXGwioobz0pK/13k2pV+wu7W4g+6iTfz+hwZbPsUk2a/5I +f6vXFB0= +-----END CERTIFICATE----- diff --git a/example/multi_threaded_echo_c++/client.cpp b/example/multi_threaded_echo_c++/client.cpp index 05d2d62ba9..1c7bb92e84 100644 --- a/example/multi_threaded_echo_c++/client.cpp +++ b/example/multi_threaded_echo_c++/client.cpp @@ -33,6 +33,7 @@ DEFINE_string(load_balancer, "", "The algorithm for load balancing"); DEFINE_int32(timeout_ms, 100, "RPC timeout in milliseconds"); DEFINE_int32(max_retry, 3, "Max retries(not including the first RPC)"); DEFINE_bool(dont_fail, false, "Print fatal when some call failed"); +DEFINE_bool(enable_ssl, false, "Use SSL connection"); DEFINE_int32(dummy_port, -1, "Launch dummy server at this port"); DEFINE_string(http_content_type, "application/json", "Content type of http request"); @@ -94,6 +95,7 @@ int main(int argc, char* argv[]) { // Initialize the channel, NULL means using default options. brpc::ChannelOptions options; + options.ssl_options.enable = FLAGS_enable_ssl; options.protocol = FLAGS_protocol; options.connection_type = FLAGS_connection_type; options.connect_timeout_ms = std::min(FLAGS_timeout_ms / 2, 100); diff --git a/example/multi_threaded_echo_c++/key.pem b/example/multi_threaded_echo_c++/key.pem new file mode 100644 index 0000000000..e3f64d1e17 --- /dev/null +++ b/example/multi_threaded_echo_c++/key.pem @@ -0,0 +1,27 @@ +-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAqncgHh2N/bamNUWFW36amY2ZSQ7WW9OM58Y4EK1/pipi25sL +14CaI5X59kHIKeDSmBagxW/bVPxm2+N+nUb5B8ljs9ETzLKUdE00VNTSGPMEctAN +SzgoCx7G0SR6pLRo8pXowC3YLLKRVsg2PWxH2+KFrJyhsyanLyA16z5jDjmOKKWw +PUrn0s8EWTe2M3OYwHnaWhfUiu/EUF7b1dPiXwlMAbAVjynr0RRVKze60AWIz3IF +fx4A7qrj66pxElUmMnxASmW2unJjW/sczt3AdW6z07aG0l4ftKP9ArUQPtputzg3 +scQi57YJD5uNiGAiSzNecU2rXH1V/9yx0N9Q6wIDAQABAoIBADN3khflnnhKzDXr +To9IU08nRG+dbjT9U16rJ0RJze+SfpSFZHblWiSCZJzoUZHrUkofEt1pn1QyfK/J +KPI9enTSZirlZk/4XwAaS0GNm/1yahZsIIdkZhqtaSO+GtVdrw4HGuXjMZCVPXJx +MocrCSsnYmqyQ9P+SJ3e4Mis5mVllwDiUVlnTIamSSt16qkPdamLSJrxvI4LirQK +9MZWNLoDFpRU1MJxQ/QzrEC3ONTq4j++AfbGzYTmDDtLeM8OSH5o72YXZ2JkaA4c +xCzHFT+NaJYxF7esn/ctzGg50LYl8IF2UQtzOkX2l3l/OktIB1w+jGV6ONb1EWx5 +4zkkzNkCgYEA2EXj7GMsyNE3OYdMw8zrqQKUMON2CNnD+mBseGlr22/bhXtzpqK8 +uNel8WF1ezOnVvNsU8pml/W/mKUu6KQt5JfaDzen3OKjzTABVlbJxwFhPvwAeaIA +q/tmSKyqiCgOMbR7Cq4UEwGf2A9/RII4JEC0/aipRU5srF65OYPUOJcCgYEAycco +DFVG6jUw9w68t/X4f7NT4IYP96hSAqLUPuVz2fWwXKLWEX8JiMI+Ue3PbMz6mPcs +4vMu364u4R3IuzrrI+PRK9iTa/pahBP6eF6ZpbY1ObI8CVLTrqUS9p22rr9lBm8V +EZA9hwcHLYt+PWzaKcsFpbP4+AeY7nBBbL9CAM0CgYAzuJsmeB1ItUgIuQOxu7sM +AzLfcjZTLYkBwreOIGAL7XdJN9nTmw2ZAvGLhWwsF5FIaRSaAUiBxOKaJb7PIhxb +k7kxdHTvjT/xHS7ksAK3VewkvO18KTMR7iBq9ugdgb7LQkc+qZzhYr0QVbxw7Ndy +TAs8sm4wxe2VV13ilFVXZwKBgDfU6ZnwBr1Llo7l/wYQA4CiSDU6IzTt2DNuhrgY +mWPX/cLEM+OHeUXkKYZV/S0n0rd8vWjWzUOLWOFlcmOMPAAkS36MYM5h6aXeOVIR +KwaVUkjyrnYN+xC6EHM41JGp1/RdzECd3sh8A1pw3K92bS9fQ+LD18IZqBFh8lh6 +23KJAoGAe48SwAsaGvqRO61Taww/Wf+YpGc9lnVbCvNFGScYaycPMqaRBUBmz/U3 +QQgpQY8T7JIECbA8sf78SlAZ9x93r0UQ70RekV3WzKAQHfHK8nqTjd3T0+i4aySO +yQpYYCgE24zYO6rQgwrhzI0S4rWe7izDDlg0RmLtQh7Xw+rlkAQ= +-----END RSA PRIVATE KEY----- diff --git a/example/multi_threaded_echo_c++/server.cpp b/example/multi_threaded_echo_c++/server.cpp index caaaaa12b3..f90794ac98 100644 --- a/example/multi_threaded_echo_c++/server.cpp +++ b/example/multi_threaded_echo_c++/server.cpp @@ -82,6 +82,8 @@ int main(int argc, char* argv[]) { // Start the server. brpc::ServerOptions options; + options.ssl_options.default_cert.certificate = "cert.pem"; + options.ssl_options.default_cert.private_key = "key.pem"; options.idle_timeout_sec = FLAGS_idle_timeout_s; options.max_concurrency = FLAGS_max_concurrency; options.internal_port = FLAGS_internal_port; diff --git a/src/brpc/acceptor.h b/src/brpc/acceptor.h index 6b961aa30a..25c12bef8d 100644 --- a/src/brpc/acceptor.h +++ b/src/brpc/acceptor.h @@ -18,6 +18,7 @@ #ifndef BRPC_ACCEPTOR_H #define BRPC_ACCEPTOR_H +#include "bthread/bthread.h" // bthread_t #include "butil/synchronization/condition_variable.h" #include "butil/containers/flat_map.h" #include "brpc/input_messenger.h" diff --git a/src/brpc/channel.cpp b/src/brpc/channel.cpp index cb10fa0b16..27b15a811e 100644 --- a/src/brpc/channel.cpp +++ b/src/brpc/channel.cpp @@ -19,17 +19,17 @@ #include #include #include -#include "butil/time.h" // milliseconds_from_now +#include "butil/time.h" // milliseconds_from_now #include "butil/logging.h" #include "bthread/unstable.h" // bthread_timer_add -#include "brpc/socket_map.h" // SocketMapInsert +#include "brpc/socket_map.h" // SocketMapInsert #include "brpc/compress.h" #include "brpc/global.h" #include "brpc/span.h" #include "brpc/details/load_balancer_with_naming.h" #include "brpc/controller.h" #include "brpc/channel.h" -#include "brpc/details/usercode_backup_pool.h" // TooManyUserCode +#include "brpc/details/usercode_backup_pool.h" // TooManyUserCode #include "brpc/policy/esp_authenticator.h" @@ -62,7 +62,9 @@ Channel::Channel(ProfilerLinker) Channel::~Channel() { if (_server_id != (SocketId)-1) { - SocketMapRemove(_server_address); + SocketMapRemove(SocketMapKey(_server_address, + _options.ssl_options, + _options.auth)); } } @@ -121,6 +123,15 @@ int Channel::InitChannelOptions(const ChannelOptions* options) { if (_options.auth == NULL) { _options.auth = policy::global_esp_authenticator(); } + } else if (_options.protocol == brpc::PROTOCOL_HTTP) { + if (_raw_server_address.compare(0, 5, "https") == 0) { + _options.ssl_options.enable = true; + if (_options.ssl_options.sni_name.empty()) { + int port; + ParseHostAndPortFromURL(_raw_server_address.c_str(), + &_options.ssl_options.sni_name, &port); + } + } } return 0; @@ -152,6 +163,7 @@ int Channel::Init(const char* server_addr_and_port, return -1; } } + _raw_server_address.assign(server_addr_and_port); return Init(point, options); } @@ -174,6 +186,7 @@ int Channel::Init(const char* server_addr, int port, return -1; } } + _raw_server_address.assign(server_addr); return Init(point, options); } @@ -189,7 +202,9 @@ int Channel::Init(butil::EndPoint server_addr_and_port, return -1; } _server_address = server_addr_and_port; - if (SocketMapInsert(server_addr_and_port, &_server_id) != 0) { + if (SocketMapInsert(SocketMapKey(server_addr_and_port, + _options.ssl_options, + _options.auth), &_server_id) != 0) { LOG(ERROR) << "Fail to insert into SocketMap"; return -1; } diff --git a/src/brpc/channel.h b/src/brpc/channel.h index 196202367f..2679872893 100644 --- a/src/brpc/channel.h +++ b/src/brpc/channel.h @@ -24,6 +24,7 @@ #include // std::ostream #include "bthread/errno.h" // Redefine errno #include "butil/intrusive_ptr.hpp" // butil::intrusive_ptr +#include "brpc/ssl_option.h" // ChannelSSLOptions #include "brpc/channel_base.h" // ChannelBase #include "brpc/adaptive_protocol_type.h" // AdaptiveProtocolType #include "brpc/adaptive_connection_type.h" // AdaptiveConnectionType @@ -87,6 +88,9 @@ struct ChannelOptions { // Print a log when above situation happens. // Default: true. bool log_succeed_without_server; + + // SSL related options. Refer to `ChannelSSLOptions' for details + ChannelSSLOptions ssl_options; // Turn on authentication for this channel if `auth' is not NULL. // Note `auth' will not be deleted by channel and must remain valid when @@ -185,6 +189,7 @@ friend class SelectiveChannel; int InitChannelOptions(const ChannelOptions* options); + std::string _raw_server_address; butil::EndPoint _server_address; SocketId _server_id; Protocol::SerializeRequest _serialize_request; diff --git a/src/brpc/controller.cpp b/src/brpc/controller.cpp old mode 100755 new mode 100644 index 2e25c3cb0a..ecf8c47537 --- a/src/brpc/controller.cpp +++ b/src/brpc/controller.cpp @@ -66,6 +66,7 @@ BAIDU_REGISTER_ERRNO(brpc::ERTMPPUBLISHABLE, "RtmpRetryingClientStream is publis BAIDU_REGISTER_ERRNO(brpc::ERTMPCREATESTREAM, "createStream was rejected by the RTMP server"); BAIDU_REGISTER_ERRNO(brpc::EEOF, "Got EOF"); BAIDU_REGISTER_ERRNO(brpc::EUNUSED, "The socket was not needed"); +BAIDU_REGISTER_ERRNO(brpc::ESSL, "SSL related operation failed"); BAIDU_REGISTER_ERRNO(brpc::EINTERNAL, "General internal error"); BAIDU_REGISTER_ERRNO(brpc::ERESPONSE, "Bad response"); @@ -1369,6 +1370,11 @@ bool Controller::is_ssl() const { return s ? (s->ssl_state() == SSL_CONNECTED) : false; } +x509_st* Controller::get_peer_certificate() const { + Socket* s = _current_call.sending_sock.get(); + return s ? s->GetPeerCertificate() : NULL; +} + #if defined(OS_MACOSX) typedef sig_t SignalHandler; #else diff --git a/src/brpc/controller.h b/src/brpc/controller.h index 9ac8b54527..dd733f72a7 100644 --- a/src/brpc/controller.h +++ b/src/brpc/controller.h @@ -46,6 +46,10 @@ #define EAUTH ERPCAUTH #endif +extern "C" { +struct x509_st; +} + namespace brpc { class Span; class Server; @@ -306,6 +310,12 @@ friend void policy::ProcessMongoRequest(InputMessageBase*); // Returns the authenticated result. NULL if there is no authentication const AuthContext* auth_context() const { return _auth_context; } + // Whether the underlying channel is using SSL + bool is_ssl() const; + + // Get the peer certificate, which can be printed by ostream + x509_st* get_peer_certificate() const; + // Mutable header of http response. HttpHeader& http_response() { if (_http_response == NULL) { @@ -380,9 +390,6 @@ friend void policy::ProcessMongoRequest(InputMessageBase*); // Protocol of the request sent by client or received by server. ProtocolType request_protocol() const { return _request_protocol; } - // Whether the underlying channel is using SSL - bool is_ssl() const; - // Resets the Controller to its initial state so that it may be reused in // a new call. Must NOT be called while an RPC is in progress. void Reset() { InternalReset(false); } diff --git a/src/brpc/details/naming_service_thread.cpp b/src/brpc/details/naming_service_thread.cpp index 8f2f87533f..6f9d27426f 100644 --- a/src/brpc/details/naming_service_thread.cpp +++ b/src/brpc/details/naming_service_thread.cpp @@ -58,7 +58,7 @@ NamingServiceThread::Actions::~Actions() { // Remove all sockets from SocketMap for (std::vector::const_iterator it = _last_servers.begin(); it != _last_servers.end(); ++it) { - SocketMapRemove(it->addr); + SocketMapRemove(SocketMapKey(it->addr)); } EndWait(0); } @@ -107,7 +107,10 @@ void NamingServiceThread::Actions::ResetServers( for (size_t i = 0; i < _added.size(); ++i) { ServerNodeWithId tagged_id; tagged_id.node = _added[i]; - CHECK_EQ(SocketMapInsert(_added[i].addr, &tagged_id.id), 0); + // TODO: For each unique SocketMapKey (i.e. SSL settings), insert a new + // Socket. SocketMapKey may be passed through AddWatcher. Make sure + // to pick those Sockets with the right settings during OnAddedServers + CHECK_EQ(SocketMapInsert(SocketMapKey(_added[i].addr), &tagged_id.id), 0); _added_sockets.push_back(tagged_id); } @@ -115,7 +118,7 @@ void NamingServiceThread::Actions::ResetServers( for (size_t i = 0; i < _removed.size(); ++i) { ServerNodeWithId tagged_id; tagged_id.node = _removed[i]; - CHECK_EQ(0, SocketMapFind(_removed[i].addr, &tagged_id.id)); + CHECK_EQ(0, SocketMapFind(SocketMapKey(_removed[i].addr), &tagged_id.id)); _removed_sockets.push_back(tagged_id); } @@ -164,7 +167,9 @@ void NamingServiceThread::Actions::ResetServers( } for (size_t i = 0; i < _removed.size(); ++i) { - SocketMapRemove(_removed[i].addr); + // TODO: Remove all Sockets that have the same address in SocketMapKey.peer + // We may need another data structure to avoid linear cost + SocketMapRemove(SocketMapKey(_removed[i].addr)); } if (!_removed.empty() || !_added.empty()) { diff --git a/src/brpc/details/ssl_helper.cpp b/src/brpc/details/ssl_helper.cpp index 7f0606ac34..e29c32b19b 100644 --- a/src/brpc/details/ssl_helper.cpp +++ b/src/brpc/details/ssl_helper.cpp @@ -22,6 +22,7 @@ #include "butil/unique_ptr.h" #include "butil/logging.h" #include "butil/ssl_compat.h" +#include "butil/string_splitter.h" #include "brpc/socket.h" #include "brpc/details/ssl_helper.h" @@ -59,6 +60,29 @@ const char* SSLStateToString(SSLState s) { return "Bad SSLState"; } +static int ParseSSLProtocols(const std::string& str_protocol) { + int protocol_flag = 0; + butil::StringSplitter sp(str_protocol.data(), + str_protocol.data() + str_protocol.size(), ','); + for (; sp; ++sp) { + butil::StringPiece protocol(sp.field(), sp.length()); + protocol.trim_spaces(); + if (strncasecmp(protocol.data(), "SSLv3", protocol.size()) == 0) { + protocol_flag |= SSLv3; + } else if (strncasecmp(protocol.data(), "TLSv1", protocol.size()) == 0) { + protocol_flag |= TLSv1; + } else if (strncasecmp(protocol.data(), "TLSv1.1", protocol.size()) == 0) { + protocol_flag |= TLSv1_1; + } else if (strncasecmp(protocol.data(), "TLSv1.2", protocol.size()) == 0) { + protocol_flag |= TLSv1_2; + } else { + LOG(ERROR) << "Unknown SSL protocol=" << protocol; + return -1; + } + } + return protocol_flag; +} + std::ostream& operator<<(std::ostream& os, const SSLError& ssl) { char buf[128]; // Should be enough ERR_error_string_n(ssl.error, buf, sizeof(buf)); @@ -105,9 +129,7 @@ static void SSLInfoCallback(const SSL* ssl, int where, int ret) { } if (where & SSL_CB_HANDSHAKE_START) { - if (s->ssl_state() == SSL_CONNECTING) { - s->set_ssl_state(SSL_CONNECTED); - } else if (s->ssl_state() == SSL_CONNECTED) { + if (s->ssl_state() == SSL_CONNECTED) { // Disable renegotiation (CVE-2009-3555) LOG(ERROR) << "Close " << *s << " due to insecure " << "renegotiation detected (CVE-2009-3555)"; @@ -180,7 +202,7 @@ static DH* SSLGetDHCallback(SSL* ssl, int exp, int keylen) { } #endif // OPENSSL_NO_DH -static void ExtractHostnames(X509* x, std::vector* hostnames) { +void ExtractHostnames(X509* x, std::vector* hostnames) { #ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME STACK_OF(GENERAL_NAME)* names = (STACK_OF(GENERAL_NAME)*) X509_get_ext_d2i(x, NID_subject_alt_name, NULL, NULL); @@ -207,7 +229,7 @@ static void ExtractHostnames(X509* x, std::vector* hostnames) { char* str = NULL; X509_NAME_ENTRY* entry = X509_NAME_get_entry(xname, i); const int len = ASN1_STRING_to_UTF8((unsigned char**)&str, - X509_NAME_ENTRY_get_data(entry)); + X509_NAME_ENTRY_get_data(entry)); if (len >= 0) { std::string hostname(str, len); hostnames->push_back(hostname); @@ -216,10 +238,10 @@ static void ExtractHostnames(X509* x, std::vector* hostnames) { } } -struct FreeSSLCTX { - inline void operator()(SSL_CTX* ctx) const { - if (ctx != NULL) { - SSL_CTX_free(ctx); +struct FreeSSL { + inline void operator()(SSL* ssl) const { + if (ssl != NULL) { + SSL_free(ssl); } } }; @@ -248,35 +270,28 @@ struct FreeEVPKEY { } }; -SSL_CTX* CreateSSLContext(const std::string& certificate, - const std::string& private_key, - const SSLOptions& options, - std::vector* hostnames) { - std::unique_ptr ssl_ctx( - SSL_CTX_new(SSLv23_server_method())); - if (!ssl_ctx) { - LOG(ERROR) << "Fail to new SSL_CTX: " << SSLError(ERR_get_error()); - return NULL; - } - +static int LoadCertificate(SSL_CTX* ctx, + const std::string& certificate, + const std::string& private_key, + std::vector* hostnames) { // Load the private key if (IsPemString(private_key)) { std::unique_ptr kbio( BIO_new_mem_buf((void*)private_key.c_str(), -1)); std::unique_ptr key( PEM_read_bio_PrivateKey(kbio.get(), NULL, 0, NULL)); - if (SSL_CTX_use_PrivateKey(ssl_ctx.get(), key.get()) != 1) { + if (SSL_CTX_use_PrivateKey(ctx, key.get()) != 1) { LOG(ERROR) << "Fail to load " << private_key << ": " << SSLError(ERR_get_error()); - return NULL; + return -1; } } else { if (SSL_CTX_use_PrivateKey_file( - ssl_ctx.get(), private_key.c_str(), SSL_FILETYPE_PEM) != 1) { + ctx, private_key.c_str(), SSL_FILETYPE_PEM) != 1) { LOG(ERROR) << "Fail to load " << private_key << ": " << SSLError(ERR_get_error()); - return NULL; + return -1; } } @@ -289,7 +304,7 @@ SSL_CTX* CreateSSLContext(const std::string& certificate, if (BIO_read_filename(cbio.get(), certificate.c_str()) <= 0) { LOG(ERROR) << "Fail to read " << certificate << ": " << SSLError(ERR_get_error()); - return NULL; + return -1; } } std::unique_ptr x( @@ -297,32 +312,32 @@ SSL_CTX* CreateSSLContext(const std::string& certificate, if (!x) { LOG(ERROR) << "Fail to parse " << certificate << ": " << SSLError(ERR_get_error()); - return NULL; + return -1; } // Load the main certficate - if (SSL_CTX_use_certificate(ssl_ctx.get(), x.get()) != 1) { + if (SSL_CTX_use_certificate(ctx, x.get()) != 1) { LOG(ERROR) << "Fail to load " << certificate << ": " << SSLError(ERR_get_error()); - return NULL; + return -1; } // Load the certificate chain #if (OPENSSL_VERSION_NUMBER >= 0x10002000L) - SSL_CTX_clear_chain_certs(ssl_ctx.get()); + SSL_CTX_clear_chain_certs(ctx); #else - if (ssl_ctx->extra_certs != NULL) { - sk_X509_pop_free(ssl_ctx->extra_certs, X509_free); - ssl_ctx->extra_certs = NULL; + if (ctx->extra_certs != NULL) { + sk_X509_pop_free(ctx->extra_certs, X509_free); + ctx->extra_certs = NULL; } #endif X509* ca = NULL; while ((ca = PEM_read_bio_X509(cbio.get(), NULL, 0, NULL))) { - if (SSL_CTX_add_extra_chain_cert(ssl_ctx.get(), ca) != 1) { + if (SSL_CTX_add_extra_chain_cert(ctx, ca) != 1) { LOG(ERROR) << "Fail to load chain certificate in " << certificate << ": " << SSLError(ERR_get_error()); X509_free(ca); - return NULL; + return -1; } } @@ -331,17 +346,25 @@ SSL_CTX* CreateSSLContext(const std::string& certificate, || ERR_GET_REASON(err) != PEM_R_NO_START_LINE)) { LOG(ERROR) << "Fail to read chain certificate in " << certificate << ": " << SSLError(ERR_get_error()); - return NULL; + return -1; } ERR_clear_error(); // Validate certificate and private key - if (SSL_CTX_check_private_key(ssl_ctx.get()) != 1) { + if (SSL_CTX_check_private_key(ctx) != 1) { LOG(ERROR) << "Fail to verify " << private_key << ": " << SSLError(ERR_get_error()); - return NULL; + return -1; } + if (hostnames != NULL) { + ExtractHostnames(x.get(), hostnames); + } + return 0; +} + +static int SetSSLOptions(SSL_CTX* ctx, const std::string& ciphers, + int protocols, const VerifyOptions& verify) { long ssloptions = SSL_OP_ALL // All known workarounds for bugs | SSL_OP_NO_SSLv2 #ifdef SSL_OP_NO_COMPRESSION @@ -349,37 +372,136 @@ SSL_CTX* CreateSSLContext(const std::string& certificate, #endif // SSL_OP_NO_COMPRESSION | SSL_OP_CIPHER_SERVER_PREFERENCE | SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION; - if (options.disable_ssl3) { + + if (!(protocols & SSLv3)) { ssloptions |= SSL_OP_NO_SSLv3; } + if (!(protocols & TLSv1)) { + ssloptions |= SSL_OP_NO_TLSv1; + } + +#ifdef SSL_OP_NO_TLSv1_1 + if (!(protocols & TLSv1_1)) { + ssloptions |= SSL_OP_NO_TLSv1_1; + } +#endif // SSL_OP_NO_TLSv1_1 + +#ifdef SSL_OP_NO_TLSv1_2 + if (!(protocols & TLSv1_2)) { + ssloptions |= SSL_OP_NO_TLSv1_2; + } +#endif // SSL_OP_NO_TLSv1_2 + SSL_CTX_set_options(ctx, ssloptions); long sslmode = SSL_MODE_ENABLE_PARTIAL_WRITE -#ifdef SSL_MODE_RELEASE_BUFFERS - | SSL_MODE_RELEASE_BUFFERS -#endif // SSL_MODE_RELEASE_BUFFERS | SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER; - SSL_CTX_set_options(ssl_ctx.get(), ssloptions); - SSL_CTX_set_mode(ssl_ctx.get(), sslmode); + SSL_CTX_set_mode(ctx, sslmode); - // TODO: Support client certification validation - SSL_CTX_set_verify(ssl_ctx.get(), SSL_VERIFY_NONE, NULL); - - if (!options.ciphers.empty() && - SSL_CTX_set_cipher_list(ssl_ctx.get(), - options.ciphers.c_str()) != 1) { - LOG(ERROR) << "Fail to set cipher list to " << options.ciphers + if (!ciphers.empty() && + SSL_CTX_set_cipher_list(ctx, ciphers.c_str()) != 1) { + LOG(ERROR) << "Fail to set cipher list to " << ciphers << ": " << SSLError(ERR_get_error()); - return NULL; + return -1; } - SSL_CTX_set_timeout(ssl_ctx.get(), options.session_lifetime_s); - SSL_CTX_sess_set_cache_size(ssl_ctx.get(), options.session_cache_size); + // TODO: Verify the CNAME in certificate matches the requesting host + if (verify.verify_depth > 0) { + SSL_CTX_set_verify(ctx, (SSL_VERIFY_PEER + | SSL_VERIFY_FAIL_IF_NO_PEER_CERT), NULL); + SSL_CTX_set_verify_depth(ctx, verify.verify_depth); + std::string cafile = verify.ca_file_path; + if (cafile.empty()) { + cafile = X509_get_default_cert_area() + std::string("/cert.pem"); + } + if (SSL_CTX_load_verify_locations(ctx, cafile.c_str(), NULL) == 0) { + if (verify.ca_file_path.empty()) { + LOG(WARNING) << "Fail to load default CA file " << cafile + << ": " << SSLError(ERR_get_error()); + } else { + LOG(ERROR) << "Fail to load CA file " << cafile + << ": " << SSLError(ERR_get_error()); + return -1; + } + } + } else { + SSL_CTX_set_verify(ctx, SSL_VERIFY_NONE, NULL); + } - SSL_CTX_set_info_callback(ssl_ctx.get(), SSLInfoCallback); + SSL_CTX_set_info_callback(ctx, SSLInfoCallback); #if OPENSSL_VERSION_NUMBER >= 0x00907000L - SSL_CTX_set_msg_callback(ssl_ctx.get(), SSLMessageCallback); + // To detect and protect from heartbleed attack + SSL_CTX_set_msg_callback(ctx, SSLMessageCallback); #endif + return 0; +} + +SSL_CTX* CreateClientSSLContext(const ChannelSSLOptions& options) { + if (!options.enable) { + return NULL; + } + + std::unique_ptr ssl_ctx( + SSL_CTX_new(SSLv23_client_method())); + if (!ssl_ctx) { + LOG(ERROR) << "Fail to new SSL_CTX: " << SSLError(ERR_get_error()); + return NULL; + } + + if (!options.client_cert.certificate.empty() + && LoadCertificate(ssl_ctx.get(), + options.client_cert.certificate, + options.client_cert.private_key, NULL) != 0) { + return NULL; + } + + int protocols = ParseSSLProtocols(options.protocols); + if (protocols < 0 + || SetSSLOptions(ssl_ctx.get(), options.ciphers, + protocols, options.verify) != 0) { + return NULL; + } + + SSL_CTX_set_session_cache_mode(ssl_ctx.get(), SSL_SESS_CACHE_CLIENT); + return ssl_ctx.release(); +} + +SSL_CTX* CreateServerSSLContext(const std::string& certificate, + const std::string& private_key, + const ServerSSLOptions& options, + std::vector* hostnames) { + std::unique_ptr ssl_ctx( + SSL_CTX_new(SSLv23_server_method())); + if (!ssl_ctx) { + LOG(ERROR) << "Fail to new SSL_CTX: " << SSLError(ERR_get_error()); + return NULL; + } + + if (LoadCertificate(ssl_ctx.get(), certificate, + private_key, hostnames) != 0) { + return NULL; + } + + int protocols = TLSv1 | TLSv1_1 | TLSv1_2; + if (!options.disable_ssl3) { + protocols |= SSLv3; + } + if (SetSSLOptions(ssl_ctx.get(), options.ciphers, + protocols, options.verify) != 0) { + return NULL; + } + +#ifdef SSL_MODE_RELEASE_BUFFERS + if (options.release_buffer) { + long sslmode = SSL_CTX_get_mode(ssl_ctx.get()); + sslmode |= SSL_MODE_RELEASE_BUFFERS; + SSL_CTX_set_mode(ssl_ctx.get(), sslmode); + } +#endif // SSL_MODE_RELEASE_BUFFERS + + SSL_CTX_set_timeout(ssl_ctx.get(), options.session_lifetime_s); + SSL_CTX_sess_set_cache_size(ssl_ctx.get(), options.session_cache_size); + #ifndef OPENSSL_NO_DH SSL_CTX_set_tmp_dh_callback(ssl_ctx.get(), SSLGetDHCallback); @@ -398,13 +520,9 @@ SSL_CTX* CreateSSLContext(const std::string& certificate, #endif // OPENSSL_NO_DH - if (hostnames != NULL) { - ExtractHostnames(x.get(), hostnames); - } return ssl_ctx.release(); } - SSL* CreateSSLSession(SSL_CTX* ctx, SocketId id, int fd, bool server_mode) { if (ctx == NULL) { LOG(WARNING) << "Lack SSL_ctx to create an SSL session"; @@ -420,6 +538,7 @@ SSL* CreateSSLSession(SSL_CTX* ctx, SocketId id, int fd, bool server_mode) { SSL_free(ssl); return NULL; } + if (server_mode) { SSL_set_accept_state(ssl); } else { @@ -429,6 +548,21 @@ SSL* CreateSSLSession(SSL_CTX* ctx, SocketId id, int fd, bool server_mode) { return ssl; } +void AddBIOBuffer(SSL* ssl, int fd, int bufsize) { + BIO* rbio = BIO_new(BIO_f_buffer()); + BIO_set_buffer_size(rbio, bufsize); + BIO* rfd = BIO_new(BIO_s_fd()); + BIO_set_fd(rfd, fd, 0); + rbio = BIO_push(rbio, rfd); + + BIO* wbio = BIO_new(BIO_f_buffer()); + BIO_set_buffer_size(wbio, bufsize); + BIO* wfd = BIO_new(BIO_s_fd()); + BIO_set_fd(wfd, fd, 0); + wbio = BIO_push(wbio, wfd); + SSL_set_bio(ssl, rbio, wbio); +} + SSLState DetectSSLState(int fd, int* error_code) { // Peek the first few bytes inside socket to detect whether // it's an SSL connection. If it is, create an SSL session @@ -637,3 +771,50 @@ int SSLDHInit() { } } // namespace brpc + +std::ostream& operator<<(std::ostream& os, SSL* ssl) { + os << "[SSL HANDSHAKE]" + << "\n* cipher: " << SSL_get_cipher(ssl) + << "\n* protocol: " << SSL_get_version(ssl) + << "\n* verify: " << (SSL_get_verify_mode(ssl) & SSL_VERIFY_PEER + ? "success" : "none") + << "\n"; + + X509* cert = SSL_get_peer_certificate(ssl); + if (cert) { + os << "\n" << cert; + } + return os; +} + +std::ostream& operator<<(std::ostream& os, X509* cert) { + BIO* buf = BIO_new(BIO_s_mem()); + if (buf == NULL) { + return os; + } + BIO_printf(buf, "[CERTIFICATE]"); + + BIO_printf(buf, "\n* subject: "); + X509_NAME_print(buf, X509_get_subject_name(cert), 0); + BIO_printf(buf, "\n* start date: "); + ASN1_TIME_print(buf, X509_get_notBefore(cert)); + BIO_printf(buf, "\n* expire date: "); + ASN1_TIME_print(buf, X509_get_notAfter(cert)); + + BIO_printf(buf, "\n* common name: "); + std::vector hostnames; + brpc::ExtractHostnames(cert, &hostnames); + for (size_t i = 0; i < hostnames.size(); ++i) { + BIO_printf(buf, "%s; ", hostnames[i].c_str()); + } + + BIO_printf(buf, "\n* issuer: "); + X509_NAME_print(buf, X509_get_issuer_name(cert), 0); + + BIO_printf(buf, "\n"); + + char* bufp = NULL; + int len = BIO_get_mem_data(buf, &bufp); + os << butil::StringPiece(bufp, len); + return os; +} diff --git a/src/brpc/details/ssl_helper.h b/src/brpc/details/ssl_helper.h index 185de82c9b..7ff525eb70 100644 --- a/src/brpc/details/ssl_helper.h +++ b/src/brpc/details/ssl_helper.h @@ -21,8 +21,8 @@ #include // For some versions of openssl, SSL_* are defined inside this header #include -#include "brpc/server.h" // SSLOptions #include "brpc/socket_id.h" // SocketId +#include "brpc/ssl_option.h" // SSLOptions namespace brpc { @@ -34,6 +34,21 @@ enum SSLState { SSL_CONNECTED = 3, // SSL handshake completed }; +enum SSLProtocol { + SSLv3 = 1 << 0, + TLSv1 = 1 << 1, + TLSv1_1 = 1 << 2, + TLSv1_2 = 1 << 3, +}; + +struct FreeSSLCTX { + inline void operator()(SSL_CTX* ctx) const { + if (ctx != NULL) { + SSL_CTX_free(ctx); + } + } +}; + struct SSLError { explicit SSLError(unsigned long e) : error(e) { } unsigned long error; @@ -51,18 +66,27 @@ int SSLThreadInit(); // Return 0 on success, -1 otherwise int SSLDHInit(); -// Create a new SSL_CTX using `certificate_file' and `private_key_file' -// and then set the right options onto it according `options'. Finally, -// extract hostnames from CN/subject fields into `hostnames' -SSL_CTX* CreateSSLContext(const std::string& certificate_file, - const std::string& private_key_file, - const SSLOptions& options, - std::vector* hostnames); +// Create a new SSL_CTX in client mode and +// set the right options according `options' +SSL_CTX* CreateClientSSLContext(const ChannelSSLOptions& options); + +// Create a new SSL_CTX in server mode using `certificate_file' +// and `private_key_file' and then set the right options onto it +// according `options'. Finally, extract hostnames from CN/subject +// fields into `hostnames' +SSL_CTX* CreateServerSSLContext(const std::string& certificate_file, + const std::string& private_key_file, + const SSLOptions& options, + std::vector* hostnames); // Create a new SSL (per connection object) using configurations in `ctx'. // Set the required `fd' and mode. `id' will be set into SSL as app data. SSL* CreateSSLSession(SSL_CTX* ctx, SocketId id, int fd, bool server_mode); +// Add a buffer layer of BIO in front of the socket fd layer, +// which can reduce the total number of calls to system read/write +void AddBIOBuffer(SSL* ssl, int fd, int bufsize); + // Judge whether the underlying channel of `fd' is using SSL // If the return value is SSL_UNKNOWN, `error_code' will be // set to indicate the reason (0 for EOF) @@ -70,5 +94,7 @@ SSLState DetectSSLState(int fd, int* error_code); } // namespace brpc +std::ostream& operator<<(std::ostream& os, SSL* ssl); +std::ostream& operator<<(std::ostream& os, X509* cert); #endif // BRPC_SSL_HELPER_H diff --git a/src/brpc/global.cpp b/src/brpc/global.cpp index 7c6a549684..3ca655e670 100644 --- a/src/brpc/global.cpp +++ b/src/brpc/global.cpp @@ -15,6 +15,7 @@ // Authors: Ge,Jun (gejun@baidu.com) #include +#include #include #include // O_RDONLY #include @@ -298,6 +299,8 @@ static void GlobalInitializeOrDieImpl() { // Initialize openssl library SSL_library_init(); + // Load the openssl.cnf under the default location + OPENSSL_config(NULL); SSL_load_error_strings(); if (SSLThreadInit() != 0 || SSLDHInit() != 0) { exit(1); diff --git a/src/brpc/input_messenger.cpp b/src/brpc/input_messenger.cpp index 1355e49ef2..924fa4782b 100644 --- a/src/brpc/input_messenger.cpp +++ b/src/brpc/input_messenger.cpp @@ -19,6 +19,7 @@ #include "butil/logging.h" // CHECK #include "butil/time.h" // cpuwide_time_us #include "butil/fd_utility.h" // make_non_blocking +#include "bthread/bthread.h" // bthread_start_background #include "bthread/unstable.h" // bthread_flush #include "bvar/bvar.h" // bvar::Adder #include "brpc/options.pb.h" // ProtocolType @@ -173,7 +174,6 @@ void InputMessenger::OnNewMessages(Socket* m) { // is batched(notice the BTHREAD_NOSIGNAL and bthread_flush). // - Verify will always be called in this bthread at most once and before // any process. - InputMessenger* messenger = static_cast(m->user()); const InputMessageHandler* handlers = messenger->_handlers; int progress = Socket::PROGRESS_INIT; diff --git a/src/brpc/nshead_service.h b/src/brpc/nshead_service.h index f686a1a33f..e36dd69138 100644 --- a/src/brpc/nshead_service.h +++ b/src/brpc/nshead_service.h @@ -24,7 +24,6 @@ namespace brpc { -class Socket; class Server; class MethodStatus; class StatusService; @@ -61,7 +60,6 @@ friend class DeleteNsheadClosure; // Only callable by Run(). ~NsheadClosure(); - Socket* _socket_ptr; const Server* _server; int64_t _start_parse_us; NsheadMessage _request; diff --git a/src/brpc/parallel_channel.cpp b/src/brpc/parallel_channel.cpp index 1e05407164..583f4124fc 100644 --- a/src/brpc/parallel_channel.cpp +++ b/src/brpc/parallel_channel.cpp @@ -14,6 +14,7 @@ // Authors: Ge,Jun (gejun@baidu.com) +#include "bthread/bthread.h" // bthread_id_xx #include "bthread/unstable.h" // bthread_timer_add #include "butil/atomicops.h" #include "butil/time.h" diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 553fa65d18..cfe9064ef0 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -137,7 +137,6 @@ void SendRpcResponse(int64_t correlation_id, Controller* cntl, const google::protobuf::Message* req, const google::protobuf::Message* res, - Socket* socket_raw, const Server* server, MethodStatus* method_status_raw, long start_parse_us) { @@ -146,7 +145,7 @@ void SendRpcResponse(int64_t correlation_id, if (span) { span->set_start_send_us(butil::cpuwide_time_us()); } - SocketUniquePtr sock(socket_raw); + Socket* sock = accessor.get_sending_socket(); ScopedMethodStatus method_status(method_status_raw); std::unique_ptr recycle_cntl(cntl); std::unique_ptr recycle_req(req); @@ -211,7 +210,7 @@ void SendRpcResponse(int64_t correlation_id, if (Socket::Address(response_stream_id, &stream_ptr) == 0) { Stream* s = (Stream*)stream_ptr->conn(); s->FillSettings(meta.mutable_stream_settings()); - s->SetHostSocket(sock.get()); + s->SetHostSocket(sock); } else { LOG(WARNING) << "Stream=" << response_stream_id << " was closed before sending response"; @@ -234,7 +233,7 @@ void SendRpcResponse(int64_t correlation_id, CHECK(accessor.remote_stream_settings() != NULL); // Send the response over stream to notify that this stream connection // is successfully built. - if (SendStreamData(sock.get(), &res_buf, + if (SendStreamData(sock, &res_buf, accessor.remote_stream_settings()->stream_id(), accessor.response_stream()) != 0) { const int errcode = errno; @@ -308,7 +307,8 @@ void EndRunningCallMethodInPool( void ProcessRpcRequest(InputMessageBase* msg_base) { const int64_t start_parse_us = butil::cpuwide_time_us(); DestroyingPtr msg(static_cast(msg_base)); - SocketUniquePtr socket(msg->ReleaseSocket()); + SocketUniquePtr socket_guard(msg->ReleaseSocket()); + Socket* socket = socket_guard.get(); const Server* server = static_cast(msg_base->arg()); ScopedNonServiceError non_service_error(server); @@ -355,8 +355,9 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { .set_remote_side(socket->remote_side()) .set_local_side(socket->local_side()) .set_auth_context(socket->auth_context()) - .set_request_protocol(PROTOCOL_BAIDU_STD); - + .set_request_protocol(PROTOCOL_BAIDU_STD) + .move_in_server_receiving_sock(socket_guard); + if (meta.has_stream_settings()) { accessor.set_remote_stream_settings(meta.release_stream_settings()); } @@ -373,7 +374,7 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { request_meta.parent_span_id(), msg->base_real_us()); accessor.set_span(span); span->set_log_id(request_meta.log_id()); - span->set_remote_side(socket->remote_side()); + span->set_remote_side(cntl->remote_side()); span->set_protocol(PROTOCOL_BAIDU_STD); span->set_received_us(msg->received_us()); span->set_start_parse_us(start_parse_us); @@ -484,10 +485,10 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { // `socket' will be held until response has been sent google::protobuf::Closure* done = ::brpc::NewCallback< int64_t, Controller*, const google::protobuf::Message*, - const google::protobuf::Message*, Socket*, const Server*, + const google::protobuf::Message*, const Server*, MethodStatus*, long>( &SendRpcResponse, meta.correlation_id(), cntl.get(), - req.get(), res.get(), socket.release(), server, + req.get(), res.get(), server, method_status, start_parse_us); if (span) { span->set_start_callback_us(butil::cpuwide_time_us()); @@ -511,7 +512,7 @@ void ProcessRpcRequest(InputMessageBase* msg_base) { // `cntl', `req' and `res' will be deleted inside `SendRpcResponse' // `socket' will be held until response has been sent SendRpcResponse(meta.correlation_id(), cntl.release(), - req.release(), res.release(), socket.release(), server, + req.release(), res.release(), server, method_status, -1); } diff --git a/src/brpc/policy/baidu_rpc_protocol.h b/src/brpc/policy/baidu_rpc_protocol.h index 90f405d2a3..159f730937 100644 --- a/src/brpc/policy/baidu_rpc_protocol.h +++ b/src/brpc/policy/baidu_rpc_protocol.h @@ -43,7 +43,7 @@ void PackRpcRequest(butil::IOBuf* buf, const google::protobuf::MethodDescriptor* method, Controller* controller, const butil::IOBuf& request, - const Authenticator* auth); + const Authenticator* auth); } // namespace policy } // namespace brpc diff --git a/src/brpc/policy/hulu_pbrpc_protocol.cpp b/src/brpc/policy/hulu_pbrpc_protocol.cpp index fb8770f81b..7b78cc5aaf 100644 --- a/src/brpc/policy/hulu_pbrpc_protocol.cpp +++ b/src/brpc/policy/hulu_pbrpc_protocol.cpp @@ -48,7 +48,8 @@ namespace policy { // problems on machines with different byte order) // 3. Use service->name() (rather than service->full_name()) + method_index // to locate method defined in .proto file -// 4. `user_message_size' is set iff request/response has attachment +// 4. 'user_message_size' is the size of protobuf request, +// and should be set iff request/response has attachment // 5. Not supported: // chunk_info - hulu doesn't support either // TalkType - nobody has use this so far in hulu @@ -222,7 +223,6 @@ static void SendHuluResponse(int64_t correlation_id, HuluController* cntl, const google::protobuf::Message* req, const google::protobuf::Message* res, - Socket* socket_ptr, const Server* server, MethodStatus* method_status_raw, long start_parse_us) { @@ -232,7 +232,7 @@ static void SendHuluResponse(int64_t correlation_id, span->set_start_send_us(butil::cpuwide_time_us()); } ScopedMethodStatus method_status(method_status_raw); - SocketUniquePtr sock(socket_ptr); + Socket* sock = accessor.get_sending_socket(); std::unique_ptr recycle_cntl(cntl); std::unique_ptr recycle_req(req); std::unique_ptr recycle_res(res); @@ -336,7 +336,8 @@ void EndRunningCallMethodInPool( void ProcessHuluRequest(InputMessageBase* msg_base) { const int64_t start_parse_us = butil::cpuwide_time_us(); DestroyingPtr msg(static_cast(msg_base)); - SocketUniquePtr socket(msg->ReleaseSocket()); + SocketUniquePtr socket_guard(msg->ReleaseSocket()); + Socket* socket = socket_guard.get(); const Server* server = static_cast(msg_base->arg()); ScopedNonServiceError non_service_error(server); @@ -382,7 +383,8 @@ void ProcessHuluRequest(InputMessageBase* msg_base) { .set_remote_side(socket->remote_side()) .set_local_side(socket->local_side()) .set_auth_context(socket->auth_context()) - .set_request_protocol(PROTOCOL_HULU_PBRPC); + .set_request_protocol(PROTOCOL_HULU_PBRPC) + .move_in_server_receiving_sock(socket_guard); if (meta.has_user_data()) { cntl->set_request_user_data(meta.user_data()); @@ -405,12 +407,13 @@ void ProcessHuluRequest(InputMessageBase* msg_base) { msg->base_real_us()); accessor.set_span(span); span->set_log_id(meta.log_id()); - span->set_remote_side(socket->remote_side()); + span->set_remote_side(cntl->remote_side()); span->set_protocol(PROTOCOL_HULU_PBRPC); span->set_received_us(msg->received_us()); span->set_start_parse_us(start_parse_us); span->set_request_size(msg->payload.size() + msg->meta.size() + 12); } + MethodStatus* method_status = NULL; do { if (!server->IsRunning()) { @@ -492,10 +495,10 @@ void ProcessHuluRequest(InputMessageBase* msg_base) { // `socket' will be held until response has been sent google::protobuf::Closure* done = ::brpc::NewCallback< int64_t, HuluController*, const google::protobuf::Message*, - const google::protobuf::Message*, Socket*, const Server*, + const google::protobuf::Message*, const Server*, MethodStatus *, long>( &SendHuluResponse, correlation_id, cntl.get(), - req.get(), res.get(), socket.release(), server, + req.get(), res.get(), server, method_status, start_parse_us); if (span) { span->set_start_callback_us(butil::cpuwide_time_us()); @@ -519,7 +522,7 @@ void ProcessHuluRequest(InputMessageBase* msg_base) { // `cntl', `req' and `res' will be deleted inside `SendHuluResponse' // `socket' will be held until response has been sent SendHuluResponse(correlation_id, cntl.release(), - req.release(), res.release(), socket.release(), server, + req.release(), res.release(), server, method_status, -1); } diff --git a/src/brpc/policy/mongo_protocol.cpp b/src/brpc/policy/mongo_protocol.cpp index 558961912a..6e0d9742ca 100644 --- a/src/brpc/policy/mongo_protocol.cpp +++ b/src/brpc/policy/mongo_protocol.cpp @@ -39,18 +39,16 @@ namespace brpc { namespace policy { struct SendMongoResponse : public google::protobuf::Closure { - SendMongoResponse(const Server *server, Socket *socket) : + SendMongoResponse(const Server *server) : status(NULL), start_callback_us(0L), - server(server), - socket(socket) {} + server(server) {} ~SendMongoResponse(); void Run(); MethodStatus* status; long start_callback_us; const Server *server; - SocketUniquePtr socket; Controller cntl; MongoRequest req; MongoResponse res; @@ -63,6 +61,7 @@ SendMongoResponse::~SendMongoResponse() { void SendMongoResponse::Run() { std::unique_ptr delete_self(this); ScopedMethodStatus method_status(status); + Socket* socket = ControllerPrivateAccessor(&cntl).get_sending_socket(); if (cntl.IsCloseConnection()) { socket->SetFailed(); @@ -174,7 +173,8 @@ void EndRunningCallMethodInPool( void ProcessMongoRequest(InputMessageBase* msg_base) { DestroyingPtr msg(static_cast(msg_base)); - SocketUniquePtr socket(msg->ReleaseSocket()); + SocketUniquePtr socket_guard(msg->ReleaseSocket()); + Socket* socket = socket_guard.get(); const Server* server = static_cast(msg_base->arg()); ScopedNonServiceError non_service_error(server); @@ -199,17 +199,18 @@ void ProcessMongoRequest(InputMessageBase* msg_base) { return; } - SendMongoResponse* mongo_done = new SendMongoResponse(server, socket.release()); + SendMongoResponse* mongo_done = new SendMongoResponse(server); mongo_done->cntl.set_mongo_session_data(context_msg->context()); ControllerPrivateAccessor accessor(&(mongo_done->cntl)); accessor.set_server(server) .set_security_mode(server->options().security_mode()) - .set_peer_id(mongo_done->socket->id()) - .set_remote_side(mongo_done->socket->remote_side()) - .set_local_side(mongo_done->socket->local_side()) - .set_auth_context(mongo_done->socket->auth_context()) - .set_request_protocol(PROTOCOL_MONGO); + .set_peer_id(socket->id()) + .set_remote_side(socket->remote_side()) + .set_local_side(socket->local_side()) + .set_auth_context(socket->auth_context()) + .set_request_protocol(PROTOCOL_MONGO) + .move_in_server_receiving_sock(socket_guard); // Tag the bthread with this server's key for // thread_local_data(). diff --git a/src/brpc/policy/nshead_protocol.cpp b/src/brpc/policy/nshead_protocol.cpp index bc1dca6edc..dd23d3e188 100644 --- a/src/brpc/policy/nshead_protocol.cpp +++ b/src/brpc/policy/nshead_protocol.cpp @@ -39,8 +39,7 @@ void bthread_assign_data(void* data); namespace brpc { NsheadClosure::NsheadClosure(void* additional_space) - : _socket_ptr(NULL) - , _server(NULL) + : _server(NULL) , _start_parse_us(0) , _do_respond(true) , _additional_space(additional_space) { @@ -65,7 +64,6 @@ class DeleteNsheadClosure { void NsheadClosure::Run() { // Recycle itself after `Run' std::unique_ptr recycle_ctx(this); - SocketUniquePtr sock(_socket_ptr); ScopedRemoveConcurrency remove_concurrency_dummy(_server, &_controller); ControllerPrivateAccessor accessor(&_controller); @@ -73,6 +71,7 @@ void NsheadClosure::Run() { if (span) { span->set_start_send_us(butil::cpuwide_time_us()); } + Socket* sock = accessor.get_sending_socket(); ScopedMethodStatus method_status(_server->options().nshead_service->_status); if (!method_status) { // Judge errors belongings. @@ -208,7 +207,8 @@ void ProcessNsheadRequest(InputMessageBase* msg_base) { const int64_t start_parse_us = butil::cpuwide_time_us(); DestroyingPtr msg(static_cast(msg_base)); - SocketUniquePtr socket(msg->ReleaseSocket()); + SocketUniquePtr socket_guard(msg->ReleaseSocket()); + Socket* socket = socket_guard.get(); const Server* server = static_cast(msg_base->arg()); ScopedNonServiceError non_service_error(server); @@ -250,7 +250,6 @@ void ProcessNsheadRequest(InputMessageBase* msg_base) { req->head = *req_head; msg->payload.swap(req->body); nshead_done->_start_parse_us = start_parse_us; - nshead_done->_socket_ptr = socket.get(); nshead_done->_server = server; ServerPrivateAccessor server_accessor(server); @@ -266,7 +265,8 @@ void ProcessNsheadRequest(InputMessageBase* msg_base) { .set_peer_id(socket->id()) .set_remote_side(socket->remote_side()) .set_local_side(socket->local_side()) - .set_request_protocol(PROTOCOL_NSHEAD); + .set_request_protocol(PROTOCOL_NSHEAD) + .move_in_server_receiving_sock(socket_guard); // Tag the bthread with this server's key for thread_local_data(). if (server->thread_local_options().thread_local_data_factory) { @@ -309,7 +309,6 @@ void ProcessNsheadRequest(InputMessageBase* msg_base) { msg.reset(); // optional, just release resourse ASAP // `socket' will be held until response has been sent - socket.release(); if (span) { span->ResetServerSpanName(service->_cached_name); span->set_start_callback_us(butil::cpuwide_time_us()); diff --git a/src/brpc/policy/sofa_pbrpc_protocol.cpp b/src/brpc/policy/sofa_pbrpc_protocol.cpp index 2479431411..85c2fc0e3c 100644 --- a/src/brpc/policy/sofa_pbrpc_protocol.cpp +++ b/src/brpc/policy/sofa_pbrpc_protocol.cpp @@ -207,7 +207,6 @@ static void SendSofaResponse(int64_t correlation_id, Controller* cntl, const google::protobuf::Message* req, const google::protobuf::Message* res, - Socket* socket_raw, const Server* server, MethodStatus* method_status_raw, long start_parse_us) { @@ -217,7 +216,7 @@ static void SendSofaResponse(int64_t correlation_id, span->set_start_send_us(butil::cpuwide_time_us()); } ScopedMethodStatus method_status(method_status_raw); - SocketUniquePtr sock(socket_raw); + Socket* sock = accessor.get_sending_socket(); std::unique_ptr recycle_cntl(cntl); std::unique_ptr recycle_req(req); std::unique_ptr recycle_res(res); @@ -313,7 +312,8 @@ void EndRunningCallMethodInPool( void ProcessSofaRequest(InputMessageBase* msg_base) { const int64_t start_parse_us = butil::cpuwide_time_us(); DestroyingPtr msg(static_cast(msg_base)); - SocketUniquePtr socket(msg->ReleaseSocket()); + SocketUniquePtr socket_guard(msg->ReleaseSocket()); + Socket* socket = socket_guard.get(); const Server* server = static_cast(msg_base->arg()); ScopedNonServiceError non_service_error(server); @@ -356,7 +356,8 @@ void ProcessSofaRequest(InputMessageBase* msg_base) { .set_remote_side(socket->remote_side()) .set_local_side(socket->local_side()) .set_auth_context(socket->auth_context()) - .set_request_protocol(PROTOCOL_SOFA_PBRPC); + .set_request_protocol(PROTOCOL_SOFA_PBRPC) + .move_in_server_receiving_sock(socket_guard); // Tag the bthread with this server's key for thread_local_data(). if (server->thread_local_options().thread_local_data_factory) { @@ -369,12 +370,13 @@ void ProcessSofaRequest(InputMessageBase* msg_base) { 0/*meta.trace_id()*/, 0/*meta.span_id()*/, 0/*meta.parent_span_id()*/, msg->base_real_us()); accessor.set_span(span); - span->set_remote_side(socket->remote_side()); + span->set_remote_side(cntl->remote_side()); span->set_protocol(PROTOCOL_SOFA_PBRPC); span->set_received_us(msg->received_us()); span->set_start_parse_us(start_parse_us); span->set_request_size(msg->meta.size() + msg->payload.size() + 24); } + MethodStatus* method_status = NULL; do { if (!server->IsRunning()) { @@ -436,10 +438,10 @@ void ProcessSofaRequest(InputMessageBase* msg_base) { // `socket' will be held until response has been sent google::protobuf::Closure* done = ::brpc::NewCallback< int64_t, Controller*, const google::protobuf::Message*, - const google::protobuf::Message*, Socket*, const Server*, + const google::protobuf::Message*, const Server*, MethodStatus *, long>( &SendSofaResponse, correlation_id, cntl.get(), - req.get(), res.get(), socket.release(), server, + req.get(), res.get(), server, method_status, start_parse_us); // `cntl', `req' and `res' will be deleted inside `done' if (span) { @@ -464,7 +466,7 @@ void ProcessSofaRequest(InputMessageBase* msg_base) { // `cntl', `req' and `res' will be deleted inside `SendSofaResponse' // `socket' will be held until response has been sent SendSofaResponse(correlation_id, cntl.release(), - req.release(), res.release(), socket.release(), server, + req.release(), res.release(), server, method_status, -1); } diff --git a/src/brpc/rtmp.cpp b/src/brpc/rtmp.cpp index d0800e400e..23ab6e7ca6 100644 --- a/src/brpc/rtmp.cpp +++ b/src/brpc/rtmp.cpp @@ -17,6 +17,7 @@ #include #include // StringOutputStream +#include "bthread/bthread.h" // bthread_id_xx #include "bthread/unstable.h" // bthread_timer_del #include "brpc/log.h" #include "brpc/callback.h" // Closure @@ -1056,9 +1057,8 @@ class RtmpSocketCreator : public SocketCreator { : _connect_options(connect_options) { } - int CreateSocket(const butil::EndPoint& pt, SocketId* id) { - SocketOptions sock_opt; - sock_opt.remote_side = pt; + int CreateSocket(const SocketOptions& opt, SocketId* id) { + SocketOptions sock_opt = opt; sock_opt.app_connect = new RtmpConnect; sock_opt.initial_parsing_context = new policy::RtmpContext(&_connect_options, NULL); return get_client_side_messenger()->Create(sock_opt, id); @@ -1661,7 +1661,7 @@ void RtmpClientStream::ReplaceSocketForStream( } } else { if (_client_impl->socket_map().Insert( - (*inout)->remote_side(), &esid) != 0) { + SocketMapKey((*inout)->remote_side()), &esid) != 0) { cntl->SetFailed(EINVAL, "Fail to get the RTMP socket"); return; } diff --git a/src/brpc/selective_channel.cpp b/src/brpc/selective_channel.cpp index e8c8d4c1e2..fd1006cd28 100644 --- a/src/brpc/selective_channel.cpp +++ b/src/brpc/selective_channel.cpp @@ -16,6 +16,7 @@ #include #include +#include "bthread/bthread.h" // bthread_id_xx #include "brpc/socket.h" // SocketUser #include "brpc/load_balancer.h" // LoadBalancer #include "brpc/details/controller_private_accessor.h" // RPCSender diff --git a/src/brpc/server.cpp b/src/brpc/server.cpp index 3fdf4537cd..de6285a6c5 100644 --- a/src/brpc/server.cpp +++ b/src/brpc/server.cpp @@ -37,7 +37,7 @@ #include "brpc/global.h" #include "brpc/socket_map.h" // SocketMapList #include "brpc/acceptor.h" // Acceptor -#include "brpc/details/ssl_helper.h" // CreateSSLContext +#include "brpc/details/ssl_helper.h" // CreateServerSSLContext #include "brpc/protocol.h" // ListProtocols #include "brpc/nshead_service.h" // NsheadService #include "brpc/builtin/bad_method_service.h" // BadMethodService @@ -116,14 +116,6 @@ const int INITIAL_CERT_MAP = 64; // compilation units is undefined. const int s_ncore = sysconf(_SC_NPROCESSORS_ONLN); -SSLOptions::SSLOptions() - : strict_sni(false) - , disable_ssl3(true) - , session_lifetime_s(300) - , session_cache_size(20480) - , ecdhe_curve_name("prime256v1") -{} - ServerOptions::ServerOptions() : idle_timeout_sec(-1) , nshead_service(NULL) @@ -1738,8 +1730,8 @@ int Server::AddCertificate(const CertInfo& cert) { SSLContext ssl_ctx; ssl_ctx.filters = cert.sni_filters; - ssl_ctx.ctx = CreateSSLContext(cert.certificate, cert.private_key, - _options.ssl_options, &ssl_ctx.filters); + ssl_ctx.ctx = CreateServerSSLContext(cert.certificate, cert.private_key, + _options.ssl_options, &ssl_ctx.filters); if (ssl_ctx.ctx == NULL) { return -1; } @@ -1853,7 +1845,7 @@ int Server::ResetCertificates(const std::vector& certs) { SSLContext ssl_ctx; ssl_ctx.filters = certs[i].sni_filters; - ssl_ctx.ctx = CreateSSLContext( + ssl_ctx.ctx = CreateServerSSLContext( certs[i].certificate, certs[i].private_key, _options.ssl_options, &ssl_ctx.filters); if (ssl_ctx.ctx == NULL) { diff --git a/src/brpc/server.h b/src/brpc/server.h index 1e9ef1b325..007463ee0b 100644 --- a/src/brpc/server.h +++ b/src/brpc/server.h @@ -30,6 +30,7 @@ #include "bvar/bvar.h" #include "butil/containers/case_ignored_flat_map.h" // [CaseIgnored]FlatMap #include "brpc/controller.h" // brpc::Controller +#include "brpc/ssl_option.h" // ServerSSLOptions #include "brpc/describable.h" // User often needs this #include "brpc/data_factory.h" // DataFactory #include "brpc/builtin/tabbed.h" @@ -50,77 +51,6 @@ class MongoServiceAdaptor; class RestfulMap; class RtmpService; -struct CertInfo { - // Certificate in PEM format. - // Note that CN and alt subjects will be extracted from the certificate, - // and will be used as hostnames. Requests to this hostname (provided SNI - // extension supported) will be encrypted using this certifcate. - // Supported both file path and raw string - std::string certificate; - - // Private key in PEM format. - // Supported both file path and raw string based on prefix: - std::string private_key; - - // Additional hostnames besides those inside the certificate. Wildcards - // are supported but it can only appear once at the beginning (i.e. *.xxx.com). - std::vector sni_filters; -}; - -struct SSLOptions { - // Constructed with default options - SSLOptions(); - - // Default certificate which will be loaded into server. Requests - // without hostname or whose hostname doesn't have a corresponding - // certificate will use this certificate. MUST be set to enable SSL. - CertInfo default_cert; - - // Additional certificates which will be loaded into server. These - // provide extra bindings between hostnames and certificates so that - // we can choose different certificates according to different hostnames. - // See `CertInfo' for detail. - std::vector certs; - - // When set, requests without hostname or whose hostname can't be found in - // any of the cerficates above will be dropped. Otherwise, `default_cert' - // will be used. - // Default: false - bool strict_sni; - - // When set, SSLv3 requests will be dropped. Strongly recommended since - // SSLv3 has been found suffering from severe security problems. Note that - // some old versions of browsers may use SSLv3 by default such as IE6.0 - // Default: true - bool disable_ssl3; - - // Maximum lifetime for a session to be cached inside OpenSSL in seconds. - // A session can be reused (initiated by client) to save handshake before - // it reaches this timeout. - // Default: 300 - int session_lifetime_s; - - // Maximum number of cached sessions. When cache is full, no more new - // session will be added into the cache until SSL_CTX_flush_sessions is - // called (automatically by SSL_read/write). A special value is 0, which - // means no limit. - // Default: 20480 - int session_cache_size; - - // Cipher suites allowed for each SSL handshake. The format of this string - // should follow that in `man 1 cipers'. If empty, OpenSSL will choose - // a default cipher based on the certificate information - // Default: "" - std::string ciphers; - - // Name of the elliptic curve used to generate ECDH ephemerial keys - // Default: prime256v1 - std::string ecdhe_curve_name; - - // TODO: Support NPN & ALPN - // TODO: Support OSCP stapling -}; - struct ServerOptions { // Constructed with default options. ServerOptions(); @@ -260,8 +190,8 @@ struct ServerOptions { // Enable more secured code which protects internal information from exposure. bool security_mode() const { return internal_port >= 0 || !has_builtin_services; } - // SSL related options. Refer to `SSLOptions' for details - SSLOptions ssl_options; + // SSL related options. Refer to `ServerSSLOptions' for details + ServerSSLOptions ssl_options; // [CAUTION] This option is for implementing specialized http proxies, // most users don't need it. Don't change this option unless you fully diff --git a/src/brpc/socket.cpp b/src/brpc/socket.cpp index 0760f154c2..7b7f76ab55 100644 --- a/src/brpc/socket.cpp +++ b/src/brpc/socket.cpp @@ -35,6 +35,7 @@ #include "brpc/errno.pb.h" #include "brpc/event_dispatcher.h" // RemoveConsumer #include "brpc/socket.h" +#include "brpc/describable.h" // Describable #include "brpc/input_messenger.h" #include "brpc/details/sparse_minute_counter.h" #include "brpc/stream_impl.h" @@ -66,6 +67,8 @@ DEFINE_int32(socket_recv_buffer_size, -1, DEFINE_int32(socket_send_buffer_size, -1, "Set send buffer size of sockets if this value is positive"); +DEFINE_int32(ssl_bio_buffer_size, 16*1024, "Set buffer size for SSL read/write"); + DEFINE_int64(socket_max_unwritten_bytes, 64 * 1024 * 1024, "Max unwritten bytes in each socket, if the limit is reached," " Socket.Write fails with EOVERCROWDED"); @@ -79,6 +82,8 @@ DEFINE_int32(connect_timeout_as_unreachable, 3, "times *continuously*, the error is changed to ENETUNREACH which " "fails the main socket as well when this socket is pooled."); +DECLARE_bool(http_verbose); + static bool validate_connect_timeout_as_unreachable(const char*, int32_t v) { return v >= 2 && v < 1000/*large enough*/; } @@ -96,7 +101,7 @@ const int WAIT_EPOLLOUT_TIMEOUT_MS = 50; class BAIDU_CACHELINE_ALIGNMENT SocketPool { public: - explicit SocketPool(const butil::EndPoint& pt); + explicit SocketPool(const SocketOptions& opt); ~SocketPool(); // Get an address-able socket. If the pool is empty, create one. @@ -111,6 +116,8 @@ class BAIDU_CACHELINE_ALIGNMENT SocketPool { void ListSockets(std::vector* list, size_t max_count); private: + // options used to create this instance + SocketOptions _options; butil::Mutex _mutex; std::vector _pool; butil::EndPoint _remote_side; @@ -438,7 +445,6 @@ Socket::Socket(Forbidden) , _auth_id(INVALID_BTHREAD_ID) , _auth_context(NULL) , _ssl_state(SSL_UNKNOWN) - , _ssl_ctx(NULL) , _ssl_session(NULL) , _connection_type_for_progressive_read(CONNECTION_TYPE_UNKNOWN) , _controller_released_socket(false) @@ -623,7 +629,6 @@ int Socket::Create(const SocketOptions& options, SocketId* id) { } // Disable SSL check if there is no SSL context m->_ssl_state = (options.ssl_ctx == NULL ? SSL_OFF : SSL_UNKNOWN); - m->_ssl_ctx = options.ssl_ctx; m->_ssl_session = NULL; m->_connection_type_for_progressive_read = CONNECTION_TYPE_UNKNOWN; m->_controller_released_socket.store(false, butil::memory_order_relaxed); @@ -654,6 +659,7 @@ int Socket::Create(const SocketOptions& options, SocketId* id) { return -1; } *id = m->_this_id; + m->_options = options; return 0; } @@ -699,6 +705,7 @@ int Socket::WaitAndReset(int32_t expected_nref) { SSL_free(_ssl_session); _ssl_session = NULL; } + _ssl_state = SSL_UNKNOWN; _nevent.store(0, butil::memory_order_relaxed); // parsing_context is very likely to be associated with the fd, // removing it is a safer choice and required by http2. @@ -713,10 +720,6 @@ int Socket::WaitAndReset(int32_t expected_nref) { LOG(FATAL) << "Fail to create _auth_id, " << berror(rc); return -1; } - // Client side(doing HC) does not support SSL now. - CHECK_NE(SSL_CONNECTED, _ssl_state); - CHECK_EQ((SSL_CTX*)NULL, _ssl_ctx); - CHECK_EQ((SSL*)NULL, _ssl_session); const int64_t cpuwide_now = butil::cpuwide_time_us(); _last_readtime_us.store(cpuwide_now, butil::memory_order_relaxed); @@ -1032,6 +1035,10 @@ void Socket::OnRecycle() { SSL_free(_ssl_session); _ssl_session = NULL; } + + if (_options.owns_ssl_ctx && _options.ssl_ctx) { + SSL_CTX_free(_options.ssl_ctx); + } delete _pipeline_q; _pipeline_q = NULL; @@ -1143,10 +1150,8 @@ int Socket::WaitEpollOut(int fd, bool pollin, const timespec* abstime) { int Socket::Connect(const timespec* abstime, int (*on_connect)(int, int, void*), void* data) { - if (_ssl_ctx) { - LOG(FATAL) << "Currently client doesn't support SSL"; - errno = EINVAL; - return -1; + if (_options.ssl_ctx) { + _ssl_state = SSL_CONNECTING; } else { _ssl_state = SSL_OFF; } @@ -1263,7 +1268,8 @@ int Socket::CheckConnected(int sockfd) { if (CreatedByConnect()) { s_vars->channel_conn << 1; } - return 0; + // Doing SSL handshake after TCP connected + return SSLHandshake(sockfd, false); } int Socket::ConnectIfNot(const timespec* abstime, WriteRequest* req) { @@ -1368,13 +1374,41 @@ void Socket::AfterAppConnected(int err, void* data) { err = ENETUNREACH; } } + s->SetFailed(err, "Fail to connect %s: %s", s->description().c_str(), berror(err)); s->ReleaseAllFailedWriteRequests(req); } } +static void* RunClosure(void* arg) { + google::protobuf::Closure* done = (google::protobuf::Closure*)arg; + done->Run(); + return NULL; +} + int Socket::KeepWriteIfConnected(int fd, int err, void* data) { + WriteRequest* req = static_cast(data); + Socket* s = req->socket; + if (err == 0 && s->ssl_state() == SSL_CONNECTING) { + // Run ssl connect in a new bthread to avoid blocking + // the current bthread (thus blocking the EventDispatcher) + bthread_t th; + google::protobuf::Closure* thrd_func = brpc::NewCallback( + Socket::CheckConnectedAndKeepWrite, fd, err, data); + if ((err = bthread_start_background(&th, &BTHREAD_ATTR_NORMAL, + RunClosure, thrd_func)) == 0) { + return 0; + } else { + PLOG(ERROR) << "Fail to start bthread"; + // Fall through with non zero `err' + } + } + CheckConnectedAndKeepWrite(fd, err, data); + return 0; +} + +void Socket::CheckConnectedAndKeepWrite(int fd, int err, void* data) { butil::fd_guard sockfd(fd); WriteRequest* req = static_cast(data); Socket* s = req->socket; @@ -1391,13 +1425,12 @@ int Socket::KeepWriteIfConnected(int fd, int err, void* data) { sockfd.release(); } else { if (err == 0) { - err = errno; + err = errno ? errno : -1; } AfterAppConnected(err, req); } - return 0; } - + inline int SetError(bthread_id_t id_wait, int ec) { if (id_wait != INVALID_BTHREAD_ID) { bthread_id_error(id_wait, ec); @@ -1428,6 +1461,13 @@ int Socket::ConductError(bthread_id_t id_wait) { } } +X509* Socket::GetPeerCertificate() const { + if (ssl_state() != SSL_CONNECTED) { + return NULL; + } + return SSL_get_peer_certificate(_ssl_session); +} + int Socket::Write(butil::IOBuf* data, const WriteOptions* options_in) { WriteOptions opt; if (options_in) { @@ -1671,14 +1711,15 @@ void* Socket::KeepWrite(void* void_arg) { } ssize_t Socket::DoWrite(WriteRequest* req) { + // Group butil::IOBuf in the list into a batch array. + butil::IOBuf* data_list[DATA_LIST_MAX]; + size_t ndata = 0; + for (WriteRequest* p = req; p != NULL && ndata < DATA_LIST_MAX; + p = p->next) { + data_list[ndata++] = &p->data; + } + if (ssl_state() == SSL_OFF) { - // Group butil::IOBuf in the list into a batch array. - butil::IOBuf* data_list[DATA_LIST_MAX]; - size_t ndata = 0; - for (WriteRequest* p = req; p != NULL && ndata < DATA_LIST_MAX; - p = p->next) { - data_list[ndata++] = &p->data; - } // Write IOBuf in the batch array into the fd. if (_conn) { return _conn->CutMessageIntoFileDescriptor(fd(), data_list, ndata); @@ -1687,52 +1728,109 @@ ssize_t Socket::DoWrite(WriteRequest* req) { fd(), data_list, ndata); return nw; } - } else if (ssl_state() == SSL_UNKNOWN) { - LOG(FATAL) << "Impossible! SSL state MUST have been set before"; - errno = EINVAL; - return -1; } + CHECK(ssl_state() == SSL_CONNECTED); - - ssize_t nw = 0; + if (_conn) { + // TODO: Separate SSL stuff from SocketConnection + return _conn->CutMessageIntoSSLChannel(_ssl_session, data_list, ndata); + } int ssl_error = 0; - bool need_continue = false; - do { - need_continue = false; - if (_conn) { - nw = _conn->CutMessageIntoSSLChannel(&req->data, _ssl_session, &ssl_error); - } else { - nw = req->data.cut_into_SSL_channel(_ssl_session, &ssl_error); + ssize_t nw = butil::IOBuf::cut_multiple_into_SSL_channel( + _ssl_session, data_list, ndata, &ssl_error); + switch (ssl_error) { + case SSL_ERROR_NONE: + break; + + case SSL_ERROR_WANT_READ: + // Disable renegotiation + errno = EPROTO; + return -1; + + case SSL_ERROR_WANT_WRITE: + errno = EAGAIN; + break; + + default: { + const unsigned long e = ERR_get_error(); + if (e != 0) { + LOG(WARNING) << "Fail to write into ssl_fd=" << fd() << ": " + << SSLError(ERR_get_error()); + errno = ESSL; + } else { + // System error with corresponding errno set + PLOG(WARNING) << "Fail to write into ssl_fd=" << fd(); } + break; + } + } + return nw; +} + +int Socket::SSLHandshake(int fd, bool server_mode) { + if (_options.ssl_ctx == NULL) { + return 0; + } + + // TODO: Reuse ssl session id for client + if (_ssl_session) { + // Free the last session, which may be deprecated when socket failed + SSL_free(_ssl_session); + } + _ssl_session = CreateSSLSession(_options.ssl_ctx, id(), fd, server_mode); + if (_ssl_session == NULL) { + return -1; + } +#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME + if (!_options.sni_name.empty()) { + SSL_set_tlsext_host_name(_ssl_session, _options.sni_name.c_str()); + } +#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME + _ssl_state = SSL_CONNECTING; + + // Loop until SSL handshake has completed. For SSL_ERROR_WANT_READ/WRITE, + // we use bthread_fd_wait as polling mechanism instead of EventDispatcher + // as it may confuse the origin event processing code. + while (true) { + int rc = SSL_do_handshake(_ssl_session); + if (rc == 1) { + _ssl_state = SSL_CONNECTED; + if (FLAGS_http_verbose) { + std::cerr << _ssl_session << std::endl; + } + AddBIOBuffer(_ssl_session, fd, FLAGS_ssl_bio_buffer_size); + return 0; + } + + int ssl_error = SSL_get_error(_ssl_session, rc); switch (ssl_error) { - case SSL_ERROR_NONE: // `nw' > 0 - break; - case SSL_ERROR_WANT_READ: - // Wait for EPOLLIN to finish renegotiation - if (bthread_fd_wait(fd(), EPOLLIN) == 0) { - need_continue = true; + if (bthread_fd_wait(fd, EPOLLIN) != 0) { + return -1; } break; - + case SSL_ERROR_WANT_WRITE: - // Regard this error as EAGAIN - errno = EAGAIN; + if (bthread_fd_wait(fd, EPOLLOUT) != 0) { + return -1; + } break; - + default: { - // For write operations, regard EOF as error const unsigned long e = ERR_get_error(); - if (e != 0) { - LOG(WARNING) << "Fail to write into ssl_fd=" << fd() - << ": " << SSLError(e); + if (ssl_error == SSL_ERROR_ZERO_RETURN || e == 0) { + errno = ECONNRESET; + LOG(ERROR) << "SSL connection was shutdown by peer: " << _remote_side; + } else if (ssl_error == SSL_ERROR_SYSCALL) { + PLOG(ERROR) << "Fail to SSL_do_handshake"; + } else { + errno = ESSL; + LOG(ERROR) << "Fail to SSL_do_handshake: " << SSLError(e); } - errno = ESSL; - break; + return -1; } } - } while (need_continue); - return nw; + } } ssize_t Socket::DoRead(size_t size_hint) { @@ -1749,12 +1847,7 @@ ssize_t Socket::DoRead(size_t size_hint) { } case SSL_CONNECTING: - if (_ssl_session != NULL) { - // Free the last SSL session - SSL_free(_ssl_session); - } - _ssl_session = CreateSSLSession(_ssl_ctx, id(), fd(), true); - if (_ssl_session == NULL) { + if (SSLHandshake(fd(), true) != 0) { errno = EINVAL; return -1; } @@ -1773,47 +1866,38 @@ ssize_t Socket::DoRead(size_t size_hint) { return _read_buf.append_from_file_descriptor(fd(), size_hint); } - // Doing SSL handshake inside `append_from_SSL_channel' - CHECK(ssl_state() == SSL_CONNECTING || ssl_state() == SSL_CONNECTED); - ssize_t nr = 0; + CHECK(ssl_state() == SSL_CONNECTED); int ssl_error = 0; - bool need_continue = false; - do { - need_continue = false; - nr = _read_buf.append_from_SSL_channel(_ssl_session, &ssl_error); - switch (ssl_error) { - case SSL_ERROR_NONE: // `nr' > 0 - break; + ssize_t nr = _read_buf.append_from_SSL_channel(_ssl_session, &ssl_error, size_hint); + switch (ssl_error) { + case SSL_ERROR_NONE: // `nr' > 0 + break; - case SSL_ERROR_WANT_READ: - // Regard this error as EAGAIN - errno = EAGAIN; - break; + case SSL_ERROR_WANT_READ: + // Regard this error as EAGAIN + errno = EAGAIN; + break; - case SSL_ERROR_WANT_WRITE: - // Wait for EPOLLOUT to finish renegotiation - if (bthread_fd_wait(fd(), EPOLLOUT) == 0) { - need_continue = true; - } - break; + case SSL_ERROR_WANT_WRITE: + // Disable renegotiation + errno = EPROTO; + return -1; - default: { - const unsigned long e = ERR_get_error(); - if (nr == 0) { - // Socket EOF or SSL session EOF - // TODO(jiangrujie): DO NOT close the socket when - // receiving SSL session EOF - } else if (e != 0) { - LOG(WARNING) << "Fail to read from ssl_fd=" << fd() - << ": " << SSLError(e); - errno = ESSL; - } else { - // System error with corresponding errno set - } - break; - } + default: { + const unsigned long e = ERR_get_error(); + if (nr == 0) { + // Socket EOF or SSL session EOF + } else if (e != 0) { + LOG(WARNING) << "Fail to read from ssl_fd=" << fd() + << ": " << SSLError(e); + errno = ESSL; + } else { + // System error with corresponding errno set + PLOG(WARNING) << "Fail to read from ssl_fd=" << fd(); } - } while (need_continue); + break; + } + } return nr; } @@ -2051,12 +2135,15 @@ void Socket::DebugSocket(std::ostream& os, SocketId id) { << "\nauth_id=" << ptr->_auth_id.value << "\nauth_context=" << ptr->_auth_context << "\nssl_state=" << SSLStateToString(ptr->_ssl_state) - << "\nssl_ctx=" << (void*)ptr->_ssl_ctx - << "\nssl_session=" << (void*)ptr->_ssl_session // TODO: print SSL internal + << "\nssl_ctx=" << (void*)ptr->_options.ssl_ctx + << "\nssl_session=" << (void*)ptr->_ssl_session << "\nlogoff_flag=" << ptr->_logoff_flag.load(butil::memory_order_relaxed) << "\nrecycle_flag=" << ptr->_recycle_flag.load(butil::memory_order_relaxed) << "\ncid=" << ptr->_correlation_id << "\nwrite_head=" << ptr->_write_head.load(butil::memory_order_relaxed); + if (ptr->ssl_state() == SSL_CONNECTED) { + os << "\n\n" << ptr->_ssl_session; + } #if defined(OS_MACOSX) struct tcp_connection_info ti; socklen_t len = sizeof(ti); @@ -2182,8 +2269,8 @@ void SocketUser::AfterRevived(Socket* ptr) { ////////// SocketPool ////////////// -inline SocketPool::SocketPool(const butil::EndPoint& pt) - : _remote_side(pt), _count(0) { +inline SocketPool::SocketPool(const SocketOptions& opt) + : _options(opt), _remote_side(opt.remote_side), _count(0) { } inline SocketPool::~SocketPool() { @@ -2233,7 +2320,11 @@ inline int SocketPool::GetSocket(SocketUniquePtr* ptr) { } } // Not found in pool - if (get_client_side_messenger()->Create(_remote_side, -1, &sid) == 0) { + SocketOptions opt = _options; + // Only main socket can be the owner of ssl_ctx + opt.owns_ssl_ctx = false; + opt.health_check_interval_s = -1; + if (get_client_side_messenger()->Create(opt, &sid) == 0) { return Socket::Address(sid, ptr); } return -1; @@ -2321,7 +2412,7 @@ int Socket::GetPooledSocket(Socket* main_socket, // Create socket_pool optimistically. SocketPool* socket_pool = main_sp->socket_pool.load(butil::memory_order_consume); if (socket_pool == NULL) { - socket_pool = new SocketPool(main_socket->remote_side()); + socket_pool = new SocketPool(main_socket->_options); SocketPool* expected = NULL; if (!main_sp->socket_pool.compare_exchange_strong( expected, socket_pool, butil::memory_order_acq_rel)) { @@ -2389,7 +2480,11 @@ int Socket::GetShortSocket(Socket* main_socket, return -1; } SocketId id; - if (get_client_side_messenger()->Create(main_socket->remote_side(), -1, &id) != 0) { + SocketOptions opt = main_socket->_options; + // Only main socket can be the owner of ssl_ctx + opt.owns_ssl_ctx = false; + opt.health_check_interval_s = -1; + if (get_client_side_messenger()->Create(opt, &id) != 0) { return -1; } if (Socket::Address(id, short_socket) != 0) { diff --git a/src/brpc/socket.h b/src/brpc/socket.h index 9d36a1a8fa..53f16919cb 100644 --- a/src/brpc/socket.h +++ b/src/brpc/socket.h @@ -23,13 +23,14 @@ #include // std::deque #include // std::set #include "butil/atomicops.h" // butil::atomic -#include "bthread/types.h" // bthread_id_t +#include "bthread/types.h" // bthread_id_t #include "butil/iobuf.h" // butil::IOBuf, IOPortal #include "butil/macros.h" // DISALLOW_COPY_AND_ASSIGN #include "butil/endpoint.h" // butil::EndPoint #include "butil/resource_pool.h" // butil::ResourceId -#include "bthread/butex.h" // butex_create_checked +#include "bthread/butex.h" // butex_create_checked #include "brpc/authenticator.h" // Authenticator +#include "brpc/errno.pb.h" // EFAILEDSOCKET #include "brpc/details/ssl_helper.h" // SSLState #include "brpc/stream.h" // StreamId #include "brpc/destroyable.h" // Destroyable @@ -86,7 +87,7 @@ class SocketConnection { // Cut IOBufs into fd or SSL Channel virtual ssize_t CutMessageIntoFileDescriptor(int, butil::IOBuf**, size_t) = 0; - virtual ssize_t CutMessageIntoSSLChannel(butil::IOBuf *, SSL*, int*) = 0; + virtual ssize_t CutMessageIntoSSLChannel(SSL*, butil::IOBuf**, size_t) = 0; }; // Application-level connect. After TCP connected, the client sends some @@ -154,7 +155,9 @@ struct SocketOptions { // one thread at any time. void (*on_edge_triggered_events)(Socket*); int health_check_interval_s; + bool owns_ssl_ctx; SSL_CTX* ssl_ctx; + std::string sni_name; bthread_keytable_pool_t* keytable_pool; SocketConnection* conn; AppConnect* app_connect; @@ -388,7 +391,7 @@ friend class schan::ChannelBalancer; void CheckEOF(); SSLState ssl_state() const { return _ssl_state; } - void set_ssl_state(SSLState s) { _ssl_state = s; } + X509* GetPeerCertificate() const; // Print debugging inforamtion of `id' into the ostream. static void DebugSocket(std::ostream&, SocketId id); @@ -466,6 +469,13 @@ friend void DereferenceSocket(Socket*); static int Status(SocketId, int32_t* nref = NULL); // for unit-test. + // Perform SSL handshake after TCP connection has been established. + // Create SSL session inside and block (in bthread) until handshake + // has completed. Application layer I/O is forbidden during this + // process to avoid concurrent I/O on the underlying fd + // Returns 0 on success, -1 otherwise + int SSLHandshake(int fd, bool server_mode); + // Based upon whether the underlying channel is using SSL (if // SSLState is SSL_UNKNOWN, try to detect at first), read data // using the corresponding method into `_read_buf'. Returns read @@ -544,6 +554,7 @@ friend void DereferenceSocket(Socket*); // Callback when connection event reaches (succeeded or not) // This callback will be passed to `Connect' static int KeepWriteIfConnected(int fd, int err, void* data); + static void CheckConnectedAndKeepWrite(int fd, int err, void* data); static void AfterAppConnected(int err, void* data); static void CreateVarsOnce(); @@ -627,6 +638,9 @@ friend void DereferenceSocket(Socket*); // carefully before implementing the callback. void (*_on_edge_triggered_events)(Socket*); + // Original options used to create this Socket + SocketOptions _options; + // A set of callbacks to monitor important events of this socket. // Initialized by SocketOptions.user SocketUser* _user; @@ -693,7 +707,6 @@ friend void DereferenceSocket(Socket*); AuthContext* _auth_context; SSLState _ssl_state; - SSL_CTX* _ssl_ctx; // not owner SSL* _ssl_session; // owner // Pass from controller, for progressive reading. diff --git a/src/brpc/socket_inl.h b/src/brpc/socket_inl.h index dfdf4b2719..cb79406603 100644 --- a/src/brpc/socket_inl.h +++ b/src/brpc/socket_inl.h @@ -54,6 +54,7 @@ inline SocketOptions::SocketOptions() , user(NULL) , on_edge_triggered_events(NULL) , health_check_interval_s(-1) + , owns_ssl_ctx(false) , ssl_ctx(NULL) , keytable_pool(NULL) , conn(NULL) diff --git a/src/brpc/socket_map.cpp b/src/brpc/socket_map.cpp index 631c5e7bc0..68835ea9aa 100644 --- a/src/brpc/socket_map.cpp +++ b/src/brpc/socket_map.cpp @@ -17,14 +17,17 @@ #include #include +#include "bthread/bthread.h" #include "butil/time.h" #include "butil/scoped_lock.h" #include "butil/logging.h" +#include "butil/third_party/murmurhash3/murmurhash3.h" #include "brpc/log.h" #include "brpc/protocol.h" #include "brpc/input_messenger.h" #include "brpc/reloadable_flags.h" #include "brpc/socket_map.h" +#include "brpc/details/ssl_helper.h" // CreateClientSSLContext namespace brpc { @@ -54,9 +57,10 @@ static butil::static_atomic g_socket_map = BUTIL_STATIC_ATOMIC_INIT( class GlobalSocketCreator : public SocketCreator { public: - int CreateSocket(const butil::EndPoint& pt, SocketId* id) { - return get_client_side_messenger()->Create( - pt, FLAGS_health_check_interval, id); + int CreateSocket(const SocketOptions& opt, SocketId* id) { + SocketOptions sock_opt = opt; + sock_opt.health_check_interval_s = FLAGS_health_check_interval; + return get_client_side_messenger()->Create(sock_opt, id); } }; @@ -84,26 +88,80 @@ SocketMap* get_or_new_client_side_socket_map() { return g_socket_map.load(butil::memory_order_consume); } -int SocketMapInsert(butil::EndPoint pt, SocketId* id) { - return get_or_new_client_side_socket_map()->Insert(pt, id); +void ComputeSocketMapKeyChecksum(const SocketMapKey& key, + unsigned char* checksum) { + butil::MurmurHash3_x64_128_Context mm_ctx; + butil::MurmurHash3_x64_128_Init(&mm_ctx, 0); + + const int BUFSIZE = 1024; // Should be enough + char buf[BUFSIZE]; + int cur_len = 0; + +#define SAFE_MEMCOPY(dst, cur_len, src, size) \ + do { \ + int copy_len = std::min((int)size, BUFSIZE - cur_len); \ + if (copy_len > 0) { \ + memcpy(dst + cur_len, src, copy_len); \ + cur_len += copy_len; \ + } \ + } while (0); + + std::size_t ephash = butil::DefaultHasher()(key.peer); + SAFE_MEMCOPY(buf, cur_len, &ephash, sizeof(ephash)); + SAFE_MEMCOPY(buf, cur_len, &key.auth, sizeof(key.auth)); + + const ChannelSSLOptions& ssl = key.ssl_options; + SAFE_MEMCOPY(buf, cur_len, &ssl.enable, sizeof(ssl.enable)); + if (ssl.enable) { + SAFE_MEMCOPY(buf, cur_len, ssl.ciphers.data(), ssl.ciphers.size()); + SAFE_MEMCOPY(buf, cur_len, ssl.protocols.data(), ssl.protocols.size()); + SAFE_MEMCOPY(buf, cur_len, ssl.sni_name.data(), ssl.sni_name.size()); + + const VerifyOptions& verify = ssl.verify; + SAFE_MEMCOPY(buf, cur_len, &verify.verify_depth, + sizeof(verify.verify_depth)); + if (verify.verify_depth > 0) { + SAFE_MEMCOPY(buf, cur_len, verify.ca_file_path.data(), + verify.ca_file_path.size()); + } + } else { + // All disabled ChannelSSLOptions are the same + } +#undef SAFE_MEMCOPY + + butil::MurmurHash3_x64_128_Update(&mm_ctx, buf, cur_len); + const CertInfo& cert = ssl.client_cert; + if (ssl.enable && !cert.certificate.empty()) { + // Certificate may be too long (PEM string) to fit into `buf' + butil::MurmurHash3_x64_128_Update( + &mm_ctx, cert.certificate.data(), cert.certificate.size()); + butil::MurmurHash3_x64_128_Update( + &mm_ctx, cert.private_key.data(), cert.private_key.size()); + // sni_filters has no effect in ChannelSSLOptions + } + butil::MurmurHash3_x64_128_Final(checksum, &mm_ctx); +} + +int SocketMapInsert(const SocketMapKey& key, SocketId* id) { + return get_or_new_client_side_socket_map()->Insert(key, id); } -int SocketMapFind(butil::EndPoint pt, SocketId* id) { +int SocketMapFind(const SocketMapKey& key, SocketId* id) { SocketMap* m = get_client_side_socket_map(); if (m) { - return m->Find(pt, id); + return m->Find(key, id); } return -1; } -void SocketMapRemove(butil::EndPoint pt) { +void SocketMapRemove(const SocketMapKey& key) { SocketMap* m = get_client_side_socket_map(); if (m) { // TODO: We don't have expected_id to pass right now since the callsite // at NamingServiceThread is hard to be fixed right now. As long as // FLAGS_health_check_interval is limited to positive, SocketMapInsert // never replaces the sockets, skipping comparison is still right. - m->Remove(pt, (SocketId)-1); + m->Remove(key, (SocketId)-1); } } @@ -204,9 +262,10 @@ void SocketMap::PrintSocketMap(std::ostream& os, void* arg) { static_cast(arg)->Print(os); } -int SocketMap::Insert(const butil::EndPoint& pt, SocketId* id) { +int SocketMap::Insert(const SocketMapKey& key, SocketId* id) { + SocketMapKeyChecksum ck(key); std::unique_lock mu(_mutex); - SingleConnection* sc = _map.seek(pt); + SingleConnection* sc = _map.seek(ck); if (sc) { if (!sc->socket->Failed() || sc->socket->health_check_interval() > 0/*HC enabled*/) { @@ -216,29 +275,40 @@ int SocketMap::Insert(const butil::EndPoint& pt, SocketId* id) { } // A socket w/o HC is failed (permanently), replace it. SocketUniquePtr ptr(sc->socket); // Remove the ref added at insertion. - _map.erase(pt); // in principle, we can override the entry in map w/o + _map.erase(ck); // in principle, we can override the entry in map w/o // removing and inserting it again. But this would make error branches // below have to remove the entry before returning, which is // error-prone. We prefer code maintainability here. sc = NULL; } + std::unique_ptr ssl_ctx( + CreateClientSSLContext(key.ssl_options)); + if (key.ssl_options.enable && !ssl_ctx) { + return -1; + } SocketId tmp_id; - if (_options.socket_creator->CreateSocket(pt, &tmp_id) != 0) { - mu.unlock(); - PLOG(FATAL) << "Fail to create socket to " << pt; + SocketOptions opt; + opt.remote_side = key.peer; + // Can't save SSL_CTX in SocketMap since SingleConnection's desctruction + // may happen before Socket's destruction (remove Channel before RPC complete) + opt.owns_ssl_ctx = true; + opt.ssl_ctx = ssl_ctx.get(); + opt.sni_name = key.ssl_options.sni_name; + if (_options.socket_creator->CreateSocket(opt, &tmp_id) != 0) { + PLOG(FATAL) << "Fail to create socket to " << key.peer; return -1; } + ssl_ctx.release(); // Add a reference to make sure that sc->socket is always accessible. Not // use SocketUniquePtr which cannot put into containers before c++11. // The ref will be removed at entry's removal. SocketUniquePtr ptr; if (Socket::Address(tmp_id, &ptr) != 0) { - mu.unlock(); LOG(FATAL) << "Fail to address SocketId=" << tmp_id; return -1; } SingleConnection new_sc = { 1, ptr.release(), 0 }; - _map[pt] = new_sc; + _map[ck] = new_sc; *id = tmp_id; bool need_to_create_bvar = false; if (FLAGS_show_socketmap_in_vars && !_exposed_in_bvar) { @@ -255,15 +325,16 @@ int SocketMap::Insert(const butil::EndPoint& pt, SocketId* id) { return 0; } -void SocketMap::Remove(const butil::EndPoint& pt, SocketId expected_id) { - return RemoveInternal(pt, expected_id, false); +void SocketMap::Remove(const SocketMapKey& key, SocketId expected_id) { + return RemoveInternal(key, expected_id, false); } -void SocketMap::RemoveInternal(const butil::EndPoint& pt, +void SocketMap::RemoveInternal(const SocketMapKey& key, SocketId expected_id, bool remove_orphan) { + SocketMapKeyChecksum ck(key); std::unique_lock mu(_mutex); - SingleConnection* sc = _map.seek(pt); + SingleConnection* sc = _map.seek(ck); if (!sc) { return; } @@ -281,7 +352,7 @@ void SocketMap::RemoveInternal(const butil::EndPoint& pt, sc->no_ref_us = butil::cpuwide_time_us(); } else { Socket* const s = sc->socket; - _map.erase(pt); + _map.erase(ck); bool need_to_create_bvar = false; if (FLAGS_show_socketmap_in_vars && !_exposed_in_bvar) { _exposed_in_bvar = true; @@ -300,9 +371,10 @@ void SocketMap::RemoveInternal(const butil::EndPoint& pt, } } -int SocketMap::Find(const butil::EndPoint& pt, SocketId* id) { +int SocketMap::Find(const SocketMapKey& key, SocketId* id) { + SocketMapKeyChecksum ck(key); BAIDU_SCOPED_LOCK(_mutex); - SingleConnection* sc = _map.seek(pt); + SingleConnection* sc = _map.seek(ck); if (sc) { *id = sc->socket->id(); return 0; @@ -333,7 +405,7 @@ void SocketMap::ListOrphans(int64_t defer_us, std::vector* out) for (Map::iterator it = _map.begin(); it != _map.end(); ++it) { SingleConnection& sc = it->second; if (sc.ref_count == 0 && now - sc.no_ref_us >= defer_us) { - out->push_back(it->first); + out->push_back(it->first.peer); } } } diff --git a/src/brpc/socket_map.h b/src/brpc/socket_map.h index c3ca30712b..4035302f88 100644 --- a/src/brpc/socket_map.h +++ b/src/brpc/socket_map.h @@ -14,10 +14,15 @@ // Authors: Ge,Jun (gejun@baidu.com) +#ifndef BRPC_SOCKET_MAP_H +#define BRPC_SOCKET_MAP_H + #include // std::vector -#include "butil/containers/flat_map.h" // FlatMap +#include "bvar/bvar.h" // bvar::PassiveStatus +#include "butil/containers/flat_map.h" // FlatMap #include "brpc/socket_id.h" // SockdetId #include "brpc/options.pb.h" // ProtocolType +#include "brpc/ssl_option.h" // ChannelSSLOptions #include "brpc/input_messenger.h" // InputMessageHandler @@ -25,18 +30,36 @@ namespace brpc { // Global mapping from remote-side to out-going sockets created by Channels. -// Try to share the Socket to `pt'. If the Socket does not exist, create one. +// The following fields uniquely define a Socket. In other word, +// Socket can't be shared between 2 different SocketMapKeys +struct SocketMapKey { + SocketMapKey(const butil::EndPoint& pt, + ChannelSSLOptions ssl = ChannelSSLOptions(), + const Authenticator* auth2 = NULL) + : peer(pt), ssl_options(ssl), auth(auth2) + {} + + butil::EndPoint peer; + ChannelSSLOptions ssl_options; + const Authenticator* auth; +}; + +// Calculate an 128-bit hashcode for SocketMapKey +void ComputeSocketMapKeyChecksum(const SocketMapKey& key, + unsigned char* checksum); + +// Try to share the Socket to `key'. If the Socket does not exist, create one. // The corresponding SocketId is written to `*id'. If this function returns // successfully, SocketMapRemove() MUST be called when the Socket is not needed. // Return 0 on success, -1 otherwise. -int SocketMapInsert(butil::EndPoint pt, SocketId* id); +int SocketMapInsert(const SocketMapKey& key, SocketId* id); -// Find the SocketId associated with `pt'. +// Find the SocketId associated with `key'. // Return 0 on found, -1 otherwise. -int SocketMapFind(butil::EndPoint pt, SocketId* id); +int SocketMapFind(const SocketMapKey& key, SocketId* id); // Called once when the Socket returned by SocketMapInsert() is not needed. -void SocketMapRemove(butil::EndPoint pt); +void SocketMapRemove(const SocketMapKey& key); // Put all existing Sockets into `ids' void SocketMapList(std::vector* ids); @@ -49,7 +72,7 @@ void SocketMapList(std::vector* ids); class SocketCreator { public: virtual ~SocketCreator() {} - virtual int CreateSocket(const butil::EndPoint& pt, SocketId* id) = 0; + virtual int CreateSocket(const SocketOptions& opt, SocketId* id) = 0; }; struct SocketMapOptions { @@ -87,15 +110,15 @@ class SocketMap { SocketMap(); ~SocketMap(); int Init(const SocketMapOptions&); - int Insert(const butil::EndPoint& pt, SocketId* id); - void Remove(const butil::EndPoint& pt, SocketId expected_id); - int Find(const butil::EndPoint& pt, SocketId* id); + int Insert(const SocketMapKey& key, SocketId* id); + void Remove(const SocketMapKey& key, SocketId expected_id); + int Find(const SocketMapKey& key, SocketId* id); void List(std::vector* ids); void List(std::vector* pts); const SocketMapOptions& options() const { return _options; } private: - void RemoveInternal(const butil::EndPoint& pt, SocketId id, + void RemoveInternal(const SocketMapKey& key, SocketId id, bool remove_orphan); void ListOrphans(int64_t defer_us, std::vector* out); void WatchConnections(); @@ -109,9 +132,40 @@ class SocketMap { Socket* socket; int64_t no_ref_us; }; + + // Store checksum of SocketMapKey instead of itself in order to: + // 1. Save precious space of key field in FlatMap + // 2. Simplify equivalence logic between SocketMapKeys + // (regard the hash collision to be zero) + struct SocketMapKeyChecksum { + explicit SocketMapKeyChecksum(const SocketMapKey& key) + : peer(key.peer) { + ComputeSocketMapKeyChecksum(key, checksum); + } + + butil::EndPoint peer; + unsigned char checksum[16]; + + inline bool operator==(const SocketMapKeyChecksum& rhs) const { + return this->peer == rhs.peer + && memcmp(this->checksum, rhs.checksum, sizeof(checksum)) == 0; + } + }; + + struct Checksum2Hash { + std::size_t operator()(const SocketMapKeyChecksum& key) const { + // Slice a subset of checksum over an evenly distributed hash + // won't affect the overall balance + std::size_t hash; + memcpy(&hash, key.checksum, sizeof(hash)); + return hash; + } + }; + // TODO: When RpcChannels connecting to one EndPoint are frequently created - // and destroyed, a single map+mutex may become hot-spots. - typedef butil::FlatMap Map; + // and destroyed, a single map+mutex may become hot-spots. + typedef butil::FlatMap Map; SocketMapOptions _options; butil::Mutex _mutex; Map _map; @@ -122,3 +176,5 @@ class SocketMap { }; } // namespace brpc + +#endif // BRPC_SOCKET_MAP_H diff --git a/src/brpc/socket_message.h b/src/brpc/socket_message.h index 24fa4f855a..47a8fa2456 100644 --- a/src/brpc/socket_message.h +++ b/src/brpc/socket_message.h @@ -17,6 +17,8 @@ #ifndef BRPC_SOCKET_MESSAGE_H #define BRPC_SOCKET_MESSAGE_H +#include "butil/status.h" // butil::Status + namespace brpc { diff --git a/src/brpc/ssl_option.cpp b/src/brpc/ssl_option.cpp new file mode 100644 index 0000000000..fb5116e91d --- /dev/null +++ b/src/brpc/ssl_option.cpp @@ -0,0 +1,38 @@ +// Copyright (c) 2014 baidu-rpc authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Authors: Rujie Jiang (jiangrujie@baidu.com) + +#include "brpc/ssl_option.h" + +namespace brpc { + +VerifyOptions::VerifyOptions() : verify_depth(0) {} + +ChannelSSLOptions::ChannelSSLOptions() + : enable(false) + , ciphers("DEFAULT") + , protocols("TLSv1, TLSv1.1, TLSv1.2") +{} + +ServerSSLOptions::ServerSSLOptions() + : strict_sni(false) + , disable_ssl3(true) + , release_buffer(false) + , session_lifetime_s(300) + , session_cache_size(20480) + , ecdhe_curve_name("prime256v1") +{} + +} // namespace brpc diff --git a/src/brpc/ssl_option.h b/src/brpc/ssl_option.h new file mode 100644 index 0000000000..48924d4ef0 --- /dev/null +++ b/src/brpc/ssl_option.h @@ -0,0 +1,162 @@ +// Copyright (c) 2014 baidu-rpc authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Authors: Rujie Jiang (jiangrujie@baidu.com) + +#ifndef BRPC_SSL_OPTION_H +#define BRPC_SSL_OPTION_H + +#include +#include + +namespace brpc { + +struct CertInfo { + // Certificate in PEM format. + // Note that CN and alt subjects will be extracted from the certificate, + // and will be used as hostnames. Requests to this hostname (provided SNI + // extension supported) will be encrypted using this certifcate. + // Supported both file path and raw string + std::string certificate; + + // Private key in PEM format. + // Supported both file path and raw string based on prefix: + std::string private_key; + + // Additional hostnames besides those inside the certificate. Wildcards + // are supported but it can only appear once at the beginning (i.e. *.xxx.com). + std::vector sni_filters; +}; + +struct VerifyOptions { + // Constructed with default options + VerifyOptions(); + + // Set the maximum depth of the certificate chain for verification + // If 0, turn off the verification + // Default: 0 + int verify_depth; + + // Set the trusted CA file to verify the peer's certificate + // If empty, use the system default CA files + // Default: "" + std::string ca_file_path; +}; + +// SSL options at client side +struct ChannelSSLOptions { + // Constructed with default options + ChannelSSLOptions(); + + // Whether to enable SSL on the channel. + // Default: false + bool enable; + + // Cipher suites used for SSL handshake. + // The format of this string should follow that in `man 1 cipers'. + // Default: "DEFAULT" + std::string ciphers; + + // SSL protocols used for SSL handshake, separated by comma. + // Available protocols: SSLv3, TLSv1, TLSv1.1, TLSv1.2 + // Default: TLSv1, TLSv1.1, TLSv1.2 + std::string protocols; + + // When set, fill this into the SNI extension field during handshake, + // which can be used by the server to locate the right certificate. + // Default: empty + std::string sni_name; + + // Certificate used for client authentication + // Default: empty + CertInfo client_cert; + + // Options used to verify the server's certificate + // Default: see above + VerifyOptions verify; + + // TODO: Support CRL +}; + +// SSL options at server side +struct ServerSSLOptions { + // Constructed with default options + ServerSSLOptions(); + + // Default certificate which will be loaded into server. Requests + // without hostname or whose hostname doesn't have a corresponding + // certificate will use this certificate. MUST be set to enable SSL. + CertInfo default_cert; + + // Additional certificates which will be loaded into server. These + // provide extra bindings between hostnames and certificates so that + // we can choose different certificates according to different hostnames. + // See `CertInfo' for detail. + std::vector certs; + + // When set, requests without hostname or whose hostname can't be found in + // any of the cerficates above will be dropped. Otherwise, `default_cert' + // will be used. + // Default: false + bool strict_sni; + + // When set, SSLv3 requests will be dropped. Strongly recommended since + // SSLv3 has been found suffering from severe security problems. Note that + // some old versions of browsers may use SSLv3 by default such as IE6.0 + // Default: true + bool disable_ssl3; + + // Flag for SSL_MODE_RELEASE_BUFFERS. When set, release read/write buffers + // when SSL connection is idle, which saves 34KB memory per connection. + // On the other hand, it introduces additional latency and reduces throughput + // Default: false + bool release_buffer; + + // Maximum lifetime for a session to be cached inside OpenSSL in seconds. + // A session can be reused (initiated by client) to save handshake before + // it reaches this timeout. + // Default: 300 + int session_lifetime_s; + + // Maximum number of cached sessions. When cache is full, no more new + // session will be added into the cache until SSL_CTX_flush_sessions is + // called (automatically by SSL_read/write). A special value is 0, which + // means no limit. + // Default: 20480 + int session_cache_size; + + // Cipher suites allowed for each SSL handshake. The format of this string + // should follow that in `man 1 cipers'. If empty, OpenSSL will choose + // a default cipher based on the certificate information + // Default: "" + std::string ciphers; + + // Name of the elliptic curve used to generate ECDH ephemerial keys + // Default: prime256v1 + std::string ecdhe_curve_name; + + // Options used to verify the client's certificate + // Default: see above + VerifyOptions verify; + + // TODO: Support NPN & ALPN + // TODO: Support OSCP stapling +}; + +// Legacy name defined in server.h +typedef ServerSSLOptions SSLOptions; + +} // namespace brpc + +#endif // BRPC_SSL_OPTION_H diff --git a/src/brpc/stream.cpp b/src/brpc/stream.cpp index 473d7a6746..a358ee2a2c 100644 --- a/src/brpc/stream.cpp +++ b/src/brpc/stream.cpp @@ -159,9 +159,9 @@ void Stream::WriteToHostSocket(butil::IOBuf* b) { BRPC_HANDLE_EOVERCROWDED(_host_socket->Write(b)); } -ssize_t Stream::CutMessageIntoSSLChannel(butil::IOBuf*, SSL*, int* error) { +ssize_t Stream::CutMessageIntoSSLChannel(SSL*, butil::IOBuf**, size_t) { CHECK(false) << "Stream does support SSL"; - *error = SSL_ERROR_SSL; + errno = EINVAL; return -1; } diff --git a/src/brpc/stream_impl.h b/src/brpc/stream_impl.h index b55e6870a2..ddcc99ebbe 100644 --- a/src/brpc/stream_impl.h +++ b/src/brpc/stream_impl.h @@ -35,7 +35,7 @@ class BAIDU_CACHELINE_ALIGNMENT Stream : public SocketConnection { int (*on_connect)(int, int, void *), void *data); ssize_t CutMessageIntoFileDescriptor(int, butil::IOBuf **data_list, size_t size); - ssize_t CutMessageIntoSSLChannel(butil::IOBuf*, SSL*, int*); + ssize_t CutMessageIntoSSLChannel(SSL*, butil::IOBuf**, size_t); void BeforeRecycle(Socket *); // --------------------- SocketConnection -------------- diff --git a/src/butil/iobuf.cpp b/src/butil/iobuf.cpp index deeac54b19..2656fee771 100644 --- a/src/butil/iobuf.cpp +++ b/src/butil/iobuf.cpp @@ -966,6 +966,51 @@ ssize_t IOBuf::cut_into_SSL_channel(SSL* ssl, int* ssl_error) { return nw; } +ssize_t IOBuf::cut_multiple_into_SSL_channel(SSL* ssl, IOBuf* const* pieces, + size_t count, int* ssl_error) { + ssize_t nw = 0; + *ssl_error = SSL_ERROR_NONE; + for (size_t i = 0; i < count; ) { + if (pieces[i]->empty()) { + ++i; + continue; + } + + ssize_t rc = pieces[i]->cut_into_SSL_channel(ssl, ssl_error); + if (rc > 0) { + nw += rc; + } else { + if (rc < 0) { + if (*ssl_error == SSL_ERROR_WANT_WRITE + || (*ssl_error == SSL_ERROR_SYSCALL + && BIO_fd_non_fatal_error(errno) == 1)) { + // Non fatal error, tell caller to write again + *ssl_error = SSL_ERROR_WANT_WRITE; + } else { + // Other errors are fatal + return rc; + } + } + if (nw == 0) { + nw = rc; // Nothing written yet, overwrite nw + } + break; + } + } + + // Flush remaining data inside the BIO buffer layer + BIO* wbio = SSL_get_wbio(ssl); + if (BIO_wpending(wbio) > 0) { + int rc = BIO_flush(wbio); + if (rc <= 0 && BIO_fd_non_fatal_error(errno) == 0) { + // Fatal error during BIO_flush + *ssl_error = SSL_ERROR_SYSCALL; + return rc; + } + } + return nw; +} + ssize_t IOBuf::pcut_multiple_into_file_descriptor( int fd, off_t offset, IOBuf* const* pieces, size_t count) { if (BAIDU_UNLIKELY(count == 0)) { @@ -1571,27 +1616,47 @@ ssize_t IOPortal::pappend_from_file_descriptor( return nr; } -ssize_t IOPortal::append_from_SSL_channel(SSL* ssl, int* ssl_error) { - if (!_block) { - _block = iobuf::acquire_tls_block(); - if (BAIDU_UNLIKELY(!_block)) { - errno = ENOMEM; - *ssl_error = SSL_ERROR_SYSCALL; - return -1; +ssize_t IOPortal::append_from_SSL_channel( + SSL* ssl, int* ssl_error, size_t max_count) { + size_t nr = 0; + do { + if (!_block) { + _block = iobuf::acquire_tls_block(); + if (BAIDU_UNLIKELY(!_block)) { + errno = ENOMEM; + *ssl_error = SSL_ERROR_SYSCALL; + return -1; + } } - } - const int nr = SSL_read(ssl, _block->data + _block->size, _block->left_space()); - *ssl_error = SSL_get_error(ssl, nr); - if (nr > 0) { - const IOBuf::BlockRef r = { (uint32_t)_block->size, (uint32_t)nr, _block }; - _push_back_ref(r); - _block->size += nr; - if (_block->full()) { - Block* const saved_next = _block->portal_next; - _block->dec_ref(); // _block may be deleted - _block = saved_next; + + const size_t read_len = std::min(_block->left_space(), max_count - nr); + const int rc = SSL_read(ssl, _block->data + _block->size, read_len); + *ssl_error = SSL_get_error(ssl, rc); + if (rc > 0) { + const IOBuf::BlockRef r = { (uint32_t)_block->size, (uint32_t)rc, _block }; + _push_back_ref(r); + _block->size += rc; + if (_block->full()) { + Block* const saved_next = _block->portal_next; + _block->dec_ref(); // _block may be deleted + _block = saved_next; + } + nr += rc; + } else { + if (rc < 0) { + if (*ssl_error == SSL_ERROR_WANT_READ + || (*ssl_error == SSL_ERROR_SYSCALL + && BIO_fd_non_fatal_error(errno) == 1)) { + // Non fatal error, tell caller to read again + *ssl_error = SSL_ERROR_WANT_READ; + } else { + // Other errors are fatal + return rc; + } + } + return (nr > 0 ? nr : rc); } - } + } while (nr < max_count); return nr; } diff --git a/src/butil/iobuf.h b/src/butil/iobuf.h index d5d1db0ac6..118af31428 100644 --- a/src/butil/iobuf.h +++ b/src/butil/iobuf.h @@ -163,6 +163,11 @@ friend class IOBufAsZeroCopyOutputStream; // and the ssl error code will be filled into `ssl_error' ssize_t cut_into_SSL_channel(struct ssl_st* ssl, int* ssl_error); + // Cut `count' number of `pieces' into SSL channel `ssl'. + // Returns bytes cut on success, -1 otherwise and errno is set. + static ssize_t cut_multiple_into_SSL_channel( + struct ssl_st* ssl, IOBuf* const* pieces, size_t count, int* ssl_error); + // Cut `count' number of `pieces' into file descriptor `fd'. // Returns bytes cut on success, -1 otherwise and errno is set. static ssize_t cut_multiple_into_file_descriptor( @@ -431,9 +436,10 @@ class IOPortal : public IOBuf { // If `offset' is negative, does exactly what append_from_file_descriptor does. ssize_t pappend_from_file_descriptor(int fd, off_t offset, size_t max_count); - // Read from SSL channel `ssl'. Returns what `SSL_read' returns - // and the ssl error code will be filled into `ssl_error' - ssize_t append_from_SSL_channel(struct ssl_st* ssl, int* ssl_error); + // Read as many bytes as possible from SSL channel `ssl', and stop until `max_count'. + // Returns total bytes read and the ssl error code will be filled into `ssl_error' + ssize_t append_from_SSL_channel(struct ssl_st* ssl, int* ssl_error, + size_t max_count = 1024*1024); // Remove all data inside and return cached blocks. void clear(); diff --git a/test/brpc_channel_unittest.cpp b/test/brpc_channel_unittest.cpp index 1267229387..9566f143cf 100644 --- a/test/brpc_channel_unittest.cpp +++ b/test/brpc_channel_unittest.cpp @@ -36,8 +36,7 @@ namespace policy { void SendRpcResponse(int64_t correlation_id, Controller* cntl, const google::protobuf::Message* req, const google::protobuf::Message* res, - Socket* socket_ptr, const Server* server_raw, - MethodStatus *, long); + const Server* server_raw, MethodStatus *, long); } // policy } // brpc @@ -219,6 +218,8 @@ class ChannelTest : public ::testing::Test{ EXPECT_TRUE(req->ParseFromZeroCopyStream(&wrapper2)); } brpc::Controller* cntl = new brpc::Controller(); + cntl->_current_call.peer_id = ptr->id(); + cntl->_current_call.sending_sock.reset(ptr.release()); google::protobuf::Message* res = ts->_svc.GetResponsePrototype(method).New(); @@ -227,12 +228,11 @@ class ChannelTest : public ::testing::Test{ int64_t, brpc::Controller*, const google::protobuf::Message*, const google::protobuf::Message*, - brpc::Socket*, const brpc::Server*, brpc::MethodStatus*, long>( &brpc::policy::SendRpcResponse, meta.correlation_id(), cntl, NULL, res, - ptr.release(), &ts->_dummy, NULL, -1); + &ts->_dummy, NULL, -1); ts->_svc.CallMethod(method, cntl, req, res, done); } diff --git a/test/brpc_load_balancer_unittest.cpp b/test/brpc_load_balancer_unittest.cpp index 8c61f245ba..dd058e4bbb 100644 --- a/test/brpc_load_balancer_unittest.cpp +++ b/test/brpc_load_balancer_unittest.cpp @@ -7,8 +7,10 @@ #include #include #include +#include "bthread/bthread.h" #include "butil/gperftools_profiler.h" #include "butil/time.h" +#include "butil/fast_rand.h" #include "butil/containers/doubly_buffered_data.h" #include "brpc/describable.h" #include "brpc/socket.h" diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index ffac5331bf..f1f85dd7ee 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -1283,108 +1283,6 @@ TEST_F(ServerTest, too_big_message) { server.Join(); } -struct EchoOpensslMsg {}; -inline std::ostream& operator<<(std::ostream& os, EchoOpensslMsg) { - std::ifstream t("openssl.msg"); - return os << "============ The output of previous openssl ============\n" - << t.rdbuf() - << "\n============ The output ends here ============\n"; -} -void CheckCert(const char* cname, const char* cert) { - std::string cmd = butil::string_printf( - "echo 'Q' | openssl s_client -connect localhost:8613 " - "-servername %s > openssl.msg && grep %s openssl.msg", cname, cert); - ASSERT_EQ(0, system(cmd.c_str())) << EchoOpensslMsg(); -} - -std::string GetRawPemString(const char* fname) { - butil::ScopedFILE fp(fname, "r"); - char buf[4096]; - int size = read(fileno(fp), buf, sizeof(buf)); - std::string raw; - raw.append(buf, size); - return raw; -} - -TEST_F(ServerTest, ssl_sni) { - brpc::Server server; - brpc::ServerOptions options; - { - brpc::CertInfo cert; - cert.certificate = "cert1.crt"; - cert.private_key = "cert1.key"; - cert.sni_filters.push_back("localhost"); - options.ssl_options.default_cert = cert; - } - { - brpc::CertInfo cert; - cert.certificate = GetRawPemString("cert2.crt"); - cert.private_key = GetRawPemString("cert2.key"); - cert.sni_filters.push_back("*.localdomain"); - options.ssl_options.certs.push_back(cert); - } - ASSERT_EQ(0, server.Start(8613, &options)); - CheckCert("localhost", "cert1"); - -#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME - CheckCert("localhost.localdomain", "cert2"); -#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME - - server.Stop(0); - server.Join(); -} - -TEST_F(ServerTest, ssl_reload) { - brpc::Server server; - brpc::ServerOptions options; - { - brpc::CertInfo cert; - cert.certificate = "cert1.crt"; - cert.private_key = "cert1.key"; - cert.sni_filters.push_back("localhost"); - options.ssl_options.default_cert = cert; - } - ASSERT_EQ(0, server.Start(8613, &options)); - CheckCert("localhost", "cert1"); - - { - brpc::CertInfo cert; - cert.certificate = GetRawPemString("cert2.crt"); - cert.private_key = GetRawPemString("cert2.key"); - cert.sni_filters.push_back("*.localdomain"); - ASSERT_EQ(0, server.AddCertificate(cert)); - } -#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME - CheckCert("localhost.localdomain", "cert2"); -#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME - - { - brpc::CertInfo cert; - cert.certificate = GetRawPemString("cert2.crt"); - cert.private_key = GetRawPemString("cert2.key"); - ASSERT_EQ(0, server.RemoveCertificate(cert)); - } -#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME - CheckCert("localhost.localdomain", "cert1"); -#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME - - { - brpc::CertInfo cert; - cert.certificate = GetRawPemString("cert2.crt"); - cert.private_key = GetRawPemString("cert2.key"); - cert.sni_filters.push_back("*.localdomain"); - std::vector certs; - certs.push_back(cert); - ASSERT_EQ(0, server.ResetCertificates(certs)); - } -#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME - CheckCert("localhost.localdomain", "cert2"); -#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME - - server.Stop(0); - server.Join(); -} - TEST_F(ServerTest, max_concurrency) { const int port = 9200; brpc::Server server1; diff --git a/test/brpc_socket_map_unittest.cpp b/test/brpc_socket_map_unittest.cpp index e48a0e9738..50b35ea4d2 100644 --- a/test/brpc_socket_map_unittest.cpp +++ b/test/brpc_socket_map_unittest.cpp @@ -17,6 +17,7 @@ DECLARE_int32(max_connection_pool_size); namespace { butil::EndPoint g_endpoint; +brpc::SocketMapKey g_key(g_endpoint); void* worker(void*) { const int ROUND = 2; @@ -25,9 +26,9 @@ void* worker(void*) { for (int i = 0; i < ROUND * 2; ++i) { for (int j = 0; j < COUNT; ++j) { if (i % 2 == 0) { - EXPECT_EQ(0, brpc::SocketMapInsert(g_endpoint, &id)); + EXPECT_EQ(0, brpc::SocketMapInsert(g_key, &id)); } else { - brpc::SocketMapRemove(g_endpoint); + brpc::SocketMapRemove(g_key); } } } @@ -55,23 +56,23 @@ TEST_F(SocketMapTest, idle_timeout) { } brpc::SocketId id; // Socket still exists since it has not reached timeout yet - ASSERT_EQ(0, brpc::SocketMapFind(g_endpoint, &id)); + ASSERT_EQ(0, brpc::SocketMapFind(g_key, &id)); usleep(TIMEOUT * 1000000L + 1100000L); // Socket should be removed after timeout - ASSERT_EQ(-1, brpc::SocketMapFind(g_endpoint, &id)); + ASSERT_EQ(-1, brpc::SocketMapFind(g_key, &id)); brpc::FLAGS_defer_close_second = TIMEOUT * 10; - ASSERT_EQ(0, brpc::SocketMapInsert(g_endpoint, &id)); - brpc::SocketMapRemove(g_endpoint); - ASSERT_EQ(0, brpc::SocketMapFind(g_endpoint, &id)); + ASSERT_EQ(0, brpc::SocketMapInsert(g_key, &id)); + brpc::SocketMapRemove(g_key); + ASSERT_EQ(0, brpc::SocketMapFind(g_key, &id)); // Change `FLAGS_idle_timeout_second' to 0 to disable checking brpc::FLAGS_defer_close_second = 0; usleep(1100000L); // And then Socket should be removed - ASSERT_EQ(-1, brpc::SocketMapFind(g_endpoint, &id)); + ASSERT_EQ(-1, brpc::SocketMapFind(g_key, &id)); brpc::SocketId main_id; - ASSERT_EQ(0, brpc::SocketMapInsert(g_endpoint, &main_id)); + ASSERT_EQ(0, brpc::SocketMapInsert(g_key, &main_id)); brpc::FLAGS_idle_timeout_second = TIMEOUT; brpc::SocketUniquePtr main_ptr; brpc::SocketUniquePtr ptr; @@ -91,7 +92,7 @@ TEST_F(SocketMapTest, idle_timeout) { ASSERT_TRUE(main_ptr.get()); main_ptr.reset(); ASSERT_NE(id, ptr->id()); - brpc::SocketMapRemove(g_endpoint); + brpc::SocketMapRemove(g_key); } TEST_F(SocketMapTest, max_pool_size) { @@ -100,7 +101,7 @@ TEST_F(SocketMapTest, max_pool_size) { brpc::FLAGS_max_connection_pool_size = MAXSIZE; brpc::SocketId main_id; - ASSERT_EQ(0, brpc::SocketMapInsert(g_endpoint, &main_id)); + ASSERT_EQ(0, brpc::SocketMapInsert(g_key, &main_id)); brpc::SocketUniquePtr ptrs[TOTALSIZE]; for (int i = 0; i < TOTALSIZE; ++i) { @@ -126,7 +127,7 @@ TEST_F(SocketMapTest, max_pool_size) { } //namespace int main(int argc, char* argv[]) { - butil::str2endpoint("127.0.0.1:12345", &g_endpoint); + butil::str2endpoint("127.0.0.1:12345", &g_key.peer); testing::InitGoogleTest(&argc, argv); return RUN_ALL_TESTS(); } diff --git a/test/brpc_ssl_unittest.cpp b/test/brpc_ssl_unittest.cpp new file mode 100644 index 0000000000..c644ea034c --- /dev/null +++ b/test/brpc_ssl_unittest.cpp @@ -0,0 +1,334 @@ +// Baidu RPC - A framework to host and access services throughout Baidu. +// Copyright (c) 2014 baidu-rpc authors + +// Date: Sun Jul 13 15:04:18 CST 2014 + +#include +#include +#include +#include +#include +#include +#include +#include "brpc/global.h" +#include "brpc/socket.h" +#include "brpc/server.h" +#include "brpc/channel.h" +#include "brpc/socket_map.h" +#include "brpc/controller.h" +#include "echo.pb.h" + +namespace brpc { +void ExtractHostnames(X509* x, std::vector* hostnames); +} // namespace brpc + + +int main(int argc, char* argv[]) { + testing::InitGoogleTest(&argc, argv); + google::ParseCommandLineFlags(&argc, &argv, true); + brpc::GlobalInitializeOrDie(); + return RUN_ALL_TESTS(); +} + +bool g_delete = false; +const std::string EXP_REQUEST = "hello"; +const std::string EXP_RESPONSE = "world"; + +class EchoServiceImpl : public test::EchoService { +public: + EchoServiceImpl() : count(0) {} + virtual ~EchoServiceImpl() { g_delete = true; } + virtual void Echo(google::protobuf::RpcController* cntl_base, + const test::EchoRequest* request, + test::EchoResponse* response, + google::protobuf::Closure* done) { + brpc::ClosureGuard done_guard(done); + brpc::Controller* cntl = (brpc::Controller*)cntl_base; + count.fetch_add(1, butil::memory_order_relaxed); + EXPECT_EQ(EXP_REQUEST, request->message()); + EXPECT_TRUE(cntl->is_ssl()); + + response->set_message(EXP_RESPONSE); + if (request->sleep_us() > 0) { + LOG(INFO) << "Sleep " << request->sleep_us() << " us, protocol=" + << cntl->request_protocol(); + bthread_usleep(request->sleep_us()); + } + } + + butil::atomic count; +}; + +class SSLTest : public ::testing::Test{ +protected: + SSLTest() {}; + virtual ~SSLTest(){}; + virtual void SetUp() {}; + virtual void TearDown() {}; +}; + +void* RunClosure(void* arg) { + google::protobuf::Closure* done = (google::protobuf::Closure*)arg; + done->Run(); + return NULL; +} + +void SendMultipleRPC(brpc::Channel* channel, int count) { + for (int i = 0; i < count; ++i) { + brpc::Controller cntl; + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + test::EchoService_Stub stub(channel); + stub.Echo(&cntl, &req, &res, NULL); + + EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText(); + } +} + +TEST_F(SSLTest, sanity) { + // Test RPC based on SSL + brpc protocol + const int port = 8613; + brpc::Server server; + brpc::ServerOptions options; + + brpc::CertInfo cert; + cert.certificate = "cert1.crt"; + cert.private_key = "cert1.key"; + options.ssl_options.default_cert = cert; + + EchoServiceImpl echo_svc; + ASSERT_EQ(0, server.AddService( + &echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE)); + ASSERT_EQ(0, server.Start(port, &options)); + + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + { + brpc::Channel channel; + brpc::ChannelOptions coptions; + coptions.ssl_options.enable = true; + ASSERT_EQ(0, channel.Init("localhost", port, &coptions)); + + brpc::Controller cntl; + test::EchoService_Stub stub(&channel); + stub.Echo(&cntl, &req, &res, NULL); + EXPECT_EQ(EXP_RESPONSE, res.message()) << cntl.ErrorText(); + } + + // stress test + const int NUM = 5; + const int COUNT = 3000; + pthread_t tids[NUM]; + { + brpc::Channel channel; + brpc::ChannelOptions coptions; + coptions.ssl_options.enable = true; + ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions)); + for (int i = 0; i < NUM; ++i) { + google::protobuf::Closure* thrd_func = + brpc::NewCallback(SendMultipleRPC, &channel, COUNT); + EXPECT_EQ(0, pthread_create(&tids[i], NULL, RunClosure, thrd_func)); + } + for (int i = 0; i < NUM; ++i) { + pthread_join(tids[i], NULL); + } + } + { + // Use HTTP + brpc::Channel channel; + brpc::ChannelOptions coptions; + coptions.protocol = "http"; + coptions.ssl_options.enable = true; + ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions)); + for (int i = 0; i < NUM; ++i) { + google::protobuf::Closure* thrd_func = + brpc::NewCallback(SendMultipleRPC, &channel, COUNT); + EXPECT_EQ(0, pthread_create(&tids[i], NULL, RunClosure, thrd_func)); + } + for (int i = 0; i < NUM; ++i) { + pthread_join(tids[i], NULL); + } + } + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); +} + +void CheckCert(const char* cname, const char* cert) { + const int port = 8613; + brpc::Channel channel; + brpc::ChannelOptions coptions; + coptions.ssl_options.enable = true; + coptions.ssl_options.sni_name = cname; + ASSERT_EQ(0, channel.Init("127.0.0.1", port, &coptions)); + + SendMultipleRPC(&channel, 1); + // client has no access to the sending socket + std::vector ids; + brpc::SocketMapList(&ids); + ASSERT_EQ(1u, ids.size()); + brpc::SocketUniquePtr sock; + ASSERT_EQ(0, brpc::Socket::Address(ids[0], &sock)); + + X509* x509 = sock->GetPeerCertificate(); + ASSERT_TRUE(x509 != NULL); + std::vector cnames; + brpc::ExtractHostnames(x509, &cnames); + ASSERT_EQ(cert, cnames[0]) << x509; +} + +std::string GetRawPemString(const char* fname) { + butil::ScopedFILE fp(fname, "r"); + char buf[4096]; + int size = read(fileno(fp), buf, sizeof(buf)); + std::string raw; + raw.append(buf, size); + return raw; +} + +#ifdef SSL_CTRL_SET_TLSEXT_HOSTNAME + +TEST_F(SSLTest, ssl_sni) { + const int port = 8613; + brpc::Server server; + brpc::ServerOptions options; + { + brpc::CertInfo cert; + cert.certificate = "cert1.crt"; + cert.private_key = "cert1.key"; + cert.sni_filters.push_back("cert1.com"); + options.ssl_options.default_cert = cert; + } + { + brpc::CertInfo cert; + cert.certificate = GetRawPemString("cert2.crt"); + cert.private_key = GetRawPemString("cert2.key"); + cert.sni_filters.push_back("*.cert2.com"); + options.ssl_options.certs.push_back(cert); + } + EchoServiceImpl echo_svc; + ASSERT_EQ(0, server.AddService( + &echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE)); + ASSERT_EQ(0, server.Start(port, &options)); + + CheckCert("cert1.com", "cert1"); + CheckCert("www.cert2.com", "cert2"); + CheckCert("noexist", "cert1"); // default cert + + server.Stop(0); + server.Join(); +} + +TEST_F(SSLTest, ssl_reload) { + const int port = 8613; + brpc::Server server; + brpc::ServerOptions options; + { + brpc::CertInfo cert; + cert.certificate = "cert1.crt"; + cert.private_key = "cert1.key"; + cert.sni_filters.push_back("cert1.com"); + options.ssl_options.default_cert = cert; + } + EchoServiceImpl echo_svc; + ASSERT_EQ(0, server.AddService( + &echo_svc, brpc::SERVER_DOESNT_OWN_SERVICE)); + ASSERT_EQ(0, server.Start(port, &options)); + + CheckCert("cert2.com", "cert1"); // default cert + { + brpc::CertInfo cert; + cert.certificate = GetRawPemString("cert2.crt"); + cert.private_key = GetRawPemString("cert2.key"); + cert.sni_filters.push_back("cert2.com"); + ASSERT_EQ(0, server.AddCertificate(cert)); + } + CheckCert("cert2.com", "cert2"); + + { + brpc::CertInfo cert; + cert.certificate = GetRawPemString("cert2.crt"); + cert.private_key = GetRawPemString("cert2.key"); + ASSERT_EQ(0, server.RemoveCertificate(cert)); + } + CheckCert("cert2.com", "cert1"); // default cert after remove cert2 + + { + brpc::CertInfo cert; + cert.certificate = GetRawPemString("cert2.crt"); + cert.private_key = GetRawPemString("cert2.key"); + cert.sni_filters.push_back("cert2.com"); + std::vector certs; + certs.push_back(cert); + ASSERT_EQ(0, server.ResetCertificates(certs)); + } + CheckCert("cert2.com", "cert2"); + + server.Stop(0); + server.Join(); +} + +#endif // SSL_CTRL_SET_TLSEXT_HOSTNAME + +const int BUFSIZE[] = {64, 128, 256, 1024, 4096}; +const int REP = 100000; + +void* ssl_perf_client(void* arg) { + SSL* ssl = (SSL*)arg; + EXPECT_EQ(1, SSL_do_handshake(ssl)); + + char buf[4096]; + butil::Timer tm; + for (size_t i = 0; i < ARRAY_SIZE(BUFSIZE); ++i) { + int size = BUFSIZE[i]; + tm.start(); + for (int j = 0; j < REP; ++j) { + SSL_write(ssl, buf, size); + } + tm.stop(); + LOG(INFO) << "SSL_write(" << size << ") tp=" + << size * REP / tm.u_elapsed() << "M/s" + << ", latency=" << tm.u_elapsed() / REP << "us"; + } + return NULL; +} + +void* ssl_perf_server(void* arg) { + SSL* ssl = (SSL*)arg; + EXPECT_EQ(1, SSL_do_handshake(ssl)); + char buf[4096]; + for (size_t i = 0; i < ARRAY_SIZE(BUFSIZE); ++i) { + int size = BUFSIZE[i]; + for (int j = 0; j < REP; ++j) { + SSL_read(ssl, buf, size); + } + } + return NULL; +} + +TEST_F(SSLTest, ssl_perf) { + const butil::EndPoint ep(butil::IP_ANY, 5961); + butil::fd_guard listenfd(butil::tcp_listen(ep, false)); + ASSERT_GT(listenfd, 0); + int clifd = tcp_connect(ep, NULL); + ASSERT_GT(clifd, 0); + int servfd = accept(listenfd, NULL, NULL); + ASSERT_GT(servfd, 0); + + brpc::ChannelSSLOptions opt; + opt.enable = true; + SSL_CTX* cli_ctx = brpc::CreateClientSSLContext(opt); + SSL_CTX* serv_ctx = + brpc::CreateServerSSLContext("cert1.crt", "cert1.key", + brpc::SSLOptions(), NULL); + SSL* cli_ssl = brpc::CreateSSLSession(cli_ctx, 0, clifd, false); + SSL* serv_ssl = brpc::CreateSSLSession(serv_ctx, 0, servfd, true); + pthread_t cpid; + pthread_t spid; + ASSERT_EQ(0, pthread_create(&cpid, NULL, ssl_perf_client, cli_ssl)); + ASSERT_EQ(0, pthread_create(&spid, NULL, ssl_perf_server , serv_ssl)); + ASSERT_EQ(0, pthread_join(cpid, NULL)); + ASSERT_EQ(0, pthread_join(spid, NULL)); +}