From 25e6c1228650e422793b489ce30de8d4b61959c7 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Sun, 27 Aug 2023 17:54:13 +0000 Subject: [PATCH] [RPC] Enhance RPC Protocol to support TVM Object This PR introduces object support in TVM RPC protocol by introducing three new interfaces in `rpc_reference.h`: - `uint64_t GetObjectBytes(Object* obj)`, which is a required implementation that returns the length of the object during serialization; - `void WriteObject(Object* obj)` used to serialize an object to a writable channel; - `void ReadObject(int* type_code, TVMValue* value)`, which deserializes a TVM Object from a channel. To serialize an object, a recommended paradigm is to write its `type_index` first, and then its content. For example, `ShapeTuple` can be serialized as: ```C++ // pseudocode void WriteObject(Object* obj) { if (obj is ShapeTuple) { this->Write(type_index of ShapeTuple); this->Write(obj->ndim); this->WriteArray(obj->shape); } else { throw Unsupported; } } uint64_t GetObjectBytes(Object* obj) { uint64_t result = 0; if (obj is ShapeTuple) { result += sizeof(uint32_t); # for `type_index` result += sizeof(int32_t); # for `ndim` result += sizeof(int64_t) * obj->ndim; # for content of the shape } else { throw Unsupported; } return result; } ``` To deserialize an object, similar to serialization, the recommended approach paradigm is to read `type_index` and disptch based on it. Caveat on deserialization: RPC Reference itself does not own or allocate any memory to store objects, meaning extra logic is usually required in `ReadObject` to keep their liveness. --- include/tvm/runtime/object.h | 2 ++ src/runtime/minrpc/minrpc_server.h | 10 ++++++++++ src/runtime/minrpc/minrpc_server_logging.h | 4 ++++ src/runtime/minrpc/rpc_reference.h | 13 +++++++++++++ src/runtime/rpc/rpc_endpoint.cc | 10 ++++++++++ 5 files changed, 39 insertions(+) diff --git a/include/tvm/runtime/object.h b/include/tvm/runtime/object.h index b10aff96a116..94644d797c1a 100644 --- a/include/tvm/runtime/object.h +++ b/include/tvm/runtime/object.h @@ -72,6 +72,8 @@ struct TypeIndex { kRuntimeShapeTuple = 6, /*! \brief runtime::PackedFunc. */ kRuntimePackedFunc = 7, + /*! \brief runtime::DRef */ + kRuntimeDiscoDRef = 8, // static assignments that may subject to change. kRuntimeClosure, kRuntimeADT, diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h index 4684aa0e1616..cca47f80b9df 100644 --- a/src/runtime/minrpc/minrpc_server.h +++ b/src/runtime/minrpc/minrpc_server.h @@ -131,6 +131,12 @@ class MinRPCReturns : public MinRPCReturnInterface { io_->Exit(static_cast(code)); } + void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); } + uint64_t GetObjectBytes(void* obj) { + this->ThrowError(RPCServerStatus::kUnknownTypeCode); + return 0; + } + template void Write(const T& data) { static_assert(std::is_trivial::value && std::is_standard_layout::value, @@ -748,6 +754,10 @@ class MinRPCServer { return ReadRawBytes(data, sizeof(T) * count); } + void ReadObject(int* tcode, TVMValue* value) { + this->ThrowError(RPCServerStatus::kUnknownTypeCode); + } + private: void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) { RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this); diff --git a/src/runtime/minrpc/minrpc_server_logging.h b/src/runtime/minrpc/minrpc_server_logging.h index deca2156ce62..89650efe9a1b 100644 --- a/src/runtime/minrpc/minrpc_server_logging.h +++ b/src/runtime/minrpc/minrpc_server_logging.h @@ -135,6 +135,10 @@ class MinRPCSniffer { return ReadRawBytes(data, sizeof(T) * count); } + void ReadObject(int* tcode, TVMValue* value) { + this->ThrowError(RPCServerStatus::kUnknownTypeCode); + } + private: bool ReadRawBytes(void* data, size_t size) { uint8_t* buf = reinterpret_cast(data); diff --git a/src/runtime/minrpc/rpc_reference.h b/src/runtime/minrpc/rpc_reference.h index 20d89ad52a1d..e16f09cb9dee 100644 --- a/src/runtime/minrpc/rpc_reference.h +++ b/src/runtime/minrpc/rpc_reference.h @@ -27,6 +27,9 @@ namespace tvm { namespace runtime { +// Forward declare TVM Object to use `Object*` in RPC protocol. +class Object; + /*! \brief The current RPC procotol version. */ constexpr const char* kRPCProtocolVer = "0.8.0"; @@ -194,6 +197,8 @@ struct RPCReference { num_bytes_ += sizeof(T) * num; } + void WriteObject(Object* obj) { num_bytes_ += channel_->GetObjectBytes(obj); } + void ThrowError(RPCServerStatus status) { channel_->ThrowError(status); } uint64_t num_bytes() const { return num_bytes_; } @@ -364,6 +369,10 @@ struct RPCReference { channel->WriteArray(bytes->data, len); break; } + case kTVMObjectHandle: { + channel->WriteObject(static_cast(value.v_handle)); + break; + } default: { channel->ThrowError(RPCServerStatus::kUnknownTypeCode); break; @@ -461,6 +470,10 @@ struct RPCReference { value.v_handle = ReceiveDLTensor(channel); break; } + case kTVMObjectHandle: { + channel->ReadObject(&tcodes[i], &value); + break; + } default: { channel->ThrowError(RPCServerStatus::kUnknownTypeCode); break; diff --git a/src/runtime/rpc/rpc_endpoint.cc b/src/runtime/rpc/rpc_endpoint.cc index 30606adf1b6f..f2c09132fc70 100644 --- a/src/runtime/rpc/rpc_endpoint.cc +++ b/src/runtime/rpc/rpc_endpoint.cc @@ -219,6 +219,16 @@ class RPCEndpoint::EventHandler : public dmlc::Stream { this->Write(cdata); } + void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); } + uint64_t GetObjectBytes(void* obj) { + this->ThrowError(RPCServerStatus::kUnknownTypeCode); + return 0; + } + + void ReadObject(int* tcode, TVMValue* value) { + this->ThrowError(RPCServerStatus::kUnknownTypeCode); + } + void MessageDone() { // Unused here, implemented for microTVM framing layer. }