Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions src/brpc/policy/http_rpc_protocol.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1331,7 +1331,25 @@ ParseResult ParseHttpMessage(butil::IOBuf *source, Socket *socket,
}
}

static void SendUnauthorizedResponse(const std::string& user_error_text, Socket* socket) {
static void SendUnauthorizedResponse(const std::string& user_error_text, Socket* socket, const InputMessageBase* msg) {
HttpContext* http_request = (HttpContext*)msg;
const bool is_http2 = http_request->header().is_http2();
if (is_http2) {
// for grpc client
const H2StreamContext* h2_sctx = static_cast<const H2StreamContext*>(msg);
brpc::Controller cntl;
cntl.http_response().set_status_code(200);
cntl.http_response().set_content_type("application/grpc");
cntl.SetFailed(ERPCAUTH, "%s", user_error_text.empty() ? "Fail to authenticate" : user_error_text.c_str());

SocketMessagePtr<H2UnsentResponse> h2_response(
H2UnsentResponse::New(&cntl, h2_sctx->stream_id(), true));
brpc::Socket::WriteOptions opt;
opt.ignore_eovercrowded = true;
socket->Write(h2_response, &opt);
return;
}

// Send 403(forbidden) to client.
HttpHeader header;
header.set_status_code(HTTP_STATUS_FORBIDDEN);
Expand Down Expand Up @@ -1374,7 +1392,7 @@ bool VerifyHttpRequest(const InputMessageBase* msg) {
const std::string *authorization
= http_request->header().GetHeader(common->AUTHORIZATION);
if (authorization == NULL) {
SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket);
SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket, msg);
return false;
}
butil::EndPoint user_addr;
Expand All @@ -1383,7 +1401,7 @@ bool VerifyHttpRequest(const InputMessageBase* msg) {
}
if (auth->VerifyCredential(*authorization, user_addr,
socket->mutable_auth_context()) != 0) {
SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket);
SendUnauthorizedResponse(auth->GetUnauthorizedErrorText(), socket, msg);
return false;
}

Expand Down
149 changes: 149 additions & 0 deletions test/brpc_http_rpc_protocol_unittest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2080,4 +2080,153 @@ TEST_F(HttpTest, http_expect) {
ASSERT_EQ(imsg_guard->header().status_code(), brpc::HTTP_STATUS_OK);
}

// Test gRPC authentication failure response format
TEST_F(HttpTest, grpc_auth_failed_response) {
// Set up an authenticator that returns authentication failure
class FailingAuthenticator : public brpc::Authenticator {
public:
int GenerateCredential(std::string*) const override { return 0; }
int VerifyCredential(const std::string&, const butil::EndPoint&, brpc::AuthContext*) const override {
return -1; // Simulate authentication failure
}
std::string GetUnauthorizedErrorText() const override {
return "Authentication failed for gRPC";
}
};

FailingAuthenticator failing_auth;
const brpc::Authenticator* original_auth = _server._options.auth;
_server._options.auth = &failing_auth;

// Test HTTP/2.0 gRPC request authentication failure using H2StreamContext
{
// Create H2Context for the connection
brpc::policy::H2Context* h2_ctx = new brpc::policy::H2Context(_socket.get(), &_server);
ASSERT_EQ(0, h2_ctx->Init());
_socket->initialize_parsing_context(&h2_ctx);

// Create H2StreamContext representing a gRPC request
brpc::policy::H2StreamContext* h2_msg = new brpc::policy::H2StreamContext(false);
h2_msg->header().set_content_type("application/grpc"); // gRPC content type
h2_msg->header().uri().set_path("/EchoService/Echo");
h2_msg->header().set_method(brpc::HTTP_METHOD_POST);

// Initialize the stream context with connection context and stream ID
h2_msg->Init(h2_ctx, 1); // stream_id = 1

// Set socket and arg using existing test pattern
if (h2_msg->_socket == NULL) {
_socket->ReAddress(&h2_msg->_socket);
}
h2_msg->_arg = &_server;

// Verify that authentication should fail for HTTP/2 gRPC request
bool verify_result = brpc::policy::VerifyHttpRequest(h2_msg);
EXPECT_FALSE(verify_result);

// Check if response has been written to pipe for HTTP/2
int bytes_in_pipe = 0;
ioctl(_pipe_fds[0], FIONREAD, &bytes_in_pipe);
EXPECT_GT(bytes_in_pipe, 0);

// Read and verify HTTP/2 response content
butil::IOPortal buf;
EXPECT_EQ((ssize_t)bytes_in_pipe, buf.append_from_file_descriptor(_pipe_fds[0], 1024));

// For HTTP/2, the response format should be different from HTTP/1.1
// Let's check if it contains HTTP/2 frame data
std::string response_str = buf.to_string();
EXPECT_GT(response_str.length(), 0);

// HTTP/2 gRPC response should contain:
// 1. grpc-status header (error code)
// 2. grpc-message header (error message)
// 3. Our authentication failure text (might be URL encoded)
EXPECT_TRUE(response_str.find("grpc-status") != std::string::npos);
EXPECT_TRUE(response_str.find("grpc-message") != std::string::npos);
EXPECT_TRUE(response_str.find("Authentication") != std::string::npos);
EXPECT_TRUE(response_str.find("failed") != std::string::npos);
EXPECT_TRUE(response_str.find("gRPC") != std::string::npos);

h2_msg->Destroy();
}

// Restore original auth settings
_server._options.auth = original_auth;
}

