Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
move ctx back to vm
  • Loading branch information
zhiics committed Oct 16, 2019
commit e9c22386e66e00b6d2861ade8036bd3d680fddfc
30 changes: 14 additions & 16 deletions include/tvm/runtime/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,17 +484,6 @@ class Executable : public ModuleNode {
*/
runtime::Module GetLib() const { return lib; }

/*!
* \brief Set the execution context for the executable.
*
* \param ctxs The list of TVMContext.
*/
void SetContext(const std::vector<TVMContext>& ctxs);

/*! \brief Get device context for params.
*/
TVMContext GetParamsContext() const;

virtual ~Executable() {}

const char* type_key() const final {
Expand All @@ -514,9 +503,6 @@ class Executable : public ModuleNode {
std::unordered_map<std::string, Index> primitive_map;
/*! \brief The virtual machine's function table. */
std::vector<VMFunction> functions;

/*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs;
};

/*! \brief The virtual machine.
Expand Down Expand Up @@ -591,6 +577,9 @@ class VirtualMachine : public runtime::ModuleNode {
/*! \brief The executable the VM will operate on. */
const Executable* exec;
Comment thread
zhiics marked this conversation as resolved.

/*! \brief The set of TVM contexts the VM is currently executing on. */
std::vector<TVMContext> ctxs;

/*! \brief Push a call frame on to the call stack. */
void PushFrame(Index arg_count, Index ret_pc, const VMFunction& vm_func);
/*! \brief Pop a frame off the call stack.
Expand Down Expand Up @@ -634,15 +623,24 @@ class VirtualMachine : public runtime::ModuleNode {

VirtualMachine() : frames(), func_index(0), code(nullptr), pc(0), exec(nullptr) {}

/*! \brief Initialize the virtual machine using an executable.
/*! \brief load the executable for the virtual machine.
* \param exec The executable.
*/
void Init(const Executable* exec);
void LoadExecutable(const Executable* exec);

/*! \brief Initialize the virtual machine for a set of contexts.
* \param contexts The set of TVM contexts.
*/
void Init(const std::vector<TVMContext>& contexts);

/*! \brief Run VM dispatch loop.
*/
void RunLoop();

/*! \brief Get device context for params.
*/
TVMContext GetParamsContext() const;

private:
/*! \brief Invoke a global setting up the VM state to execute.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relay/backend/profiler_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def __init__(self, mod):
super().__init__(mod)
m = mod.module if isinstance(mod, vm.Executable) else mod
self.mod = _vm._VirtualMachineDebug(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]
self._get_stat = self.mod["get_stat"]

Expand Down
3 changes: 1 addition & 2 deletions python/tvm/relay/backend/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ def serialize(self):
ctx = tvm.cpu()
target = "llvm"
executable = relay.vm..compile(mod, target)
executable.set_context(ctx)

# serialize.
ser = relay.serializer.Serializer(executable)
Expand All @@ -117,7 +116,7 @@ def serialize(self):
des_exec = deser.deserialize()

# execute the deserialized executable.
des_exec.set_context(ctx)
des_vm.init(ctx)
x_data = np.random.rand(10, 10).astype('float32')
des_vm = relay.vm.VirtualMachine(des_exec)
res = des_vm.run(x_data)
Expand Down
42 changes: 13 additions & 29 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@

import tvm
from tvm import autotvm
from tvm import TVMContext
from tvm.relay import expr as _expr
from . import _vm
from . import vmobj as _obj
Expand Down Expand Up @@ -57,37 +56,10 @@ class Executable(object):
"""Relay VM executable"""
def __init__(self, mod):
self.mod = mod
self._set_context = self.mod["set_context"]
self._get_lib = self.mod["get_lib"]
self._get_bytecode = self.mod["get_bytecode"]
self._get_stats = self.mod["get_stats"]

def set_context(self, ctx):
"""Initialize the context of the VM executable.

Parameters
----------
ctx : Union[:py:class:`tvm.TVMContext`, List[py:class:`tvm.TVMContext`]]
The runtime context to run the code on.
"""

if isinstance(ctx, TVMContext):
ctx = [ctx]
elif not isinstance(ctx, (list, tuple)):
raise ValueError("ctx has to be the type of TVMContext or a list of "
"TVMContext")
# args[0], args[1] are used as the primary/fallback context type and id
# for heterogeneous execution.
args = []
for cur_ctx in ctx:
if not isinstance(cur_ctx, TVMContext):
raise ValueError("ctx has to be the type of TVMContext or a list "
"of TVMContext")
args.append(cur_ctx.device_type)
args.append(cur_ctx.device_id)

self._set_context(*args)

@property
def lib(self):
"""Get the library that contains hardware dependent code.
Expand Down Expand Up @@ -182,8 +154,20 @@ def __init__(self, mod):
"tvm.Module, but received {}".format(type(mod)))
m = mod.module if isinstance(mod, Executable) else mod
self.mod = _vm._VirtualMachine(m)
self._init = self.mod["init"]
self._invoke = self.mod["invoke"]

def init(self, ctx):
"""Initialize the context in the VM.

Parameters
----------
ctx : :py:class:`TVMContext`
The runtime context to run the code on.
"""
args = [ctx.device_type, ctx.device_id]
self._init(*args)

def invoke(self, func_name, *args):
"""Invoke a function.

Expand Down Expand Up @@ -344,8 +328,8 @@ def __init__(self, mod, ctx, target):
self.ctx = ctx
self.target = target
self.executable = compile(mod, target)
self.executable.set_context(ctx)
self.vm = VirtualMachine(self.executable)
self.vm.init(ctx)

def _make_executor(self, expr=None):
main = self.mod["main"]
Expand Down
15 changes: 0 additions & 15 deletions src/runtime/vm/deserializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,6 @@ void Deserializer::Deserialize() {

// Code section.
DeserializeCodeSection();

// Context section.
DeserializeContextSection();
}

void Deserializer::DeserializeGlobalSection() {
Expand Down Expand Up @@ -313,18 +310,6 @@ void Deserializer::DeserializeCodeSection() {
}
}

void Deserializer::DeserializeContextSection() {
std::vector<uint64_t> ctxs;
STREAM_CHECK(strm_->Read(&ctxs), "context");
CHECK_EQ(ctxs.size() % 2, 0U);
for (size_t i = 0; i < ctxs.size(); i += 2) {
TVMContext ctx;
ctx.device_type = DLDeviceType(ctxs[i]);
ctx.device_id = static_cast<int>(ctxs[i + 1]);
exec_->ctxs.push_back(ctx);
}
}

runtime::Module CreateDeserializer(const std::string& code, const runtime::Module lib) {
std::shared_ptr<Deserializer> exec = std::make_shared<Deserializer>();
exec->Init(code, lib);
Expand Down
34 changes: 1 addition & 33 deletions src/runtime/vm/executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,7 @@ namespace vm {

PackedFunc Executable::GetFunction(const std::string& name,
const std::shared_ptr<ModuleNode>& sptr_to_self) {
if (name == "set_context") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size() % 2, 0);
std::vector<TVMContext> contexts;
for (int i = 0; i < args.size() / 2; ++i) {
TVMContext ctx;
int device_type = args[i * 2];
ctx.device_type = DLDeviceType(device_type);
ctx.device_id = args[i * 2 + 1];
contexts.push_back(ctx);
}
this->SetContext(contexts);
});
} else if (name == "get_lib") {
if (name == "get_lib") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
*rv = this->GetLib();
});
Expand All @@ -70,10 +57,6 @@ PackedFunc Executable::GetFunction(const std::string& name,
}
}

inline void Executable::SetContext(const std::vector<TVMContext>& ctxs) {
this->ctxs = ctxs;
}

std::string Executable::GetBytecode() const {
std::ostringstream oss;

Expand Down Expand Up @@ -166,20 +149,6 @@ std::string Executable::Stats() const {
return oss.str();
}

TVMContext Executable::GetParamsContext() const {
CHECK(!ctxs.empty()) << "context has not been set yet.";

// Use the fallback device if no device index is available.
int fallback_device_type = static_cast<int>(ctxs[0].device_type);
// TODO(wweic): For heterogeneous execution, get device information from byte

const auto& cit =
std::find_if(ctxs.begin(), ctxs.end(), [&fallback_device_type](const TVMContext& c) {
return fallback_device_type == static_cast<int>(c.device_type);
});
return (cit == ctxs.end() ? ctxs[0] : *cit);
}

TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
Expand All @@ -188,7 +157,6 @@ TVM_REGISTER_GLOBAL("relay._vm.GetNumOfGlobals")
*rv = static_cast<int>(exec->global_map.size());
});


TVM_REGISTER_GLOBAL("relay._vm.GetGlobalFields")
.set_body([](TVMArgs args, TVMRetValue* rv) {
runtime::Module mod = args[0];
Expand Down
25 changes: 21 additions & 4 deletions src/runtime/vm/profiler/vm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,43 @@ PackedFunc VirtualMachineDebug::GetFunction(
os << "Total Duration " << total_duration << " us" << std::endl;
*rv = os.str();
});
} else if (name == "init") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
CHECK_EQ(args.size() % 2, 0);
std::vector<TVMContext> contexts;
for (int i = 0; i < args.size() / 2; ++i) {
TVMContext ctx;
int device_type = args[i * 2];
ctx.device_type = DLDeviceType(device_type);
ctx.device_id = args[i * 2 + 1];
contexts.push_back(ctx);
}
this->Init(contexts);
});
} else {
return VirtualMachine::GetFunction(name, sptr_to_self);
}
}

void VirtualMachineDebug::Init(const Executable* exec) {
VirtualMachine::Init(exec);
void VirtualMachineDebug::LoadExecutable(const Executable* exec) {
VirtualMachine::LoadExecutable(exec);
CHECK(this->exec);
for (auto kv : this->exec->primitive_map) {
packed_index_map[kv.second] = kv.first;
op_invokes[kv.second] = 0;
}
}

void VirtualMachineDebug::Init(const std::vector<TVMContext>& ctxs) {
VirtualMachine::Init(ctxs);
}

void VirtualMachineDebug::InvokePacked(Index packed_index,
const PackedFunc& func, Index arg_count,
Index output_size,
const std::vector<ObjectRef>& args) {
CHECK(this->exec);
auto ctx = this->exec->GetParamsContext();
auto ctx = this->GetParamsContext();
// warmup
VirtualMachine::InvokePacked(packed_index, func, arg_count, output_size,
args);
Expand All @@ -108,7 +125,7 @@ void VirtualMachineDebug::InvokePacked(Index packed_index,

runtime::Module CreateVirtualMachineDebug(const Executable* exec) {
std::shared_ptr<VirtualMachineDebug> vm = std::make_shared<VirtualMachineDebug>();
vm->Init(exec);
vm->LoadExecutable(exec);
return runtime::Module(vm);
}

Expand Down
4 changes: 3 additions & 1 deletion src/runtime/vm/profiler/vm.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ class VirtualMachineDebug : public VirtualMachine {
void InvokePacked(Index packed_index, const PackedFunc& func, Index arg_count,
Index output_size, const std::vector<ObjectRef>& args) final;

void Init(const Executable* exec);
void LoadExecutable(const Executable* exec);

~VirtualMachineDebug() {}

private:
void Init(const std::vector<TVMContext>& ctxs);

std::unordered_map<Index, std::string> packed_index_map;
std::unordered_map<Index, std::vector<double>> op_durations;
std::unordered_map<Index, int> op_invokes;
Expand Down
13 changes: 0 additions & 13 deletions src/runtime/vm/serializer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,6 @@ TVMByteArray Serializer::Serialize() {
// Code section.
SerializeCodeSection();

// Context section.
SerializeContextSection();

TVMByteArray arr;
arr.data = code_.c_str();
arr.size = code_.length();
Expand Down Expand Up @@ -300,16 +297,6 @@ void Serializer::SerializeCodeSection() {
}
}

void Serializer::SerializeContextSection() {
CHECK(!exec_->ctxs.empty());
std::vector<uint64_t> serialized_ctx;
for (const auto& ctx : exec_->ctxs) {
serialized_ctx.push_back(static_cast<uint64_t>(ctx.device_type));
serialized_ctx.push_back(static_cast<uint64_t>(ctx.device_id));
}
strm_->Write(serialized_ctx);
}

runtime::Module CreateSerializer(const Executable* exec) {
std::shared_ptr<Serializer> serializer = std::make_shared<Serializer>();
serializer->Init(exec);
Expand Down
6 changes: 0 additions & 6 deletions src/runtime/vm/serializer.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
* - The `primitive_map` that contains the name of individual primitive operators.
* - The `functions`, e.g., the `VMFunction`. Each `VMFunction` is composed of
* a list of instructions/bytecode.
* - The `ctxs` that contains the device context used to execute the hardware
* dependent code.
*
* Note that only the library is returned as a separate module. All othere parts
* are stored in a single serialized code that is organized with the following
Expand All @@ -43,7 +41,6 @@
* - Primitive name section, containing the function name of the primitive ops
* used by the virtual machine.
* - Code section, handling the VM functions and bytecode.
* - Context section, saving the context information.
*
* The code section is again organized as follows for each VM function:
* func_name, register_file_size, num_instructions (N)
Expand Down Expand Up @@ -136,9 +133,6 @@ class Serializer : public runtime::ModuleNode {
/*! \brief Serialize the vm functions in exec_. */
void SerializeCodeSection();

/*! \brief Serialize the context in exec_. */
void SerializeContextSection();

/*! \brief The Relay virtual machine executable to be serialized. */
const Executable* exec_;

Expand Down
Loading