Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions cpp/src/arrow/flight/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ class GrpcIpcMessageReader : public ipc::MessageReader {
protected:
Status OverrideWithServerError(Status&& st) {
// Get the gRPC status if not OK, to propagate any server error message
RETURN_NOT_OK(internal::FromGrpcStatus(stream_->Finish()));
RETURN_NOT_OK(internal::FromGrpcStatus(stream_->Finish(), &rpc_->context));
return std::move(st);
}

Expand Down Expand Up @@ -458,7 +458,7 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter {
pb::PutResult message;
while (writer_->Read(&message)) {
}
RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish()));
RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish(), &rpc_->context));
if (!finished_writes) {
return Status::UnknownError(
"Could not finish writing record batches before closing");
Expand Down Expand Up @@ -577,7 +577,7 @@ class FlightClient::FlightClientImpl {
RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming));
// Explicitly close our side of the connection
bool finished_writes = stream->WritesDone();
RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish()));
RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context));
if (!finished_writes) {
return Status::UnknownError("Could not finish writing before closing");
}
Expand All @@ -604,7 +604,7 @@ class FlightClient::FlightClientImpl {
}

listing->reset(new SimpleFlightListing(std::move(flights)));
return internal::FromGrpcStatus(stream->Finish());
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}

Status DoAction(const FlightCallOptions& options, const Action& action,
Expand All @@ -628,7 +628,7 @@ class FlightClient::FlightClientImpl {

*results = std::unique_ptr<ResultStream>(
new SimpleResultStream(std::move(materialized_results)));
return internal::FromGrpcStatus(stream->Finish());
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}

Status ListActions(const FlightCallOptions& options, std::vector<ActionType>* types) {
Expand All @@ -645,7 +645,7 @@ class FlightClient::FlightClientImpl {
RETURN_NOT_OK(internal::FromProto(pb_type, &type));
types->emplace_back(std::move(type));
}
return internal::FromGrpcStatus(stream->Finish());
return internal::FromGrpcStatus(stream->Finish(), &rpc.context);
}

Status GetFlightInfo(const FlightCallOptions& options,
Expand All @@ -659,7 +659,7 @@ class FlightClient::FlightClientImpl {
ClientRpc rpc(options);
RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
Status s = internal::FromGrpcStatus(
stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response));
stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response), &rpc.context);
RETURN_NOT_OK(s);

FlightInfo::Data info_data;
Expand All @@ -678,7 +678,7 @@ class FlightClient::FlightClientImpl {
ClientRpc rpc(options);
RETURN_NOT_OK(rpc.SetToken(auth_handler_.get()));
Status s = internal::FromGrpcStatus(
stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response));
stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response), &rpc.context);
RETURN_NOT_OK(s);

std::string str;
Expand Down
25 changes: 25 additions & 0 deletions cpp/src/arrow/flight/flight_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,24 @@ TEST(TestFlight, RoundtripStatus) {
MakeFlightError(FlightStatusCode::Unavailable, "Test message"));
ASSERT_NE(nullptr, detail);
ASSERT_EQ(FlightStatusCode::Unavailable, detail->code());

Status status = internal::FromGrpcStatus(
internal::ToGrpcStatus(Status::NotImplemented("Sentinel")));
ASSERT_TRUE(status.IsNotImplemented());
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));

status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel")));
ASSERT_TRUE(status.IsInvalid());
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));

status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel")));
ASSERT_TRUE(status.IsKeyError());
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));

status =
internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel")));
ASSERT_TRUE(status.IsAlreadyExists());
ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel"));
}

