Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions src/runtime/disco/threaded_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
uint64_t GetObjectBytes(Object* obj) {
if (obj->IsInstance<DRefObj>()) {
return sizeof(uint32_t) + sizeof(int64_t);
} else if (obj->IsInstance<StringObj>()) {
uint64_t size = static_cast<StringObj*>(obj)->size;
return sizeof(uint32_t) + sizeof(uint64_t) + size * sizeof(char);
} else if (obj->IsInstance<ShapeTupleObj>()) {
uint64_t ndim = static_cast<ShapeTupleObj*>(obj)->size;
return sizeof(uint32_t) + sizeof(uint64_t) + ndim * sizeof(ShapeTupleObj::index_type);
Expand All @@ -124,13 +127,16 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
int64_t reg_id = static_cast<DRefObj*>(obj)->reg_id;
this->Write<uint32_t>(TypeIndex::kRuntimeDiscoDRef);
this->Write<int64_t>(reg_id);
} else if (obj->IsInstance<StringObj>()) {
StringObj* str = static_cast<StringObj*>(obj);
this->Write<uint32_t>(TypeIndex::kRuntimeString);
this->Write<uint64_t>(str->size);
this->WriteArray<char>(str->data, str->size);
} else if (obj->IsInstance<ShapeTupleObj>()) {
ShapeTupleObj* shape = static_cast<ShapeTupleObj*>(obj);
this->Write<uint32_t>(TypeIndex::kRuntimeShapeTuple);
this->Write<uint64_t>(shape->size);
for (uint64_t i = 0; i < shape->size; ++i) {
this->Write<ShapeTupleObj::index_type>(shape->data[i]);
}
this->WriteArray<ShapeTupleObj::index_type>(shape->data, shape->size);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: "
<< obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")";
Expand All @@ -146,13 +152,17 @@ class DiscoThreadedMessageQueue : public dmlc::Stream {
this->Read<int64_t>(&dref->reg_id);
dref->session = Session{nullptr};
result = ObjectRef(std::move(dref));
} else if (type_index == TypeIndex::kRuntimeString) {
uint64_t size = 0;
Comment thread
tqchen marked this conversation as resolved.
this->Read<uint64_t>(&size);
std::string data(size, '\0');
this->ReadArray<char>(data.data(), size);
result = String(std::move(data));
} else if (type_index == TypeIndex::kRuntimeShapeTuple) {
uint64_t ndim = 0;
this->Read<uint64_t>(&ndim);
std::vector<ShapeTupleObj::index_type> data(ndim);
for (ShapeTupleObj::index_type& i : data) {
this->Read<ShapeTupleObj::index_type>(&i);
}
this->ReadArray<ShapeTupleObj::index_type>(data.data(), ndim);
result = ShapeTuple(std::move(data));
} else {
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: "
Expand Down
39 changes: 39 additions & 0 deletions tests/python/disco/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import tvm
from tvm import relax as rx
from tvm._ffi import register_func
from tvm.runtime import ShapeTuple, String
from tvm.runtime import disco as di
from tvm.script import ir as I
from tvm.script import relax as R
Expand Down Expand Up @@ -106,6 +107,42 @@ def my_str_func(x: str): # pylint: disable=invalid-name
assert result.debug_get_from_remote(i) == "hello_suffix"


def test_string_obj():
num_workers = 4

@register_func("tests.disco.str_obj", override=True)
def my_str_func(x: String): # pylint: disable=invalid-name
assert isinstance(x, String)
return String(x + "_suffix")

sess = di.ThreadedSession(num_workers=num_workers)
func: di.DPackedFunc = sess.get_global_func("tests.disco.str_obj")
result: di.DRef = func(String("hello"))

for i in range(num_workers):
value = result.debug_get_from_remote(i)
assert isinstance(value, String)
assert value == "hello_suffix"


def test_shape_tuple():
num_workers = 4

@register_func("tests.disco.shape_tuple", override=True)
def my_str_func(x: ShapeTuple): # pylint: disable=invalid-name
assert isinstance(x, ShapeTuple)
return ShapeTuple(list(x) + [4, 5])

sess = di.ThreadedSession(num_workers=num_workers)
func: di.DPackedFunc = sess.get_global_func("tests.disco.shape_tuple")
result: di.DRef = func(ShapeTuple([1, 2, 3]))

for i in range(num_workers):
value = result.debug_get_from_remote(i)
assert isinstance(value, ShapeTuple)
assert list(value) == [1, 2, 3, 4, 5]


def test_vm_module():
num_workers = 4

Expand Down Expand Up @@ -210,6 +247,8 @@ def transpose_2(
test_int()
test_float()
test_string()
test_string_obj()
test_shape_tuple()
test_ndarray()
test_vm_module()
test_vm_multi_func()