Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/brpc/authenticator.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class AuthContext {

class Authenticator {
public:
virtual ~Authenticator() {}
virtual ~Authenticator() = default;

// Implement this method to generate credential information
// into `auth_str' which will be sent to `VerifyCredential'
Expand All @@ -74,6 +74,13 @@ class Authenticator {
const butil::EndPoint& client_addr,
AuthContext* out_ctx) const = 0;

// Implement this method to unauthorized error text which
// will be sent as a part of error text in baidu_std/hulu_pbrpc
// protocol or body in http protocol.
virtual std::string GetUnauthorizedErrorText() const {
return "";
}

};

inline std::ostream& operator<<(std::ostream& os, const AuthContext& ctx) {
Expand Down
35 changes: 27 additions & 8 deletions src/brpc/policy/baidu_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -659,22 +659,41 @@ bool VerifyRpcRequest(const InputMessageBase* msg_base) {
const Server* server = static_cast<const Server*>(msg->arg());
Socket* socket = msg->socket();

RpcMeta meta;
if (!ParsePbFromIOBuf(&meta, msg->meta)) {
RpcMeta request_meta;
if (!ParsePbFromIOBuf(&request_meta, msg->meta)) {
LOG(WARNING) << "Fail to parse RpcRequestMeta";
return false;
}
const Authenticator* auth = server->options().auth;
if (NULL == auth) {
// Fast pass (no authentication)
return true;
}
if (auth->VerifyCredential(
meta.authentication_data(), socket->remote_side(),
socket->mutable_auth_context()) != 0) {
return false;
}
return true;
if (auth->VerifyCredential(request_meta.authentication_data(),
socket->remote_side(),
socket->mutable_auth_context()) == 0) {
return true;
}

// Send `ERPCAUTH' to client.
RpcMeta response_meta;
response_meta.set_correlation_id(request_meta.correlation_id());
response_meta.mutable_response()->set_error_code(ERPCAUTH);
response_meta.mutable_response()->set_error_text("Fail to authenticate");
std::string user_error_text = auth->GetUnauthorizedErrorText();
if (!user_error_text.empty()) {
response_meta.mutable_response()->mutable_error_text()->append(": ");
response_meta.mutable_response()->mutable_error_text()->append(user_error_text);
}
butil::IOBuf res_buf;
SerializeRpcHeaderAndMeta(&res_buf, response_meta, 0);
Socket::WriteOptions opt;
opt.ignore_eovercrowded = true;
if (socket->Write(&res_buf, &opt) != 0) {
PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket;
}

return false;
}

void ProcessRpcResponse(InputMessageBase* msg_base) {
Expand Down
35 changes: 30 additions & 5 deletions src/brpc/policy/http_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,26 @@ ParseResult ParseHttpMessage(butil::IOBuf *source, Socket *socket,
}
}

static void SendUnauthorizedResponse(const std::string& user_error_text, Socket* socket) {
// Send 403(forbidden) to client.
HttpHeader header;
header.set_status_code(HTTP_STATUS_FORBIDDEN);
butil::IOBuf content;
content.append(butil::string_printf("[%d]", ERPCAUTH));
content.append("Fail to authenticate");
if (!user_error_text.empty()) {
content.append(": ");
content.append(user_error_text);
}
butil::IOBuf res_buf;
MakeRawHttpResponse(&res_buf, &header, &content);
Socket::WriteOptions opt;
opt.ignore_eovercrowded = true;
if (socket->Write(&res_buf, &opt) != 0) {
PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket;
}
}

bool VerifyHttpRequest(const InputMessageBase* msg) {
Server* server = (Server*)msg->arg();
Socket* socket = msg->socket();
Expand All @@ -1265,8 +1285,7 @@ bool VerifyHttpRequest(const InputMessageBase* msg) {
}
const Server::MethodProperty* mp = FindMethodPropertyByURI(
http_request->header().uri().path(), server, NULL);
if (mp != NULL &&
mp->is_builtin_service &&
if (mp != NULL && mp->is_builtin_service &&
mp->service->GetDescriptor() != BadMethodService::descriptor()) {
// BuiltinService doesn't need authentication
// TODO: Fix backdoor that sends BuiltinService at first
Expand All @@ -1275,16 +1294,22 @@ bool VerifyHttpRequest(const InputMessageBase* msg) {
}

const std::string *authorization
= http_request->header().GetHeader("Authorization");
= http_request->header().GetHeader(common->AUTHORIZATION);
if (authorization == NULL) {
SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket);
return false;
}
butil::EndPoint user_addr;
if (!GetUserAddressFromHeader(http_request->header(), &user_addr)) {
user_addr = socket->remote_side();
}
return auth->VerifyCredential(*authorization, user_addr,
socket->mutable_auth_context()) == 0;
if (auth->VerifyCredential(*authorization, user_addr,
socket->mutable_auth_context()) != 0) {
SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket);
return false;
}

return true;
}


Expand Down
35 changes: 27 additions & 8 deletions src/brpc/policy/hulu_pbrpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -540,22 +540,41 @@ bool VerifyHuluRequest(const InputMessageBase* msg_base) {
Socket* socket = msg->socket();
const Server* server = static_cast<const Server*>(msg->arg());

HuluRpcRequestMeta meta;
if (!ParsePbFromIOBuf(&meta, msg->meta)) {
HuluRpcRequestMeta request_meta;
if (!ParsePbFromIOBuf(&request_meta, msg->meta)) {
LOG(WARNING) << "Fail to parse HuluRpcRequestMeta";
return false;
}
const Authenticator* auth = server->options().auth;
if (NULL == auth) {
// Fast pass (no authentication)
return true;
}
if (auth->VerifyCredential(
meta.credential_data(), socket->remote_side(),
socket->mutable_auth_context()) != 0) {
return false;
}
return true;
if (auth->VerifyCredential(request_meta.credential_data(),
socket->remote_side(),
socket->mutable_auth_context()) == 0) {
return true;
}

// Send `ERPCAUTH' to client.
HuluRpcResponseMeta response_meta;
response_meta.set_correlation_id(request_meta.correlation_id());
response_meta.set_error_code(ERPCAUTH);
std::string user_error_text = auth->GetUnauthorizedErrorText();
response_meta.set_error_text("Fail to authenticate");
if (!user_error_text.empty()) {
response_meta.mutable_error_text()->append(": ");
response_meta.mutable_error_text()->append(user_error_text);
}
butil::IOBuf res_buf;
SerializeHuluHeaderAndMeta(&res_buf, request_meta, 0);
Socket::WriteOptions opt;
opt.ignore_eovercrowded = true;
if (socket->Write(&res_buf, &opt) != 0) {
PLOG_IF(WARNING, errno != EPIPE) << "Fail to write into " << *socket;
}

return false;
}

void ProcessHuluResponse(InputMessageBase* msg_base) {
Expand Down
86 changes: 81 additions & 5 deletions test/brpc_server_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,25 @@ void* RunClosure(void* arg) {
return NULL;
}

bool g_verify_success = true;
const std::string g_unauthorized_error_text = "unauthorized";

class MyAuthenticator : public brpc::Authenticator {
public:
MyAuthenticator() {}
virtual ~MyAuthenticator() {}
int GenerateCredential(std::string*) const {
MyAuthenticator() = default;
~MyAuthenticator() override = default;
int GenerateCredential(std::string*) const override {
return 0;
}

int VerifyCredential(const std::string&,
const butil::EndPoint&,
brpc::AuthContext*) const {
return 0;
brpc::AuthContext*) const override {
return g_verify_success ? 0 : -1;
}

std::string GetUnauthorizedErrorText() const override {
return g_unauthorized_error_text;
}
};

Expand Down Expand Up @@ -1828,4 +1835,73 @@ TEST_F(ServerTest, rpc_pb_message_factory) {
ASSERT_EQ(0, server.Join());
}

void TestBaiduStdAuth(const butil::EndPoint& ep,
brpc::Controller& cntl,
int error_code, bool failed) {
brpc::Channel chan;
brpc::ChannelOptions copt;
copt.max_retry = 0;
copt.protocol = "baidu_std";
ASSERT_EQ(0, chan.Init(ep, &copt));

test::EchoRequest req;
test::EchoResponse res;
req.set_message(EXP_REQUEST);
test::EchoService_Stub stub(&chan);
stub.Echo(&cntl, &req, &res, NULL);
ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText();
ASSERT_EQ(cntl.ErrorCode(), error_code);
}

void TestHttpAuth(const butil::EndPoint& ep,
brpc::Controller& cntl,
int status_code, bool failed) {
brpc::Channel chan;
brpc::ChannelOptions copt;
copt.max_retry = 0;
copt.protocol = "http";
ASSERT_EQ(0, chan.Init(ep, &copt));

cntl.http_request().uri() = "/EchoService/Echo";
cntl.request_attachment().append(R"({"message": "hello"})");
cntl.http_request().set_method(brpc::HTTP_METHOD_POST);
test::EchoService_Stub stub(&chan);
chan.CallMethod(NULL, &cntl, NULL, NULL, NULL);
ASSERT_EQ(cntl.Failed(), failed) << cntl.ErrorText();
ASSERT_EQ(cntl.http_response().status_code(), status_code);
}

TEST_F(ServerTest, auth) {
butil::EndPoint ep;
ASSERT_EQ(0, str2endpoint("127.0.0.1:8613", &ep));
brpc::Server server;
EchoServiceImpl service;
ASSERT_EQ(0, server.AddService(&service, brpc::SERVER_DOESNT_OWN_SERVICE));
MyAuthenticator auth;
brpc::ServerOptions opt;
opt.auth = &auth;
ASSERT_EQ(0, server.Start(ep, &opt));

brpc::Controller cntl;
TestBaiduStdAuth(ep, cntl, 0, false);

g_verify_success = false;
cntl.Reset();
TestBaiduStdAuth(ep, cntl, brpc::ERPCAUTH, true);
ASSERT_NE(cntl.ErrorText().find(g_unauthorized_error_text), std::string::npos);

cntl.Reset();
TestHttpAuth(ep, cntl, brpc::HTTP_STATUS_FORBIDDEN, true);
ASSERT_NE(cntl.response_attachment().to_string().find(g_unauthorized_error_text),
std::string::npos);

g_verify_success = true;
cntl.Reset();
cntl.http_request().SetHeader("Authorization", "123");
TestHttpAuth(ep, cntl, brpc::HTTP_STATUS_OK, false);

ASSERT_EQ(0, server.Stop(0));
ASSERT_EQ(0, server.Join());
}

} //namespace