diff --git a/src/brpc/authenticator.h b/src/brpc/authenticator.h index 501e427e7f..e08a3f1424 100644 --- a/src/brpc/authenticator.h +++ b/src/brpc/authenticator.h @@ -57,7 +57,7 @@ class AuthContext { class Authenticator { public: - virtual ~Authenticator() {} + virtual ~Authenticator() = default; // Implement this method to generate credential information // into `auth_str' which will be sent to `VerifyCredential' @@ -74,6 +74,13 @@ class Authenticator { const butil::EndPoint& client_addr, AuthContext* out_ctx) const = 0; + // Implement this method to unauthorized error text which + // will be sent as a part of error text in baidu_std/hulu_pbrpc + // protocol or body in http protocol. + virtual std::string GetUnauthorizedErrorText() const { + return ""; + } + }; inline std::ostream& operator<<(std::ostream& os, const AuthContext& ctx) { diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 0fb439a82d..504895b185 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -659,8 +659,8 @@ bool VerifyRpcRequest(const InputMessageBase* msg_base) { const Server* server = static_cast(msg->arg()); Socket* socket = msg->socket(); - RpcMeta meta; - if (!ParsePbFromIOBuf(&meta, msg->meta)) { + RpcMeta request_meta; + if (!ParsePbFromIOBuf(&request_meta, msg->meta)) { LOG(WARNING) << "Fail to parse RpcRequestMeta"; return false; } @@ -668,13 +668,32 @@ bool VerifyRpcRequest(const InputMessageBase* msg_base) { if (NULL == auth) { // Fast pass (no authentication) return true; - } - if (auth->VerifyCredential( - meta.authentication_data(), socket->remote_side(), - socket->mutable_auth_context()) != 0) { - return false; } - return true; + if (auth->VerifyCredential(request_meta.authentication_data(), + socket->remote_side(), + socket->mutable_auth_context()) == 0) { + return true; + } + + // Send `ERPCAUTH' to client. + RpcMeta response_meta; + response_meta.set_correlation_id(request_meta.correlation_id()); + response_meta.mutable_response()->set_error_code(ERPCAUTH); + response_meta.mutable_response()->set_error_text("Fail to authenticate"); + std::string user_error_text = auth->GetUnauthorizedErrorText(); + if (!user_error_text.empty()) { + response_meta.mutable_response()->mutable_error_text()->append(": "); + response_meta.mutable_response()->mutable_error_text()->append(user_error_text); + } + butil::IOBuf res_buf; + SerializeRpcHeaderAndMeta(&res_buf, response_meta, 0); + Socket::WriteOptions opt; + opt.ignore_eovercrowded = true; + if (socket->Write(&res_buf, &opt) != 0) { + PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket; + } + + return false; } void ProcessRpcResponse(InputMessageBase* msg_base) { diff --git a/src/brpc/policy/http_rpc_protocol.cpp b/src/brpc/policy/http_rpc_protocol.cpp index 52fff2c9cf..37ea8002dd 100644 --- a/src/brpc/policy/http_rpc_protocol.cpp +++ b/src/brpc/policy/http_rpc_protocol.cpp @@ -1253,6 +1253,26 @@ ParseResult ParseHttpMessage(butil::IOBuf *source, Socket *socket, } } +static void SendUnauthorizedResponse(const std::string& user_error_text, Socket* socket) { + // Send 403(forbidden) to client. + HttpHeader header; + header.set_status_code(HTTP_STATUS_FORBIDDEN); + butil::IOBuf content; + content.append(butil::string_printf("[%d]", ERPCAUTH)); + content.append("Fail to authenticate"); + if (!user_error_text.empty()) { + content.append(": "); + content.append(user_error_text); + } + butil::IOBuf res_buf; + MakeRawHttpResponse(&res_buf, &header, &content); + Socket::WriteOptions opt; + opt.ignore_eovercrowded = true; + if (socket->Write(&res_buf, &opt) != 0) { + PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket; + } +} + bool VerifyHttpRequest(const InputMessageBase* msg) { Server* server = (Server*)msg->arg(); Socket* socket = msg->socket(); @@ -1265,8 +1285,7 @@ bool VerifyHttpRequest(const InputMessageBase* msg) { } const Server::MethodProperty* mp = FindMethodPropertyByURI( http_request->header().uri().path(), server, NULL); - if (mp != NULL && - mp->is_builtin_service && + if (mp != NULL && mp->is_builtin_service && mp->service->GetDescriptor() != BadMethodService::descriptor()) { // BuiltinService doesn't need authentication // TODO: Fix backdoor that sends BuiltinService at first @@ -1275,16 +1294,22 @@ bool VerifyHttpRequest(const InputMessageBase* msg) { } const std::string *authorization - = http_request->header().GetHeader("Authorization"); + = http_request->header().GetHeader(common->AUTHORIZATION); if (authorization == NULL) { + SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket); return false; } butil::EndPoint user_addr; if (!GetUserAddressFromHeader(http_request->header(), &user_addr)) { user_addr = socket->remote_side(); } - return auth->VerifyCredential(*authorization, user_addr, - socket->mutable_auth_context()) == 0; + if (auth->VerifyCredential(*authorization, user_addr, + socket->mutable_auth_context()) != 0) { + SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket); + return false; + } + + return true; } diff --git a/src/brpc/policy/hulu_pbrpc_protocol.cpp b/src/brpc/policy/hulu_pbrpc_protocol.cpp index 2b63189eac..cb10aac35b 100644 --- a/src/brpc/policy/hulu_pbrpc_protocol.cpp +++ b/src/brpc/policy/hulu_pbrpc_protocol.cpp @@ -540,8 +540,8 @@ bool VerifyHuluRequest(const InputMessageBase* msg_base) { Socket* socket = msg->socket(); const Server* server = static_cast(msg->arg()); - HuluRpcRequestMeta meta; - if (!ParsePbFromIOBuf(&meta, msg->meta)) { + HuluRpcRequestMeta request_meta; + if (!ParsePbFromIOBuf(&request_meta, msg->meta)) { LOG(WARNING) << "Fail to parse HuluRpcRequestMeta"; return false; } @@ -549,13 +549,32 @@ bool VerifyHuluRequest(const InputMessageBase* msg_base) { if (NULL == auth) { // Fast pass (no authentication) return true; - } - if (auth->VerifyCredential( - meta.credential_data(), socket->remote_side(), - socket->mutable_auth_context()) != 0) { - return false; } - return true; + if (auth->VerifyCredential(request_meta.credential_data(), + socket->remote_side(), + socket->mutable_auth_context()) == 0) { + return true; + } + + // Send `ERPCAUTH' to client. + HuluRpcResponseMeta response_meta; + response_meta.set_correlation_id(request_meta.correlation_id()); + response_meta.set_error_code(ERPCAUTH); + std::string user_error_text = auth->GetUnauthorizedErrorText(); + response_meta.set_error_text("Fail to authenticate"); + if (!user_error_text.empty()) { + response_meta.mutable_error_text()->append(": "); + response_meta.mutable_error_text()->append(user_error_text); + } + butil::IOBuf res_buf; + SerializeHuluHeaderAndMeta(&res_buf, request_meta, 0); + Socket::WriteOptions opt; + opt.ignore_eovercrowded = true; + if (socket->Write(&res_buf, &opt) != 0) { + PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket; + } + + return false; } void ProcessHuluResponse(InputMessageBase* msg_base) { diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index 5f06887a52..b2e0bbaa48 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -80,18 +80,25 @@ void* RunClosure(void* arg) { return NULL; } +bool g_verify_success = true; +const std::string g_unauthorized_error_text = "unauthorized"; + class MyAuthenticator : public brpc::Authenticator { public: - MyAuthenticator() {} - virtual ~MyAuthenticator() {} - int GenerateCredential(std::string*) const { + MyAuthenticator() = default; + ~MyAuthenticator() override = default; + int GenerateCredential(std::string*) const override { return 0; } int VerifyCredential(const std::string&, const butil::EndPoint&, - brpc::AuthContext*) const { - return 0; + brpc::AuthContext*) const override { + return g_verify_success ? 0 : -1; + } + + std::string GetUnauthorizedErrorText() const override { + return g_unauthorized_error_text; } }; @@ -1828,4 +1835,73 @@ TEST_F(ServerTest, rpc_pb_message_factory) { ASSERT_EQ(0, server.Join()); } +void TestBaiduStdAuth(const butil::EndPoint& ep, + brpc::Controller& cntl, + int error_code, bool failed) { + brpc::Channel chan; + brpc::ChannelOptions copt; + copt.max_retry = 0; + copt.protocol = "baidu_std"; + ASSERT_EQ(0, chan.Init(ep, &copt)); + + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + test::EchoService_Stub stub(&chan); + stub.Echo(&cntl, &req, &res, NULL); + ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText(); + ASSERT_EQ(cntl.ErrorCode(), error_code); +} + +void TestHttpAuth(const butil::EndPoint& ep, + brpc::Controller& cntl, + int status_code, bool failed) { + brpc::Channel chan; + brpc::ChannelOptions copt; + copt.max_retry = 0; + copt.protocol = "http"; + ASSERT_EQ(0, chan.Init(ep, &copt)); + + cntl.http_request().uri() = "/EchoService/Echo"; + cntl.request_attachment().append(R"({"message": "hello"})"); + cntl.http_request().set_method(brpc::HTTP_METHOD_POST); + test::EchoService_Stub stub(&chan); + chan.CallMethod(NULL, &cntl, NULL, NULL, NULL); + ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText(); + ASSERT_EQ(cntl.http_response().status_code(), status_code); +} + +TEST_F(ServerTest, auth) { + butil::EndPoint ep; + ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep)); + brpc::Server server; + EchoServiceImpl service; + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + MyAuthenticator auth; + brpc::ServerOptions opt; + opt.auth = &auth; + ASSERT_EQ(0, server.Start(ep, &opt)); + + brpc::Controller cntl; + TestBaiduStdAuth(ep, cntl, 0, false); + + g_verify_success = false; + cntl.Reset(); + TestBaiduStdAuth(ep, cntl, brpc::ERPCAUTH, true); + ASSERT_NE(cntl.ErrorText().find(g_unauthorized_error_text), std::string::npos); + + cntl.Reset(); + TestHttpAuth(ep, cntl, brpc::HTTP_STATUS_FORBIDDEN, true); + ASSERT_NE(cntl.response_attachment().to_string().find(g_unauthorized_error_text), + std::string::npos); + + g_verify_success = true; + cntl.Reset(); + cntl.http_request().SetHeader("Authorization", "123"); + TestHttpAuth(ep, cntl, brpc::HTTP_STATUS_OK, false); + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); +} + } //namespace