From 057cce924852611d3ff1dbeaa667be3b81394203 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 28 Aug 2023 07:30:27 +0000 Subject: [PATCH] [Disco] Support ShapeTuple in Disco Protocol ShapeTuple is an essential item used in Relax stack as the runtime representation of shapes. It is also used as a current workaround to represent integers (1-d shape) given standalone non-constant integers are currently absent in Relax. This PR introduces formal support for ShapeTuple in Disco's communication protocol based on the recent enhancement of TVM RPC system to support TVM Objects: https://github.com/apache/tvm/pull/15631. --- src/runtime/disco/threaded_session.cc | 22 ++++++++++----- tests/python/disco/test_session.py | 39 +++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index 04860ef7129e..cb84918d2dc8 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -110,6 +110,9 @@ class DiscoThreadedMessageQueue : public dmlc::Stream { uint64_t GetObjectBytes(Object* obj) { if (obj->IsInstance()) { return sizeof(uint32_t) + sizeof(int64_t); + } else if (obj->IsInstance()) { + uint64_t size = static_cast(obj)->size; + return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char); } else if (obj->IsInstance()) { uint64_t ndim = static_cast(obj)->size; return sizeof(uint32_t) + sizeof(uint64_t) + ndim * sizeof(ShapeTupleObj::index_type); @@ -124,13 +127,16 @@ class DiscoThreadedMessageQueue : public dmlc::Stream { int64_t reg_id = static_cast(obj)->reg_id; this->Write(TypeIndex::kRuntimeDiscoDRef); this->Write(reg_id); + } else if (obj->IsInstance()) { + StringObj* str = static_cast(obj); + this->Write(TypeIndex::kRuntimeString); + this->Write(str->size); + this->WriteArray(str->data, str->size); } else if (obj->IsInstance()) { ShapeTupleObj* shape = static_cast(obj); this->Write(TypeIndex::kRuntimeShapeTuple); this->Write(shape->size); - for (uint64_t i = 0; i < shape->size; ++i) { - this->Write(shape->data[i]); - } + this->WriteArray(shape->data, shape->size); } else { LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " << obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")"; @@ -146,13 +152,17 @@ class DiscoThreadedMessageQueue : public dmlc::Stream { this->Read(&dref->reg_id); dref->session = Session{nullptr}; result = ObjectRef(std::move(dref)); + } else if (type_index == TypeIndex::kRuntimeString) { + uint64_t size = 0; + this->Read(&size); + std::string data(size, '\0'); + this->ReadArray(data.data(), size); + result = String(std::move(data)); } else if (type_index == TypeIndex::kRuntimeShapeTuple) { uint64_t ndim = 0; this->Read(&ndim); std::vector data(ndim); - for (ShapeTupleObj::index_type& i : data) { - this->Read(&i); - } + this->ReadArray(data.data(), ndim); result = ShapeTuple(std::move(data)); } else { LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: " diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index 2e5afe35f7bc..a2c0906f227e 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -23,6 +23,7 @@ import tvm from tvm import relax as rx from tvm._ffi import register_func +from tvm.runtime import ShapeTuple, String from tvm.runtime import disco as di from tvm.script import ir as I from tvm.script import relax as R @@ -106,6 +107,42 @@ def my_str_func(x: str): # pylint: disable=invalid-name assert result.debug_get_from_remote(i) == "hello_suffix" +def test_string_obj(): + num_workers = 4 + + @register_func("tests.disco.str_obj", override=True) + def my_str_func(x: String): # pylint: disable=invalid-name + assert isinstance(x, String) + return String(x + "_suffix") + + sess = di.ThreadedSession(num_workers=num_workers) + func: di.DPackedFunc = sess.get_global_func("tests.disco.str_obj") + result: di.DRef = func(String("hello")) + + for i in range(num_workers): + value = result.debug_get_from_remote(i) + assert isinstance(value, String) + assert value == "hello_suffix" + + +def test_shape_tuple(): + num_workers = 4 + + @register_func("tests.disco.shape_tuple", override=True) + def my_str_func(x: ShapeTuple): # pylint: disable=invalid-name + assert isinstance(x, ShapeTuple) + return ShapeTuple(list(x) + [4, 5]) + + sess = di.ThreadedSession(num_workers=num_workers) + func: di.DPackedFunc = sess.get_global_func("tests.disco.shape_tuple") + result: di.DRef = func(ShapeTuple([1, 2, 3])) + + for i in range(num_workers): + value = result.debug_get_from_remote(i) + assert isinstance(value, ShapeTuple) + assert list(value) == [1, 2, 3, 4, 5] + + def test_vm_module(): num_workers = 4 @@ -210,6 +247,8 @@ def transpose_2( test_int() test_float() test_string() + test_string_obj() + test_shape_tuple() test_ndarray() test_vm_module() test_vm_multi_func()