TEST(TestFlight, GetPort) {
Expand Down Expand Up @@ -965,6 +983,13 @@ TEST_F(TestFlightClient, DoAction) {
ASSERT_EQ(nullptr, result);
}

TEST_F(TestFlightClient, RoundTripStatus) {
const auto descr = FlightDescriptor::Command("status-outofmemory");
std::unique_ptr<FlightInfo> info;
const auto status = client_->GetFlightInfo(descr, &info);
ASSERT_RAISES(OutOfMemory, status);
}

TEST_F(TestFlightClient, Issue5095) {
// Make sure the server-side error message is reflected to the
// client
Expand Down
111 changes: 107 additions & 4 deletions cpp/src/arrow/flight/internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "arrow/flight/protocol_internal.h"

#include <cstddef>
#include <map>
#include <memory>
#include <sstream>
#include <string>
Expand All @@ -45,12 +46,72 @@ namespace flight {
namespace internal {

const char* kGrpcAuthHeader = "auth-token-bin";
const char* kGrpcStatusCodeHeader = "x-arrow-status";
const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin";
const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin";

static Status StatusCodeFromString(const grpc::string_ref& code_ref, StatusCode* code) {
// Bounce through std::string to get a proper null-terminated C string
const auto code_int = std::atoi(std::string(code_ref.data(), code_ref.size()).c_str());
switch (code_int) {

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this worth refactoring out into status.h?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(this being the int<->statuscode mapping)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... perhaps as a helper function?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've made a helper function, but kept it internal, since there doesn't seem to be a need to serialize statuses outside of Flight.

case static_cast<int>(StatusCode::OutOfMemory):
case static_cast<int>(StatusCode::KeyError):
case static_cast<int>(StatusCode::TypeError):
case static_cast<int>(StatusCode::Invalid):
case static_cast<int>(StatusCode::IOError):
case static_cast<int>(StatusCode::CapacityError):
case static_cast<int>(StatusCode::IndexError):
case static_cast<int>(StatusCode::UnknownError):
case static_cast<int>(StatusCode::NotImplemented):
case static_cast<int>(StatusCode::SerializationError):
case static_cast<int>(StatusCode::RError):
case static_cast<int>(StatusCode::CodeGenError):
case static_cast<int>(StatusCode::ExpressionValidationError):
case static_cast<int>(StatusCode::ExecutionError):
case static_cast<int>(StatusCode::AlreadyExists): {
*code = static_cast<StatusCode>(code_int);
return Status::OK();
}
default:
// Code is invalid
return Status::UnknownError("Unknown Arrow status code", code_ref);
}
}

Status FromGrpcStatus(const grpc::Status& grpc_status) {
if (grpc_status.ok()) {
return Status::OK();
/// Try to extract a status from gRPC trailers.
/// Return Status::OK if found, an error otherwise.
static Status FromGrpcContext(const grpc::ClientContext& ctx, Status* status) {
const std::multimap<grpc::string_ref, grpc::string_ref>& trailers =
ctx.GetServerTrailingMetadata();
const auto code_val = trailers.find(kGrpcStatusCodeHeader);
if (code_val == trailers.end()) {
return Status::IOError("Status code header not found");
}

const grpc::string_ref code_ref = (*code_val).second;
StatusCode code;
RETURN_NOT_OK(StatusCodeFromString(code_ref, &code));

const auto message_val = trailers.find(kGrpcStatusMessageHeader);
if (message_val == trailers.end()) {
return Status::IOError("Status message header not found");
}

const grpc::string_ref message_ref = (*message_val).second;
std::string message = std::string(message_ref.data(), message_ref.size());
const auto detail_val = trailers.find(kGrpcStatusDetailHeader);
if (detail_val != trailers.end()) {
const grpc::string_ref detail_ref = (*detail_val).second;
message += ". Detail: ";
message += std::string(detail_ref.data(), detail_ref.size());
}
*status = Status(code, message);
return Status::OK();
}

/// Convert a gRPC status to an Arrow status, ignoring any
/// implementation-defined headers that encode further detail.
static Status FromGrpcCode(const grpc::Status& grpc_status) {
switch (grpc_status.error_code()) {
case grpc::StatusCode::OK:
return Status::OK();
Expand Down Expand Up @@ -123,7 +184,28 @@ Status FromGrpcStatus(const grpc::Status& grpc_status) {
}
}

grpc::Status ToGrpcStatus(const Status& arrow_status) {
Status FromGrpcStatus(const grpc::Status& grpc_status, grpc::ClientContext* ctx) {
const Status status = FromGrpcCode(grpc_status);

if (!status.ok() && ctx) {
Status arrow_status;

if (!FromGrpcContext(*ctx, &arrow_status).ok()) {
// If we fail to decode a more detailed status from the headers,
// proceed normally
return status;
}

if (status.detail()) {
return arrow_status.WithDetail(status.detail());
}
return arrow_status;
}
return status;
}

/// Convert an Arrow status to a gRPC status.
static grpc::Status ToRawGrpcStatus(const Status& arrow_status) {
if (arrow_status.ok()) {
return grpc::Status::OK;
}
Expand Down Expand Up @@ -164,10 +246,31 @@ grpc::Status ToGrpcStatus(const Status& arrow_status) {
grpc_code = grpc::StatusCode::UNIMPLEMENTED;
} else if (arrow_status.IsInvalid()) {
grpc_code = grpc::StatusCode::INVALID_ARGUMENT;
} else if (arrow_status.IsKeyError()) {
grpc_code = grpc::StatusCode::NOT_FOUND;
} else if (arrow_status.IsAlreadyExists()) {
grpc_code = grpc::StatusCode::ALREADY_EXISTS;
}
return grpc::Status(grpc_code, message);
}

/// Convert an Arrow status to a gRPC status, and add extra headers to
/// the response to encode the original Arrow status.
grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx) {
grpc::Status status = ToRawGrpcStatus(arrow_status);
if (!status.ok() && ctx) {
const std::string code = std::to_string(static_cast<int>(arrow_status.code()));
ctx->AddTrailingMetadata(internal::kGrpcStatusCodeHeader, code);
ctx->AddTrailingMetadata(internal::kGrpcStatusMessageHeader, arrow_status.message());
if (arrow_status.detail()) {
const std::string detail_string = arrow_status.detail()->ToString();
ctx->AddTrailingMetadata(internal::kGrpcStatusDetailHeader, detail_string);
}
}

return status;
}

// ActionType

Status FromProto(const pb::ActionType& pb_type, ActionType* type) {
Expand Down
20 changes: 18 additions & 2 deletions cpp/src/arrow/flight/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,30 @@ namespace internal {
ARROW_FLIGHT_EXPORT
extern const char* kGrpcAuthHeader;

/// The name of the header used to pass the exact Arrow status code.
ARROW_FLIGHT_EXPORT
extern const char* kGrpcStatusCodeHeader;

/// The name of the header used to pass the exact Arrow status message.
ARROW_FLIGHT_EXPORT
extern const char* kGrpcStatusMessageHeader;

/// The name of the header used to pass the exact Arrow status detail.
ARROW_FLIGHT_EXPORT
extern const char* kGrpcStatusDetailHeader;

ARROW_FLIGHT_EXPORT
Status SchemaToString(const Schema& schema, std::string* out);

/// Convert a gRPC status to an Arrow status. Optionally, provide a
/// ClientContext to recover the exact Arrow status if it was passed
/// over the wire.
ARROW_FLIGHT_EXPORT
Status FromGrpcStatus(const grpc::Status& grpc_status);
Status FromGrpcStatus(const grpc::Status& grpc_status,
grpc::ClientContext* ctx = nullptr);

ARROW_FLIGHT_EXPORT
grpc::Status ToGrpcStatus(const Status& arrow_status);
grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx = nullptr);

// These functions depend on protobuf types which are not exported in the Flight DLL.

Expand Down
24 changes: 14 additions & 10 deletions cpp/src/arrow/flight/server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,8 @@ class GrpcServerAuthSender : public ServerAuthSender {

class FlightServiceImpl;
class GrpcServerCallContext : public ServerCallContext {
explicit GrpcServerCallContext(grpc::ServerContext* context) : context_(context) {}

const std::string& peer_identity() const override { return peer_identity_; }

// Helper method that runs interceptors given the result of an RPC,
Expand All @@ -248,7 +250,10 @@ class GrpcServerCallContext : public ServerCallContext {
for (const auto& instance : middleware_) {
instance->CallCompleted(status);
}
return internal::ToGrpcStatus(status);

// Set custom headers to map the exact Arrow status for clients
// who want it.
return internal::ToGrpcStatus(status, context_);
}

ServerMiddleware* GetMiddleware(const std::string& key) const override {
Expand Down Expand Up @@ -334,7 +339,6 @@ class FlightServiceImpl : public FlightService::Service {
// Authenticate the client (if applicable) and construct the call context
grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context,
GrpcServerCallContext& flight_context) {
flight_context.context_ = context;
if (!auth_handler_) {
flight_context.peer_identity_ = "";
} else {
Expand Down Expand Up @@ -386,7 +390,7 @@ class FlightServiceImpl : public FlightService::Service {
grpc::Status Handshake(
ServerContext* context,
grpc::ServerReaderWriter<pb::HandshakeResponse, pb::HandshakeRequest>* stream) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(
MakeCallContext(FlightMethod::Handshake, context, flight_context));

Expand All @@ -405,7 +409,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request,
ServerWriter<pb::FlightInfo>* writer) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(
CheckAuth(FlightMethod::ListFlights, context, flight_context));

Expand All @@ -428,7 +432,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request,
pb::FlightInfo* response) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(
CheckAuth(FlightMethod::GetFlightInfo, context, flight_context));

Expand All @@ -453,7 +457,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request,
pb::SchemaResult* response) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context));

CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null");
Expand All @@ -477,7 +481,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status DoGet(ServerContext* context, const pb::Ticket* request,
ServerWriter<pb::FlightData>* writer) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context));

CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null");
Expand Down Expand Up @@ -517,7 +521,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status DoPut(ServerContext* context,
grpc::ServerReaderWriter<pb::PutResult, pb::FlightData>* reader) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context));

auto message_reader =
Expand All @@ -532,7 +536,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status ListActions(ServerContext* context, const pb::Empty* request,
ServerWriter<pb::ActionType>* writer) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(
CheckAuth(FlightMethod::ListActions, context, flight_context));
// Retrieve the listing from the implementation
Expand All @@ -543,7 +547,7 @@ class FlightServiceImpl : public FlightService::Service {

grpc::Status DoAction(ServerContext* context, const pb::Action* request,
ServerWriter<pb::Result>* writer) {
GrpcServerCallContext flight_context;
GrpcServerCallContext flight_context(context);
GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context));
CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null");
Action action;
Expand Down
Loading