From e0692e6a87ac0b43174e0a3066a264a41312dc0a Mon Sep 17 00:00:00 2001 From: Bright Chen Date: Sat, 24 Aug 2024 16:40:36 +0800 Subject: [PATCH 1/2] Send unauthorized response to client when authentication fails --- src/brpc/authenticator.h | 11 ++- src/brpc/policy/baidu_rpc_protocol.cpp | 36 +++++++-- src/brpc/policy/http_rpc_protocol.cpp | 37 ++++++++- src/brpc/policy/http_rpc_protocol.h | 1 + src/brpc/policy/hulu_pbrpc_protocol.cpp | 31 ++++++-- test/brpc_server_unittest.cpp | 100 ++++++++++++++++++++++-- 6 files changed, 197 insertions(+), 19 deletions(-) diff --git a/src/brpc/authenticator.h b/src/brpc/authenticator.h index 501e427e7f..a778191b06 100644 --- a/src/brpc/authenticator.h +++ b/src/brpc/authenticator.h @@ -57,7 +57,7 @@ class AuthContext { class Authenticator { public: - virtual ~Authenticator() {} + virtual ~Authenticator() = default; // Implement this method to generate credential information // into `auth_str' which will be sent to `VerifyCredential' @@ -74,6 +74,15 @@ class Authenticator { const butil::EndPoint& client_addr, AuthContext* out_ctx) const = 0; + // Implement this method to decide whether to send a response + // to the client when authentication fails. + // Returns true to indicate a response needs to be sent, + // otherwise no response is needed. + virtual bool GetUnauthorizedResponseInfo(std::string& response_str) const { + (void)response_str; + return false; + } + }; inline std::ostream& operator<<(std::ostream& os, const AuthContext& ctx) { diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 0fb439a82d..7b0125b790 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -668,13 +668,39 @@ bool VerifyRpcRequest(const InputMessageBase* msg_base) { if (NULL == auth) { // Fast pass (no authentication) return true; - } - if (auth->VerifyCredential( - meta.authentication_data(), socket->remote_side(), - socket->mutable_auth_context()) != 0) { + } + + bool success = auth->VerifyCredential(meta.authentication_data(), + socket->remote_side(), + socket->mutable_auth_context()) == 0; + if (success) { + return true; + } + + // Send `ERPCAUTH' to client. + std::string res_info; + if (!auth->GetUnauthorizedResponseInfo(res_info)) { return false; } - return true; + + RpcMeta temp_meta; + temp_meta.set_correlation_id(meta.correlation_id()); + RpcResponseMeta* response_meta = temp_meta.mutable_response(); + response_meta->set_error_code(ERPCAUTH); + std::string error_text = res_info.empty() ? "Fail to authenticate" : + butil::string_printf("Fail to authenticate, %s", res_info.c_str()); + response_meta->set_error_text(error_text); + + butil::IOBuf res_buf; + SerializeRpcHeaderAndMeta(&res_buf, temp_meta, 0); + + Socket::WriteOptions opt; + opt.ignore_eovercrowded = true; + if (socket->Write(&res_buf, &opt) != 0) { + PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket; + } + + return false; } void ProcessRpcResponse(InputMessageBase* msg_base) { diff --git a/src/brpc/policy/http_rpc_protocol.cpp b/src/brpc/policy/http_rpc_protocol.cpp index 52fff2c9cf..86ec882b4f 100644 --- a/src/brpc/policy/http_rpc_protocol.cpp +++ b/src/brpc/policy/http_rpc_protocol.cpp @@ -125,6 +125,7 @@ CommonStrings::CommonStrings() , CONTENT_TYPE_SPRING_PROTO("application/x-protobuf") , ERROR_CODE("x-bd-error-code") , AUTHORIZATION("authorization") + , WWW_AUTHENTICATE("www-authenticate") , ACCEPT_ENCODING("accept-encoding") , CONTENT_ENCODING("content-encoding") , CONTENT_LENGTH("content_length") @@ -1253,6 +1254,30 @@ ParseResult ParseHttpMessage(butil::IOBuf *source, Socket *socket, } } +static void SendUnauthorizedResponseIfNeed(const Authenticator* auth, Socket* socket) { + std::string www_authenticate; + if (!auth->GetUnauthorizedResponseInfo(www_authenticate)) { + return; + } + + // Send 401(unauthorized) and `ERPCAUTH' to client. + butil::IOBuf res_buf; + HttpHeader header; + header.set_status_code(HTTP_STATUS_UNAUTHORIZED); + // RFC7235 https://datatracker.ietf.org/doc/html/rfc7235#section-4.1 + // The server generating a 401 response MUST send a WWW-Authenticate + // header field (Section 4.1) containing at least one challenge + // applicable to the target resource. + header.SetHeader(common->ERROR_CODE, butil::string_printf("%d", ERPCAUTH)); + header.SetHeader(common->WWW_AUTHENTICATE, www_authenticate); + MakeRawHttpResponse(&res_buf, &header, NULL); + Socket::WriteOptions opt; + opt.ignore_eovercrowded = true; + if (socket->Write(&res_buf, &opt) != 0) { + PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket; + } +} + bool VerifyHttpRequest(const InputMessageBase* msg) { Server* server = (Server*)msg->arg(); Socket* socket = msg->socket(); @@ -1275,16 +1300,22 @@ bool VerifyHttpRequest(const InputMessageBase* msg) { } const std::string *authorization - = http_request->header().GetHeader("Authorization"); + = http_request->header().GetHeader(common->AUTHORIZATION); if (authorization == NULL) { + SendUnauthorizedResponseIfNeed(auth, socket); return false; } butil::EndPoint user_addr; if (!GetUserAddressFromHeader(http_request->header(), &user_addr)) { user_addr = socket->remote_side(); } - return auth->VerifyCredential(*authorization, user_addr, - socket->mutable_auth_context()) == 0; + if (auth->VerifyCredential(*authorization, user_addr, + socket->mutable_auth_context()) != 0) { + SendUnauthorizedResponseIfNeed(auth, socket); + return false; + } + + return true; } diff --git a/src/brpc/policy/http_rpc_protocol.h b/src/brpc/policy/http_rpc_protocol.h index 918e69d0fa..c636ebeca8 100644 --- a/src/brpc/policy/http_rpc_protocol.h +++ b/src/brpc/policy/http_rpc_protocol.h @@ -40,6 +40,7 @@ struct CommonStrings { std::string CONTENT_TYPE_SPRING_PROTO; std::string ERROR_CODE; std::string AUTHORIZATION; + std::string WWW_AUTHENTICATE; std::string ACCEPT_ENCODING; std::string CONTENT_ENCODING; std::string CONTENT_LENGTH; diff --git a/src/brpc/policy/hulu_pbrpc_protocol.cpp b/src/brpc/policy/hulu_pbrpc_protocol.cpp index 2b63189eac..5b456e542e 100644 --- a/src/brpc/policy/hulu_pbrpc_protocol.cpp +++ b/src/brpc/policy/hulu_pbrpc_protocol.cpp @@ -549,12 +549,33 @@ bool VerifyHuluRequest(const InputMessageBase* msg_base) { if (NULL == auth) { // Fast pass (no authentication) return true; - } - if (auth->VerifyCredential( - meta.credential_data(), socket->remote_side(), - socket->mutable_auth_context()) != 0) { - return false; } + bool success = auth->VerifyCredential(meta.credential_data(), + socket->remote_side(), + socket->mutable_auth_context()) == 0; + if (!success) { + std::string res_info; + if (!auth->GetUnauthorizedResponseInfo(res_info)) { + return false; + } + + // Send `ERPCAUTH' to client. + HuluRpcResponseMeta temp_meta; + temp_meta.set_correlation_id(meta.correlation_id()); + temp_meta.set_error_code(ERPCAUTH); + std::string error_text = res_info.empty() ? "Fail to authenticate" : + butil::string_printf("Fail to authenticate, %s", res_info.c_str()); + temp_meta.set_error_text(error_text); + + butil::IOBuf res_buf; + SerializeHuluHeaderAndMeta(&res_buf, meta, 0); + Socket::WriteOptions opt; + opt.ignore_eovercrowded = true; + if (socket->Write(&res_buf, &opt) != 0) { + PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket; + } + } + return true; } diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index 5f06887a52..672bc02c4e 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -80,18 +80,29 @@ void* RunClosure(void* arg) { return NULL; } +bool g_verify_success = true; +bool g_unauthorized_response = false; +const std::string g_unauthorized_response_info = "Basic"; + class MyAuthenticator : public brpc::Authenticator { public: - MyAuthenticator() {} - virtual ~MyAuthenticator() {} - int GenerateCredential(std::string*) const { + MyAuthenticator() = default; + ~MyAuthenticator() = default; + int GenerateCredential(std::string*) const override { return 0; } int VerifyCredential(const std::string&, const butil::EndPoint&, - brpc::AuthContext*) const { - return 0; + brpc::AuthContext*) const override { + return g_verify_success ? 0 : -1; + } + + bool GetUnauthorizedResponseInfo(std::string& response_str) const override { + if (g_unauthorized_response) { + response_str = "Basic"; + } + return g_unauthorized_response; } }; @@ -1828,4 +1839,83 @@ TEST_F(ServerTest, rpc_pb_message_factory) { ASSERT_EQ(0, server.Join()); } +void TestBaiduStdAuth(const butil::EndPoint& ep, + brpc::Controller& cntl, + int error_code, bool failed) { + brpc::Channel chan; + brpc::ChannelOptions copt; + copt.max_retry = 0; + copt.protocol = "baidu_std"; + ASSERT_EQ(0, chan.Init(ep, &copt)); + + test::EchoRequest req; + test::EchoResponse res; + req.set_message(EXP_REQUEST); + test::EchoService_Stub stub(&chan); + stub.Echo(&cntl, &req, &res, NULL); + ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText(); + ASSERT_EQ(cntl.ErrorCode(), error_code); +} + +void TestHttpAuth(const butil::EndPoint& ep, + brpc::Controller& cntl, + int error_code, bool failed) { + brpc::Channel chan; + brpc::ChannelOptions copt; + copt.max_retry = 0; + copt.protocol = "http"; + ASSERT_EQ(0, chan.Init(ep, &copt)); + + cntl.http_request().uri() = "/EchoService/Echo"; + cntl.request_attachment().append(R"({"message": "hello"})"); + cntl.http_request().set_method(brpc::HTTP_METHOD_POST); + test::EchoService_Stub stub(&chan); + chan.CallMethod(NULL, &cntl, NULL, NULL, NULL); + ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText(); + ASSERT_EQ(cntl.ErrorCode(), error_code); +} + +TEST_F(ServerTest, auth) { + butil::EndPoint ep; + ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep)); + brpc::Server server; + EchoServiceImpl service; + ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE)); + MyAuthenticator auth; + brpc::ServerOptions opt; + opt.auth = &auth; + ASSERT_EQ(0, server.Start(ep, &opt)); + + brpc::Controller cntl; + TestBaiduStdAuth(ep, cntl, 0, false); + + g_verify_success = false; + cntl.Reset(); + TestBaiduStdAuth(ep, cntl, brpc::EEOF, true); + + g_unauthorized_response = true; + cntl.Reset(); + TestBaiduStdAuth(ep, cntl, brpc::ERPCAUTH, true); + ASSERT_NE(cntl.ErrorText().find(g_unauthorized_response_info), std::string::npos); + + brpc::policy::FLAGS_use_http_error_code = true; + cntl.Reset(); + TestHttpAuth(ep, cntl, brpc::ERPCAUTH, true); + const std::string* www_authenticate = cntl.http_response().GetHeader("WWW-Authenticate"); + ASSERT_NE(nullptr, www_authenticate); + ASSERT_EQ(*www_authenticate, g_unauthorized_response_info); + + g_unauthorized_response = false; + cntl.Reset(); + TestHttpAuth(ep, cntl, brpc::EEOF, true); + + g_verify_success = true; + cntl.Reset(); + cntl.http_request().SetHeader("Authorization", "123"); + TestHttpAuth(ep, cntl, 0, false); + + ASSERT_EQ(0, server.Stop(0)); + ASSERT_EQ(0, server.Join()); +} + } //namespace From adf1e6dcaee8a825b1aae137ad071c9e67d32cdf Mon Sep 17 00:00:00 2001 From: Bright Chen Date: Tue, 27 Aug 2024 22:25:19 +0800 Subject: [PATCH 2/2] Send unauthorized response by default --- src/brpc/authenticator.h | 12 +++--- src/brpc/policy/baidu_rpc_protocol.cpp | 35 +++++++---------- src/brpc/policy/http_rpc_protocol.cpp | 34 +++++++---------- src/brpc/policy/http_rpc_protocol.h | 1 - src/brpc/policy/hulu_pbrpc_protocol.cpp | 50 ++++++++++++------------- test/brpc_server_unittest.cpp | 36 ++++++------------ 6 files changed, 68 insertions(+), 100 deletions(-) diff --git a/src/brpc/authenticator.h b/src/brpc/authenticator.h index a778191b06..e08a3f1424 100644 --- a/src/brpc/authenticator.h +++ b/src/brpc/authenticator.h @@ -74,13 +74,11 @@ class Authenticator { const butil::EndPoint& client_addr, AuthContext* out_ctx) const = 0; - // Implement this method to decide whether to send a response - // to the client when authentication fails. - // Returns true to indicate a response needs to be sent, - // otherwise no response is needed. - virtual bool GetUnauthorizedResponseInfo(std::string& response_str) const { - (void)response_str; - return false; + // Implement this method to unauthorized error text which + // will be sent as a part of error text in baidu_std/hulu_pbrpc + // protocol or body in http protocol. + virtual std::string GetUnauthorizedErrorText() const { + return ""; } }; diff --git a/src/brpc/policy/baidu_rpc_protocol.cpp b/src/brpc/policy/baidu_rpc_protocol.cpp index 7b0125b790..504895b185 100644 --- a/src/brpc/policy/baidu_rpc_protocol.cpp +++ b/src/brpc/policy/baidu_rpc_protocol.cpp @@ -659,8 +659,8 @@ bool VerifyRpcRequest(const InputMessageBase* msg_base) { const Server* server = static_cast(msg->arg()); Socket* socket = msg->socket(); - RpcMeta meta; - if (!ParsePbFromIOBuf(&meta, msg->meta)) { + RpcMeta request_meta; + if (!ParsePbFromIOBuf(&request_meta, msg->meta)) { LOG(WARNING) << "Fail to parse RpcRequestMeta"; return false; } @@ -669,31 +669,24 @@ bool VerifyRpcRequest(const InputMessageBase* msg_base) { // Fast pass (no authentication) return true; } - - bool success = auth->VerifyCredential(meta.authentication_data(), - socket->remote_side(), - socket->mutable_auth_context()) == 0; - if (success) { + if (auth->VerifyCredential(request_meta.authentication_data(), + socket->remote_side(), + socket->mutable_auth_context()) == 0) { return true; } // Send `ERPCAUTH' to client. - std::string res_info; - if (!auth->GetUnauthorizedResponseInfo(res_info)) { - return false; + RpcMeta response_meta; + response_meta.set_correlation_id(request_meta.correlation_id()); + response_meta.mutable_response()->set_error_code(ERPCAUTH); + response_meta.mutable_response()->set_error_text("Fail to authenticate"); + std::string user_error_text = auth->GetUnauthorizedErrorText(); + if (!user_error_text.empty()) { + response_meta.mutable_response()->mutable_error_text()->append(": "); + response_meta.mutable_response()->mutable_error_text()->append(user_error_text); } - - RpcMeta temp_meta; - temp_meta.set_correlation_id(meta.correlation_id()); - RpcResponseMeta* response_meta = temp_meta.mutable_response(); - response_meta->set_error_code(ERPCAUTH); - std::string error_text = res_info.empty() ? "Fail to authenticate" : - butil::string_printf("Fail to authenticate, %s", res_info.c_str()); - response_meta->set_error_text(error_text); - butil::IOBuf res_buf; - SerializeRpcHeaderAndMeta(&res_buf, temp_meta, 0); - + SerializeRpcHeaderAndMeta(&res_buf, response_meta, 0); Socket::WriteOptions opt; opt.ignore_eovercrowded = true; if (socket->Write(&res_buf, &opt) != 0) { diff --git a/src/brpc/policy/http_rpc_protocol.cpp b/src/brpc/policy/http_rpc_protocol.cpp index 86ec882b4f..37ea8002dd 100644 --- a/src/brpc/policy/http_rpc_protocol.cpp +++ b/src/brpc/policy/http_rpc_protocol.cpp @@ -125,7 +125,6 @@ CommonStrings::CommonStrings() , CONTENT_TYPE_SPRING_PROTO("application/x-protobuf") , ERROR_CODE("x-bd-error-code") , AUTHORIZATION("authorization") - , WWW_AUTHENTICATE("www-authenticate") , ACCEPT_ENCODING("accept-encoding") , CONTENT_ENCODING("content-encoding") , CONTENT_LENGTH("content_length") @@ -1254,23 +1253,19 @@ ParseResult ParseHttpMessage(butil::IOBuf *source, Socket *socket, } } -static void SendUnauthorizedResponseIfNeed(const Authenticator* auth, Socket* socket) { - std::string www_authenticate; - if (!auth->GetUnauthorizedResponseInfo(www_authenticate)) { - return; +static void SendUnauthorizedResponse(const std::string& user_error_text, Socket* socket) { + // Send 403(forbidden) to client. + HttpHeader header; + header.set_status_code(HTTP_STATUS_FORBIDDEN); + butil::IOBuf content; + content.append(butil::string_printf("[%d]", ERPCAUTH)); + content.append("Fail to authenticate"); + if (!user_error_text.empty()) { + content.append(": "); + content.append(user_error_text); } - - // Send 401(unauthorized) and `ERPCAUTH' to client. butil::IOBuf res_buf; - HttpHeader header; - header.set_status_code(HTTP_STATUS_UNAUTHORIZED); - // RFC7235 https://datatracker.ietf.org/doc/html/rfc7235#section-4.1 - // The server generating a 401 response MUST send a WWW-Authenticate - // header field (Section 4.1) containing at least one challenge - // applicable to the target resource. - header.SetHeader(common->ERROR_CODE, butil::string_printf("%d", ERPCAUTH)); - header.SetHeader(common->WWW_AUTHENTICATE, www_authenticate); - MakeRawHttpResponse(&res_buf, &header, NULL); + MakeRawHttpResponse(&res_buf, &header, &content); Socket::WriteOptions opt; opt.ignore_eovercrowded = true; if (socket->Write(&res_buf, &opt) != 0) { @@ -1290,8 +1285,7 @@ bool VerifyHttpRequest(const InputMessageBase* msg) { } const Server::MethodProperty* mp = FindMethodPropertyByURI( http_request->header().uri().path(), server, NULL); - if (mp != NULL && - mp->is_builtin_service && + if (mp != NULL && mp->is_builtin_service && mp->service->GetDescriptor() != BadMethodService::descriptor()) { // BuiltinService doesn't need authentication // TODO: Fix backdoor that sends BuiltinService at first @@ -1302,7 +1296,7 @@ bool VerifyHttpRequest(const InputMessageBase* msg) { const std::string *authorization = http_request->header().GetHeader(common->AUTHORIZATION); if (authorization == NULL) { - SendUnauthorizedResponseIfNeed(auth, socket); + SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket); return false; } butil::EndPoint user_addr; @@ -1311,7 +1305,7 @@ bool VerifyHttpRequest(const InputMessageBase* msg) { } if (auth->VerifyCredential(*authorization, user_addr, socket->mutable_auth_context()) != 0) { - SendUnauthorizedResponseIfNeed(auth, socket); + SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket); return false; } diff --git a/src/brpc/policy/http_rpc_protocol.h b/src/brpc/policy/http_rpc_protocol.h index c636ebeca8..918e69d0fa 100644 --- a/src/brpc/policy/http_rpc_protocol.h +++ b/src/brpc/policy/http_rpc_protocol.h @@ -40,7 +40,6 @@ struct CommonStrings { std::string CONTENT_TYPE_SPRING_PROTO; std::string ERROR_CODE; std::string AUTHORIZATION; - std::string WWW_AUTHENTICATE; std::string ACCEPT_ENCODING; std::string CONTENT_ENCODING; std::string CONTENT_LENGTH; diff --git a/src/brpc/policy/hulu_pbrpc_protocol.cpp b/src/brpc/policy/hulu_pbrpc_protocol.cpp index 5b456e542e..cb10aac35b 100644 --- a/src/brpc/policy/hulu_pbrpc_protocol.cpp +++ b/src/brpc/policy/hulu_pbrpc_protocol.cpp @@ -540,8 +540,8 @@ bool VerifyHuluRequest(const InputMessageBase* msg_base) { Socket* socket = msg->socket(); const Server* server = static_cast(msg->arg()); - HuluRpcRequestMeta meta; - if (!ParsePbFromIOBuf(&meta, msg->meta)) { + HuluRpcRequestMeta request_meta; + if (!ParsePbFromIOBuf(&request_meta, msg->meta)) { LOG(WARNING) << "Fail to parse HuluRpcRequestMeta"; return false; } @@ -550,33 +550,31 @@ bool VerifyHuluRequest(const InputMessageBase* msg_base) { // Fast pass (no authentication) return true; } - bool success = auth->VerifyCredential(meta.credential_data(), - socket->remote_side(), - socket->mutable_auth_context()) == 0; - if (!success) { - std::string res_info; - if (!auth->GetUnauthorizedResponseInfo(res_info)) { - return false; - } - - // Send `ERPCAUTH' to client. - HuluRpcResponseMeta temp_meta; - temp_meta.set_correlation_id(meta.correlation_id()); - temp_meta.set_error_code(ERPCAUTH); - std::string error_text = res_info.empty() ? "Fail to authenticate" : - butil::string_printf("Fail to authenticate, %s", res_info.c_str()); - temp_meta.set_error_text(error_text); + if (auth->VerifyCredential(request_meta.credential_data(), + socket->remote_side(), + socket->mutable_auth_context()) == 0) { + return true; + } - butil::IOBuf res_buf; - SerializeHuluHeaderAndMeta(&res_buf, meta, 0); - Socket::WriteOptions opt; - opt.ignore_eovercrowded = true; - if (socket->Write(&res_buf, &opt) != 0) { - PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket; - } + // Send `ERPCAUTH' to client. + HuluRpcResponseMeta response_meta; + response_meta.set_correlation_id(request_meta.correlation_id()); + response_meta.set_error_code(ERPCAUTH); + std::string user_error_text = auth->GetUnauthorizedErrorText(); + response_meta.set_error_text("Fail to authenticate"); + if (!user_error_text.empty()) { + response_meta.mutable_error_text()->append(": "); + response_meta.mutable_error_text()->append(user_error_text); + } + butil::IOBuf res_buf; + SerializeHuluHeaderAndMeta(&res_buf, request_meta, 0); + Socket::WriteOptions opt; + opt.ignore_eovercrowded = true; + if (socket->Write(&res_buf, &opt) != 0) { + PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket; } - return true; + return false; } void ProcessHuluResponse(InputMessageBase* msg_base) { diff --git a/test/brpc_server_unittest.cpp b/test/brpc_server_unittest.cpp index 672bc02c4e..b2e0bbaa48 100644 --- a/test/brpc_server_unittest.cpp +++ b/test/brpc_server_unittest.cpp @@ -81,13 +81,12 @@ void* RunClosure(void* arg) { } bool g_verify_success = true; -bool g_unauthorized_response = false; -const std::string g_unauthorized_response_info = "Basic"; +const std::string g_unauthorized_error_text = "unauthorized"; class MyAuthenticator : public brpc::Authenticator { public: MyAuthenticator() = default; - ~MyAuthenticator() = default; + ~MyAuthenticator() override = default; int GenerateCredential(std::string*) const override { return 0; } @@ -98,11 +97,8 @@ class MyAuthenticator : public brpc::Authenticator { return g_verify_success ? 0 : -1; } - bool GetUnauthorizedResponseInfo(std::string& response_str) const override { - if (g_unauthorized_response) { - response_str = "Basic"; - } - return g_unauthorized_response; + std::string GetUnauthorizedErrorText() const override { + return g_unauthorized_error_text; } }; @@ -1859,7 +1855,7 @@ void TestBaiduStdAuth(const butil::EndPoint& ep, void TestHttpAuth(const butil::EndPoint& ep, brpc::Controller& cntl, - int error_code, bool failed) { + int status_code, bool failed) { brpc::Channel chan; brpc::ChannelOptions copt; copt.max_retry = 0; @@ -1872,7 +1868,7 @@ void TestHttpAuth(const butil::EndPoint& ep, test::EchoService_Stub stub(&chan); chan.CallMethod(NULL, &cntl, NULL, NULL, NULL); ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText(); - ASSERT_EQ(cntl.ErrorCode(), error_code); + ASSERT_EQ(cntl.http_response().status_code(), status_code); } TEST_F(ServerTest, auth) { @@ -1891,28 +1887,18 @@ TEST_F(ServerTest, auth) { g_verify_success = false; cntl.Reset(); - TestBaiduStdAuth(ep, cntl, brpc::EEOF, true); - - g_unauthorized_response = true; - cntl.Reset(); TestBaiduStdAuth(ep, cntl, brpc::ERPCAUTH, true); - ASSERT_NE(cntl.ErrorText().find(g_unauthorized_response_info), std::string::npos); - - brpc::policy::FLAGS_use_http_error_code = true; - cntl.Reset(); - TestHttpAuth(ep, cntl, brpc::ERPCAUTH, true); - const std::string* www_authenticate = cntl.http_response().GetHeader("WWW-Authenticate"); - ASSERT_NE(nullptr, www_authenticate); - ASSERT_EQ(*www_authenticate, g_unauthorized_response_info); + ASSERT_NE(cntl.ErrorText().find(g_unauthorized_error_text), std::string::npos); - g_unauthorized_response = false; cntl.Reset(); - TestHttpAuth(ep, cntl, brpc::EEOF, true); + TestHttpAuth(ep, cntl, brpc::HTTP_STATUS_FORBIDDEN, true); + ASSERT_NE(cntl.response_attachment().to_string().find(g_unauthorized_error_text), + std::string::npos); g_verify_success = true; cntl.Reset(); cntl.http_request().SetHeader("Authorization", "123"); - TestHttpAuth(ep, cntl, 0, false); + TestHttpAuth(ep, cntl, brpc::HTTP_STATUS_OK, false); ASSERT_EQ(0, server.Stop(0)); ASSERT_EQ(0, server.Join());