diff --git a/include/tvm/runtime/disco/session.h b/include/tvm/runtime/disco/session.h index 3d4c3e4ea1a3..71fcce75b292 100644 --- a/include/tvm/runtime/disco/session.h +++ b/include/tvm/runtime/disco/session.h @@ -197,6 +197,8 @@ class SessionObj : public Object { * The thirtd element is the function to be called. */ TVM_DLL virtual DRef CallWithPacked(const TVMArgs& args) = 0; + /*! \brief Get the number of workers in the session. */ + TVM_DLL virtual int64_t GetNumWorkers() = 0; /*! \brief Get a global functions on workers. */ TVM_DLL virtual DRef GetGlobalFunc(const std::string& name) = 0; /*! diff --git a/python/tvm/runtime/disco/session.py b/python/tvm/runtime/disco/session.py index ee151db7166c..18329eb3f5bd 100644 --- a/python/tvm/runtime/disco/session.py +++ b/python/tvm/runtime/disco/session.py @@ -146,6 +146,11 @@ def shutdown(self): """Shut down the Disco session""" _ffi_api.SessionShutdown(self) # type: ignore # pylint: disable=no-member + @property + def num_workers(self) -> int: + """Return the number of workers in the session""" + return _ffi_api.SessionGetNumWorkers(self) # type: ignore # pylint: disable=no-member + def get_global_func(self, name: str) -> DRef: """Get a global function on workers. diff --git a/src/runtime/disco/process_session.cc b/src/runtime/disco/process_session.cc index 6474db479e94..dfcf36989c00 100644 --- a/src/runtime/disco/process_session.cc +++ b/src/runtime/disco/process_session.cc @@ -153,6 +153,8 @@ class ProcessSessionObj final : public BcastSessionObj { ~ProcessSessionObj() { Kill(); } + int64_t GetNumWorkers() { return workers_.size() + 1; } + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { if (worker_id == 0) { this->SyncWorker(worker_id); diff --git a/src/runtime/disco/session.cc b/src/runtime/disco/session.cc index e74d3819fe04..00f28a7b9f6a 100644 --- a/src/runtime/disco/session.cc +++ b/src/runtime/disco/session.cc @@ -37,6 +37,8 @@ TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugGetFromRemote") .set_body_method(&DRefObj::DebugGetFromRemote); TVM_REGISTER_GLOBAL("runtime.disco.DRefDebugCopyFrom") .set_body_method(&DRefObj::DebugCopyFrom); +TVM_REGISTER_GLOBAL("runtime.disco.SessionGetNumWorkers") + .set_body_method(&SessionObj::GetNumWorkers); TVM_REGISTER_GLOBAL("runtime.disco.SessionGetGlobalFunc") .set_body_method(&SessionObj::GetGlobalFunc); TVM_REGISTER_GLOBAL("runtime.disco.SessionCopyFromWorker0") diff --git a/src/runtime/disco/threaded_session.cc b/src/runtime/disco/threaded_session.cc index c1f2f8539337..7a76a45ed539 100644 --- a/src/runtime/disco/threaded_session.cc +++ b/src/runtime/disco/threaded_session.cc @@ -154,6 +154,8 @@ class ThreadedSessionObj final : public BcastSessionObj { workers_.clear(); } + int64_t GetNumWorkers() { return workers_.size(); } + TVMRetValue DebugGetFromRemote(int64_t reg_id, int worker_id) { this->SyncWorker(worker_id); return this->workers_.at(worker_id).worker->register_file.at(reg_id); diff --git a/tests/python/disco/test_session.py b/tests/python/disco/test_session.py index 40dcb04911c9..ef8ea2e70a25 100644 --- a/tests/python/disco/test_session.py +++ b/tests/python/disco/test_session.py @@ -220,6 +220,13 @@ def transpose_2( np.testing.assert_equal(z_nd, x_np) +@pytest.mark.parametrize("session_kind", _all_session_kinds) +@pytest.mark.parametrize("num_workers", [1, 2, 4]) +def test_num_workers(session_kind, num_workers): + sess = session_kind(num_workers=num_workers) + assert sess.num_workers == num_workers + + if __name__ == "__main__": test_int(di.ProcessSession) test_float(di.ProcessSession)