// Test HTTP/1.0 authentication failure response format
TEST_F(HttpTest, http10_auth_failed_response) {
// Set up an authenticator that returns authentication failure
class FailingAuthenticator : public brpc::Authenticator {
public:
int GenerateCredential(std::string*) const override { return 0; }
int VerifyCredential(const std::string&, const butil::EndPoint&, brpc::AuthContext*) const override {
return -1; // Simulate authentication failure
}
std::string GetUnauthorizedErrorText() const override {
return "Authentication failed for HTTP/1.0";
}
};

FailingAuthenticator failing_auth;
const brpc::Authenticator* original_auth = _server._options.auth;
_server._options.auth = &failing_auth;

// Test HTTP/1.0 request authentication failure (should return HTTP 403)
{
brpc::policy::HttpContext* http_msg = MakePostRequestMessage("/EchoService/Echo");
http_msg->header().set_version(1, 0); // Set to HTTP/1.0
http_msg->header().set_content_type("application/json"); // Regular HTTP request

// Use VerifyMessage to properly set up socket and arg (like other tests)
VerifyMessage(http_msg, false);

// Verify that authentication should fail for HTTP/1.0 request
bool verify_result = brpc::policy::VerifyHttpRequest(http_msg);
EXPECT_FALSE(verify_result);

// Check HTTP/1.0 response format
int bytes_in_pipe = 0;
ioctl(_pipe_fds[0], FIONREAD, &bytes_in_pipe);
EXPECT_GT(bytes_in_pipe, 0);

butil::IOPortal buf;
EXPECT_EQ((ssize_t)bytes_in_pipe, buf.append_from_file_descriptor(_pipe_fds[0], 1024));

// Parse HTTP/1.0 response and verify format
brpc::ParseResult pr = brpc::policy::ParseHttpMessage(&buf, _socket.get(), false, NULL);
EXPECT_EQ(brpc::PARSE_OK, pr.error());
brpc::policy::HttpContext* response_msg = static_cast<brpc::policy::HttpContext*>(pr.message());

// Verify HTTP/1.x response format (server may respond with 1.1 even for 1.0 requests)
EXPECT_EQ(1, response_msg->header().major_version());
EXPECT_TRUE(response_msg->header().minor_version() >= 0);
EXPECT_EQ(brpc::HTTP_STATUS_FORBIDDEN, response_msg->header().status_code());

// Check response body content
std::string body_content = response_msg->body().to_string();
EXPECT_TRUE(body_content.find("Authentication failed") != std::string::npos);
EXPECT_TRUE(body_content.find("1004") != std::string::npos); // brpc error code

// Verify HTTP headers for HTTP/1.0
const std::string* content_length = response_msg->header().GetHeader("Content-Length");
EXPECT_TRUE(content_length != NULL);
EXPECT_GT(std::stoi(*content_length), 0);

// Content-Type may not always be set for error responses, check if present
const std::string* content_type = response_msg->header().GetHeader("Content-Type");
if (content_type != NULL) {
// If present, should contain text
EXPECT_TRUE(content_type->find("text") != std::string::npos);
}

response_msg->Destroy();
http_msg->Destroy();
}

// Restore original auth settings
_server._options.auth = original_auth;
}

} //namespace
Loading