diff --git a/cpp/src/arrow/flight/client.cc b/cpp/src/arrow/flight/client.cc index 036d6d334d4a..eae31a525534 100644 --- a/cpp/src/arrow/flight/client.cc +++ b/cpp/src/arrow/flight/client.cc @@ -306,7 +306,7 @@ class GrpcIpcMessageReader : public ipc::MessageReader { protected: Status OverrideWithServerError(Status&& st) { // Get the gRPC status if not OK, to propagate any server error message - RETURN_NOT_OK(internal::FromGrpcStatus(stream_->Finish())); + RETURN_NOT_OK(internal::FromGrpcStatus(stream_->Finish(), &rpc_->context)); return std::move(st); } @@ -458,7 +458,7 @@ class DoPutPayloadWriter : public ipc::internal::IpcPayloadWriter { pb::PutResult message; while (writer_->Read(&message)) { } - RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish())); + RETURN_NOT_OK(internal::FromGrpcStatus(writer_->Finish(), &rpc_->context)); if (!finished_writes) { return Status::UnknownError( "Could not finish writing record batches before closing"); @@ -577,7 +577,7 @@ class FlightClient::FlightClientImpl { RETURN_NOT_OK(auth_handler_->Authenticate(&outgoing, &incoming)); // Explicitly close our side of the connection bool finished_writes = stream->WritesDone(); - RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish())); + RETURN_NOT_OK(internal::FromGrpcStatus(stream->Finish(), &rpc.context)); if (!finished_writes) { return Status::UnknownError("Could not finish writing before closing"); } @@ -604,7 +604,7 @@ class FlightClient::FlightClientImpl { } listing->reset(new SimpleFlightListing(std::move(flights))); - return internal::FromGrpcStatus(stream->Finish()); + return internal::FromGrpcStatus(stream->Finish(), &rpc.context); } Status DoAction(const FlightCallOptions& options, const Action& action, @@ -628,7 +628,7 @@ class FlightClient::FlightClientImpl { *results = std::unique_ptr( new SimpleResultStream(std::move(materialized_results))); - return internal::FromGrpcStatus(stream->Finish()); + return internal::FromGrpcStatus(stream->Finish(), &rpc.context); } Status ListActions(const FlightCallOptions& options, std::vector* types) { @@ -645,7 +645,7 @@ class FlightClient::FlightClientImpl { RETURN_NOT_OK(internal::FromProto(pb_type, &type)); types->emplace_back(std::move(type)); } - return internal::FromGrpcStatus(stream->Finish()); + return internal::FromGrpcStatus(stream->Finish(), &rpc.context); } Status GetFlightInfo(const FlightCallOptions& options, @@ -659,7 +659,7 @@ class FlightClient::FlightClientImpl { ClientRpc rpc(options); RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); Status s = internal::FromGrpcStatus( - stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response)); + stub_->GetFlightInfo(&rpc.context, pb_descriptor, &pb_response), &rpc.context); RETURN_NOT_OK(s); FlightInfo::Data info_data; @@ -678,7 +678,7 @@ class FlightClient::FlightClientImpl { ClientRpc rpc(options); RETURN_NOT_OK(rpc.SetToken(auth_handler_.get())); Status s = internal::FromGrpcStatus( - stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response)); + stub_->GetSchema(&rpc.context, pb_descriptor, &pb_response), &rpc.context); RETURN_NOT_OK(s); std::string str; diff --git a/cpp/src/arrow/flight/flight_test.cc b/cpp/src/arrow/flight/flight_test.cc index 7e0d414a07cf..5b1b273f3beb 100644 --- a/cpp/src/arrow/flight/flight_test.cc +++ b/cpp/src/arrow/flight/flight_test.cc @@ -258,6 +258,24 @@ TEST(TestFlight, RoundtripStatus) { MakeFlightError(FlightStatusCode::Unavailable, "Test message")); ASSERT_NE(nullptr, detail); ASSERT_EQ(FlightStatusCode::Unavailable, detail->code()); + + Status status = internal::FromGrpcStatus( + internal::ToGrpcStatus(Status::NotImplemented("Sentinel"))); + ASSERT_TRUE(status.IsNotImplemented()); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); + + status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::Invalid("Sentinel"))); + ASSERT_TRUE(status.IsInvalid()); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); + + status = internal::FromGrpcStatus(internal::ToGrpcStatus(Status::KeyError("Sentinel"))); + ASSERT_TRUE(status.IsKeyError()); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); + + status = + internal::FromGrpcStatus(internal::ToGrpcStatus(Status::AlreadyExists("Sentinel"))); + ASSERT_TRUE(status.IsAlreadyExists()); + ASSERT_THAT(status.message(), ::testing::HasSubstr("Sentinel")); } TEST(TestFlight, GetPort) { @@ -965,6 +983,13 @@ TEST_F(TestFlightClient, DoAction) { ASSERT_EQ(nullptr, result); } +TEST_F(TestFlightClient, RoundTripStatus) { + const auto descr = FlightDescriptor::Command("status-outofmemory"); + std::unique_ptr info; + const auto status = client_->GetFlightInfo(descr, &info); + ASSERT_RAISES(OutOfMemory, status); +} + TEST_F(TestFlightClient, Issue5095) { // Make sure the server-side error message is reflected to the // client diff --git a/cpp/src/arrow/flight/internal.cc b/cpp/src/arrow/flight/internal.cc index 39c0b92e263e..7e90c76f06bb 100644 --- a/cpp/src/arrow/flight/internal.cc +++ b/cpp/src/arrow/flight/internal.cc @@ -20,6 +20,7 @@ #include "arrow/flight/protocol_internal.h" #include +#include #include #include #include @@ -45,12 +46,72 @@ namespace flight { namespace internal { const char* kGrpcAuthHeader = "auth-token-bin"; +const char* kGrpcStatusCodeHeader = "x-arrow-status"; +const char* kGrpcStatusMessageHeader = "x-arrow-status-message-bin"; +const char* kGrpcStatusDetailHeader = "x-arrow-status-detail-bin"; + +static Status StatusCodeFromString(const grpc::string_ref& code_ref, StatusCode* code) { + // Bounce through std::string to get a proper null-terminated C string + const auto code_int = std::atoi(std::string(code_ref.data(), code_ref.size()).c_str()); + switch (code_int) { + case static_cast(StatusCode::OutOfMemory): + case static_cast(StatusCode::KeyError): + case static_cast(StatusCode::TypeError): + case static_cast(StatusCode::Invalid): + case static_cast(StatusCode::IOError): + case static_cast(StatusCode::CapacityError): + case static_cast(StatusCode::IndexError): + case static_cast(StatusCode::UnknownError): + case static_cast(StatusCode::NotImplemented): + case static_cast(StatusCode::SerializationError): + case static_cast(StatusCode::RError): + case static_cast(StatusCode::CodeGenError): + case static_cast(StatusCode::ExpressionValidationError): + case static_cast(StatusCode::ExecutionError): + case static_cast(StatusCode::AlreadyExists): { + *code = static_cast(code_int); + return Status::OK(); + } + default: + // Code is invalid + return Status::UnknownError("Unknown Arrow status code", code_ref); + } +} -Status FromGrpcStatus(const grpc::Status& grpc_status) { - if (grpc_status.ok()) { - return Status::OK(); +/// Try to extract a status from gRPC trailers. +/// Return Status::OK if found, an error otherwise. +static Status FromGrpcContext(const grpc::ClientContext& ctx, Status* status) { + const std::multimap& trailers = + ctx.GetServerTrailingMetadata(); + const auto code_val = trailers.find(kGrpcStatusCodeHeader); + if (code_val == trailers.end()) { + return Status::IOError("Status code header not found"); } + const grpc::string_ref code_ref = (*code_val).second; + StatusCode code; + RETURN_NOT_OK(StatusCodeFromString(code_ref, &code)); + + const auto message_val = trailers.find(kGrpcStatusMessageHeader); + if (message_val == trailers.end()) { + return Status::IOError("Status message header not found"); + } + + const grpc::string_ref message_ref = (*message_val).second; + std::string message = std::string(message_ref.data(), message_ref.size()); + const auto detail_val = trailers.find(kGrpcStatusDetailHeader); + if (detail_val != trailers.end()) { + const grpc::string_ref detail_ref = (*detail_val).second; + message += ". Detail: "; + message += std::string(detail_ref.data(), detail_ref.size()); + } + *status = Status(code, message); + return Status::OK(); +} + +/// Convert a gRPC status to an Arrow status, ignoring any +/// implementation-defined headers that encode further detail. +static Status FromGrpcCode(const grpc::Status& grpc_status) { switch (grpc_status.error_code()) { case grpc::StatusCode::OK: return Status::OK(); @@ -123,7 +184,28 @@ Status FromGrpcStatus(const grpc::Status& grpc_status) { } } -grpc::Status ToGrpcStatus(const Status& arrow_status) { +Status FromGrpcStatus(const grpc::Status& grpc_status, grpc::ClientContext* ctx) { + const Status status = FromGrpcCode(grpc_status); + + if (!status.ok() && ctx) { + Status arrow_status; + + if (!FromGrpcContext(*ctx, &arrow_status).ok()) { + // If we fail to decode a more detailed status from the headers, + // proceed normally + return status; + } + + if (status.detail()) { + return arrow_status.WithDetail(status.detail()); + } + return arrow_status; + } + return status; +} + +/// Convert an Arrow status to a gRPC status. +static grpc::Status ToRawGrpcStatus(const Status& arrow_status) { if (arrow_status.ok()) { return grpc::Status::OK; } @@ -164,10 +246,31 @@ grpc::Status ToGrpcStatus(const Status& arrow_status) { grpc_code = grpc::StatusCode::UNIMPLEMENTED; } else if (arrow_status.IsInvalid()) { grpc_code = grpc::StatusCode::INVALID_ARGUMENT; + } else if (arrow_status.IsKeyError()) { + grpc_code = grpc::StatusCode::NOT_FOUND; + } else if (arrow_status.IsAlreadyExists()) { + grpc_code = grpc::StatusCode::ALREADY_EXISTS; } return grpc::Status(grpc_code, message); } +/// Convert an Arrow status to a gRPC status, and add extra headers to +/// the response to encode the original Arrow status. +grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx) { + grpc::Status status = ToRawGrpcStatus(arrow_status); + if (!status.ok() && ctx) { + const std::string code = std::to_string(static_cast(arrow_status.code())); + ctx->AddTrailingMetadata(internal::kGrpcStatusCodeHeader, code); + ctx->AddTrailingMetadata(internal::kGrpcStatusMessageHeader, arrow_status.message()); + if (arrow_status.detail()) { + const std::string detail_string = arrow_status.detail()->ToString(); + ctx->AddTrailingMetadata(internal::kGrpcStatusDetailHeader, detail_string); + } + } + + return status; +} + // ActionType Status FromProto(const pb::ActionType& pb_type, ActionType* type) { diff --git a/cpp/src/arrow/flight/internal.h b/cpp/src/arrow/flight/internal.h index 515637dd9b74..d165c90cd824 100644 --- a/cpp/src/arrow/flight/internal.h +++ b/cpp/src/arrow/flight/internal.h @@ -67,14 +67,30 @@ namespace internal { ARROW_FLIGHT_EXPORT extern const char* kGrpcAuthHeader; +/// The name of the header used to pass the exact Arrow status code. +ARROW_FLIGHT_EXPORT +extern const char* kGrpcStatusCodeHeader; + +/// The name of the header used to pass the exact Arrow status message. +ARROW_FLIGHT_EXPORT +extern const char* kGrpcStatusMessageHeader; + +/// The name of the header used to pass the exact Arrow status detail. +ARROW_FLIGHT_EXPORT +extern const char* kGrpcStatusDetailHeader; + ARROW_FLIGHT_EXPORT Status SchemaToString(const Schema& schema, std::string* out); +/// Convert a gRPC status to an Arrow status. Optionally, provide a +/// ClientContext to recover the exact Arrow status if it was passed +/// over the wire. ARROW_FLIGHT_EXPORT -Status FromGrpcStatus(const grpc::Status& grpc_status); +Status FromGrpcStatus(const grpc::Status& grpc_status, + grpc::ClientContext* ctx = nullptr); ARROW_FLIGHT_EXPORT -grpc::Status ToGrpcStatus(const Status& arrow_status); +grpc::Status ToGrpcStatus(const Status& arrow_status, grpc::ServerContext* ctx = nullptr); // These functions depend on protobuf types which are not exported in the Flight DLL. diff --git a/cpp/src/arrow/flight/server.cc b/cpp/src/arrow/flight/server.cc index 615d8bd2798b..891b27e33116 100644 --- a/cpp/src/arrow/flight/server.cc +++ b/cpp/src/arrow/flight/server.cc @@ -234,6 +234,8 @@ class GrpcServerAuthSender : public ServerAuthSender { class FlightServiceImpl; class GrpcServerCallContext : public ServerCallContext { + explicit GrpcServerCallContext(grpc::ServerContext* context) : context_(context) {} + const std::string& peer_identity() const override { return peer_identity_; } // Helper method that runs interceptors given the result of an RPC, @@ -248,7 +250,10 @@ class GrpcServerCallContext : public ServerCallContext { for (const auto& instance : middleware_) { instance->CallCompleted(status); } - return internal::ToGrpcStatus(status); + + // Set custom headers to map the exact Arrow status for clients + // who want it. + return internal::ToGrpcStatus(status, context_); } ServerMiddleware* GetMiddleware(const std::string& key) const override { @@ -334,7 +339,6 @@ class FlightServiceImpl : public FlightService::Service { // Authenticate the client (if applicable) and construct the call context grpc::Status CheckAuth(const FlightMethod& method, ServerContext* context, GrpcServerCallContext& flight_context) { - flight_context.context_ = context; if (!auth_handler_) { flight_context.peer_identity_ = ""; } else { @@ -386,7 +390,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status Handshake( ServerContext* context, grpc::ServerReaderWriter* stream) { - GrpcServerCallContext flight_context; + GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( MakeCallContext(FlightMethod::Handshake, context, flight_context)); @@ -405,7 +409,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status ListFlights(ServerContext* context, const pb::Criteria* request, ServerWriter* writer) { - GrpcServerCallContext flight_context; + GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( CheckAuth(FlightMethod::ListFlights, context, flight_context)); @@ -428,7 +432,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status GetFlightInfo(ServerContext* context, const pb::FlightDescriptor* request, pb::FlightInfo* response) { - GrpcServerCallContext flight_context; + GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( CheckAuth(FlightMethod::GetFlightInfo, context, flight_context)); @@ -453,7 +457,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status GetSchema(ServerContext* context, const pb::FlightDescriptor* request, pb::SchemaResult* response) { - GrpcServerCallContext flight_context; + GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::GetSchema, context, flight_context)); CHECK_ARG_NOT_NULL(flight_context, request, "FlightDescriptor cannot be null"); @@ -477,7 +481,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoGet(ServerContext* context, const pb::Ticket* request, ServerWriter* writer) { - GrpcServerCallContext flight_context; + GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoGet, context, flight_context)); CHECK_ARG_NOT_NULL(flight_context, request, "ticket cannot be null"); @@ -517,7 +521,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoPut(ServerContext* context, grpc::ServerReaderWriter* reader) { - GrpcServerCallContext flight_context; + GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoPut, context, flight_context)); auto message_reader = @@ -532,7 +536,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status ListActions(ServerContext* context, const pb::Empty* request, ServerWriter* writer) { - GrpcServerCallContext flight_context; + GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK( CheckAuth(FlightMethod::ListActions, context, flight_context)); // Retrieve the listing from the implementation @@ -543,7 +547,7 @@ class FlightServiceImpl : public FlightService::Service { grpc::Status DoAction(ServerContext* context, const pb::Action* request, ServerWriter* writer) { - GrpcServerCallContext flight_context; + GrpcServerCallContext flight_context(context); GRPC_RETURN_NOT_GRPC_OK(CheckAuth(FlightMethod::DoAction, context, flight_context)); CHECK_ARG_NOT_NULL(flight_context, request, "Action cannot be null"); Action action; diff --git a/cpp/src/arrow/flight/test_util.cc b/cpp/src/arrow/flight/test_util.cc index 1c1bda205676..64e8bd2d807d 100644 --- a/cpp/src/arrow/flight/test_util.cc +++ b/cpp/src/arrow/flight/test_util.cc @@ -169,6 +169,12 @@ class FlightTestServer : public FlightServerBase { Status GetFlightInfo(const ServerCallContext& context, const FlightDescriptor& request, std::unique_ptr* out) override { + // Test that Arrow-C++ status codes can make it through gRPC + if (request.type == FlightDescriptor::DescriptorType::CMD && + request.cmd == "status-outofmemory") { + return Status::OutOfMemory("Sentinel"); + } + std::vector flights = ExampleFlightInfo(); for (const auto& info : flights) { diff --git a/python/pyarrow/tests/test_flight.py b/python/pyarrow/tests/test_flight.py index c5cb865cee60..322702a2f6c3 100644 --- a/python/pyarrow/tests/test_flight.py +++ b/python/pyarrow/tests/test_flight.py @@ -597,7 +597,7 @@ def test_list_actions(): client = FlightClient(('localhost', server.port)) with pytest.raises( flight.FlightServerError, - match=("TypeError: Results of list_actions must be " + match=("Results of list_actions must be " "ActionType or tuple") ): list(client.list_actions()) @@ -623,6 +623,8 @@ def do_action(self, context, action): return iter([action.body]) elif action.type == 'bad-action': return iter(['foo']) + elif action.type == 'arrow-exception': + raise pa.ArrowMemoryError() def test_do_action_result_convenience(): @@ -643,8 +645,15 @@ def test_nicer_server_exceptions(): with ConvenienceServer() as server: client = FlightClient(('localhost', server.port)) with pytest.raises(flight.FlightServerError, - match="TypeError: a bytes-like object is required"): + match="a bytes-like object is required"): list(client.do_action('bad-action')) + # While Flight/C++ sends across the original status code, it + # doesn't get mapped to the equivalent code here, since we + # want to be able to distinguish between client- and server- + # side errors. + with pytest.raises(flight.FlightServerError, + match="ArrowMemoryError"): + list(client.do_action('arrow-exception')) def test_get_port():