Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 20 additions & 7 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,20 +243,33 @@ def build(

if not isinstance(inputs, (dict, container.Map)):
target = Target.current() if target is None else target
target = target if target else "llvm"
target_input_mod = {target: input_mod}
if target is None and isinstance(input_mod, tvm.IRModule):
target_mod = {}
for gvar, func in input_mod.functions.items():
tgt = func.attrs["target"] if func.attrs and "target" in func.attrs else "llvm"
if tgt not in target_mod:
target_mod[tgt] = {}
target_mod[tgt][gvar] = func

target_input_mod = {}
for tgt in target_mod.keys():
tir_mod = tvm.IRModule(target_mod[tgt])
tir_mod.with_attrs(input_mod.attrs)
target_input_mod[tgt] = tir_mod
else:
target_input_mod = {target: input_mod}
else:
target_input_mod = inputs
target_input_mod = {tgt: lower(mod) for tgt, mod in inputs.items()}

# Because modules can be created from a variety of sources, we annotate them
# with the relevant attributes here to ensure they propagate
annotated_mods = {}
for tar, mod in target_input_mod.items():
if not isinstance(tar, (str, Target)):
for tgt, mod in target_input_mod.items():
if not isinstance(tgt, (str, Target)):
raise ValueError("The key of inputs must be str or " "Target when inputs is dict.")
if not isinstance(mod, tvm.IRModule):
raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")
annotated_mods[tar] = mod.with_attr("runtime", runtime)
raise ValueError("inputs must be Schedule, IRModule, " "or dict of str to IRModule.")
annotated_mods[tgt] = mod.with_attr("runtime", runtime)

# TODO(mbs): Both CompilationConfig and TIRToRuntime implement the same host target
# defaulting logic, but there's currently no way to get back the decided host.
Expand Down
26 changes: 24 additions & 2 deletions python/tvm/relax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from .expr import Tuple as rx_Tuple
from .expr import Expr, ShapeExpr, Function, PrimValue, StringImm, te_tensor
from ..te import Tensor as te_Tensor, create_prim_func
from ..ir import Array, Attrs, Type, Map
from ..ir import Array, Attrs, Type, Map, VDevice
from .struct_info import PrimStructInfo, ShapeStructInfo, TensorStructInfo


Expand Down Expand Up @@ -418,6 +418,24 @@ def _populate_used_vars(expr):
diff = used_vars - bound_vars
return list(diff)

def _get_vdevice(arg: Any) -> Optional[VDevice]:
"""get the virtual device from arguments."""
vdevice = None
if isinstance(arg, Expr): # type: ignore
if isinstance(arg.struct_info, TensorStructInfo):
vdevice = arg.struct_info.vdevice
elif isinstance(arg, (list, Array, tuple)):
for x in arg:
vdevice = _get_vdevice(x)
if vdevice is not None:
return vdevice
elif isinstance(arg, (dict, Map)):
for k in arg:
vdevice = _get_vdevice(arg[k])
if vdevice is not None:
return vdevice
return vdevice

def _shape_with_old_tir_var(
shape_values: List[tir.PrimExpr], tir_var_inverse_map: Dict[tir.Var, tir.PrimExpr]
):
Expand Down Expand Up @@ -456,7 +474,11 @@ def _shape_with_old_tir_var(
tir_var_inverse_map = {v: k for k, v in tir_var_map.items()}

output_sinfo = [
TensorStructInfo(_shape_with_old_tir_var(out.shape, tir_var_inverse_map), out.dtype)
TensorStructInfo(
_shape_with_old_tir_var(out.shape, tir_var_inverse_map),
out.dtype,
_get_vdevice(args),
)
for out in outs
]

Expand Down
32 changes: 20 additions & 12 deletions python/tvm/relax/vm_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def jit(self, fcompile=None, addons=None, **kwargs) -> tvm.runtime.Module:
vm = tvm.relax.VirtualMachine(rt_mod, tvm.cuda())
"""

# TODO(tvm-team): Update runtime.Module interfac
# TODO(tvm-team): Update runtime.Module interface
# to query these properties as bitmask.
def _not_runnable(x):
return x.type_key in ("c", "static_library")
Expand Down Expand Up @@ -179,13 +179,17 @@ def _vmcodegen(
raise ValueError(f"Unknown exec_mode {exec_mode}")


def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):
def _autodetect_system_lib_req(
target: Optional[tvm.target.Target] = None, system_lib: Optional[bool] = None
):
"""Automatically detect system lib requirement"""
host = target if target.host is None else target.host
if system_lib is None:
system_lib = False
if "wasm" in host.attrs.get("mtriple", ""):
system_lib = True
if target is not None:
host = target if target.host is None else target.host
if system_lib is None:
system_lib = False
if "wasm" in host.attrs.get("mtriple", ""):
system_lib = True

if system_lib:
# use packed-func to avoid relay dep.
return tvm.get_global_func("relay.backend.CreateRuntime")("cpp", {"system-lib": system_lib})
Expand All @@ -194,7 +198,7 @@ def _autodetect_system_lib_req(target: tvm.target.Target, system_lib):

def _vmlink(
builder: "relax.ExecBuilder",
target: Union[str, tvm.target.Target],
target: Optional[Union[str, tvm.target.Target]],
tir_mod: Optional[tvm.IRModule] = None,
ext_libs: List[tvm.runtime.Module] = None,
params: Optional[Dict[str, list]] = None,
Expand All @@ -213,8 +217,10 @@ def _vmlink(
builder: relax.ExecBuilder
Builder used to collect executables.

target : Union[str, tvm.target.Target]
target : Optional[Union[str, tvm.target.Target]]
A build target which can have optional host side compilation target.
If the target is not specified, the target in the vdevice list will be used.
For multi-target compilation, the vdevice should be annotated.

tir_mod: IRModule
The input TIR IRModule to be linked together.
Expand All @@ -239,14 +245,16 @@ def _vmlink(
lib = None
if tir_mod is not None:
lib = tvm.build(
tir_mod, target=target, runtime=_autodetect_system_lib_req(target, system_lib)
tir_mod,
target=target,
runtime=_autodetect_system_lib_req(target, system_lib),
)
return Executable(_ffi_api.VMLink(builder, target, lib, ext_libs, params)) # type: ignore


def build(
mod: tvm.IRModule,
target: Union[str, tvm.target.Target],
target: Optional[Union[str, tvm.target.Target]] = None,
params: Optional[Dict[str, list]] = None,
pipeline: Union[None, str, tvm.transform.Pass] = "default_build",
exec_mode: str = "bytecode",
Expand All @@ -261,7 +269,7 @@ def build(
mod: IRModule
The input IRModule to be built.

target : Union[str, tvm.target.Target]
target : Optional[Union[str, tvm.target.Target]]
A build target which can have optional host side compilation target.

When TVM compiles device specific program such as CUDA,
Expand Down
7 changes: 1 addition & 6 deletions python/tvm/runtime/relax_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def __init__(

Parameters
----------
mod: Union[tvm.runtime.Module, tvm.relax.Executable]
rt_mod: Union[tvm.runtime.Module, tvm.relax.Executable]
Runtime module exported by the result of build.

device : Union[Device, List[Device]]
Expand Down Expand Up @@ -107,11 +107,6 @@ def _setup_device(self, dev: Device, memory_cfg: Union[str, Dict[Device, str]])
)
devs = [dev]

if any(dev.device_type % RPC_SESS_MASK == tvm.cpu().device_type for dev in devs[:-1]):
raise RuntimeError(
"CPU host is required to be the last element of the device list if provided."
)

# CPU is required for executing shape functions
if devs[-1].device_type % RPC_SESS_MASK != tvm.cpu().device_type:
devs.append(tvm.cpu())
Expand Down
20 changes: 20 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,16 @@ def _any_gpu_exists():
)


def _multi_gpu_exists():
return (
(tvm.cuda(0).exist and tvm.cuda(1).exist)
or (tvm.rocm(0).exist and tvm.rocm(1).exist)
or (tvm.opencl(0).exist and tvm.opencl(1).exist)
or (tvm.metal(0).exist and tvm.metal(1).exist)
or (tvm.vulkan(0).exist and tvm.vulkan(1).exist)
)


# Mark a test as requiring llvm to run
requires_llvm = Feature(
"llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm"
Expand All @@ -847,6 +857,16 @@ def _any_gpu_exists():
# :py:func:`tvm.testing.requires_gpu`.
uses_gpu = requires_gpu(support_required="optional")

# Mark a test as requiring multiple GPUs to run.
requires_multi_gpu = Feature("multi_gpu", run_time_check=_multi_gpu_exists)

# Mark to differentiate tests that use multiple GPUs in some capacity.
#
# These tests will be run on test nodes with multiple GPUs.
# To mark a test that must have multiple GPUs present to run, use
# :py:func:`tvm.testing.requires_multi_gpu`.
uses_multi_gpu = requires_multi_gpu(support_required="optional")

# Mark a test as requiring the x86 Architecture to run.
requires_x86 = Feature(
"x86", "x86 Architecture", run_time_check=lambda: platform.machine() == "x86_64"
Expand Down
2 changes: 1 addition & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ void IRModuleNode::Update(const IRModule& mod) {

IRModule IRModuleNode::ShallowCopy() {
return IRModule(this->functions, this->type_definitions, this->Imports(), this->source_map,
this->attrs);
this->attrs, this->global_infos);
}

std::pair<IRModule, GlobalVar> IRModule::FromExprInContext(
Expand Down
39 changes: 31 additions & 8 deletions src/relax/transform/call_tir_rewrite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
*/
/*!
* \file src/relax/transform/call_tir_rewrite.cc
* \brief Perform explicit tensor allocation for call_tir.
* \brief Perform explicit tensor allocation for call_tir,
* call_tir_inplace, and call_dps_packed.
*/
#include <tvm/relax/attrs/op.h>
#include <tvm/relax/expr_functor.h>
Expand All @@ -28,6 +29,7 @@
#include <tvm/tir/op.h>

#include "../../relay/transforms/pattern_utils.h"
#include "utils.h"

namespace tvm {
namespace relax {
Expand All @@ -43,6 +45,19 @@ namespace relax {

class CallTIRMutator : public ExprMutator {
public:
explicit CallTIRMutator(const IRModule& mod) : ExprMutator(mod), mod_(std::move(mod)) {}

IRModule Run() {
for (const auto& [gv, func] : mod_->functions) {
if (func->IsInstance<FunctionNode>()) {
auto updated_func = Downcast<Function>(this->VisitExpr(func));
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
return builder_->GetContextIRModule();
}

private:
using ExprMutator::VisitExpr_;
Expr VisitExpr_(const CallNode* call) override {
// post-order mutation
Expand All @@ -65,11 +80,15 @@ class CallTIRMutator : public ExprMutator {
const TensorStructInfo& tensor_sinfo = _tensor_sinfo.value();
ICHECK(tensor_sinfo->shape.defined())
<< "the TensorStructInfo shape of call_tir has not populated";
int dev_index = 0;
if (tensor_sinfo->vdevice.defined()) {
dev_index = GetDeviceIndex(mod_, tensor_sinfo->vdevice.value());
}
if (!is_inplace) {
outs.push_back(
builder_->Emit(Call(alloc_tensor_op, //
builder_->Emit(Call(alloc_tensor_op,
{Downcast<ShapeExpr>(tensor_sinfo->shape.value()),
DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(0)}, //
DataTypeImm(tensor_sinfo->dtype), PrimValue::Int64(dev_index)},
Attrs()),
"alloc"));
} else {
Expand Down Expand Up @@ -150,16 +169,20 @@ class CallTIRMutator : public ExprMutator {

return GetRef<Expr>(call);
}
};

Expr CallTIRRewrite(const Expr& e) { return CallTIRMutator().VisitExpr(e); }
/*! \brief The context IRModule. */
IRModule mod_;
};

namespace transform {

Pass CallTIRRewrite() {
runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
[=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(CallTIRRewrite(f)); };
return CreateFunctionPass(pass_func, 0, "CallTIRRewrite", {});
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return CallTIRMutator(mod).Run(); };
return CreateModulePass(/*pass_function=*/pass_func,
/*opt_level=*/0,
/*pass_name=*/"CallTIRRewrite",
/*required=*/{});
}

TVM_REGISTER_GLOBAL("relax.transform.CallTIRRewrite").set_body_typed(CallTIRRewrite);
Expand Down
42 changes: 42 additions & 0 deletions src/relax/transform/legalize_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <tvm/relax/analysis.h>
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/op_attr_types.h>
#include <tvm/relax/struct_info.h>
#include <tvm/relax/transform.h>

namespace tvm {
Expand Down Expand Up @@ -72,6 +73,14 @@ class LegalizeMutator : public ExprMutator {
builder_->UpdateFunction(gv, Downcast<BaseFunc>(updated_func));
}
}
// Fill the "kTarget" attribute of PrimFunc
for (const auto& [gv, func] : builder_->GetContextIRModule()->functions) {
const tir::PrimFuncNode* prim_func;
if (tmap_.count(gv) && (prim_func = func.as<tir::PrimFuncNode>())) {
auto f = WithAttr(GetRef<tir::PrimFunc>(prim_func), tvm::attr::kTarget, tmap_[gv]);
builder_->UpdateFunction(gv, f);
}
}
return builder_->GetContextIRModule();
}

Expand Down Expand Up @@ -109,6 +118,33 @@ class LegalizeMutator : public ExprMutator {
return Call(call_pure_packed_op, ret_args, ret->attrs, ret->sinfo_args);
}

Target GetTarget(const Array<StructInfo>& sinfos) {
for (auto sinfo : sinfos) {
if (const auto* tinfo = sinfo.as<TensorStructInfoNode>()) {
if (tinfo->vdevice.defined()) {
auto vdevice = tinfo->vdevice.value();
if (vdevice->target.defined()) {
return vdevice->target;
}
}
} else if (const auto* tup_sinfo = sinfo.as<TupleStructInfoNode>()) {
return GetTarget(tup_sinfo->fields);
}
}
return Target();
}

void SaveTarget(const Expr& expr) {
if (expr->IsInstance<CallNode>()) {
auto call = Downcast<Call>(expr);
auto target = GetTarget(call->sinfo_args);
const GlobalVarNode* gvar_node;
if (target.defined() && (gvar_node = call->args[0].as<GlobalVarNode>())) {
this->tmap_.Set(GetRef<GlobalVar>(gvar_node), target);
}
}
}

Expr VisitExpr_(const CallNode* call) final {
Call visited_call = Downcast<Call>(this->VisitExprPostOrder_(call));
static const auto& legalize_map = Op::GetAttrMap<FLegalize>("FLegalize");
Expand Down Expand Up @@ -164,6 +200,10 @@ class LegalizeMutator : public ExprMutator {
builder_->BeginBindingBlock();
}
Expr legalized = legalization_func(builder_, visited_call);

// Save the expected target info. into tmap_
SaveTarget(legalized);

legalized = builder_->Normalize(legalized);

BindingBlock prologue = builder_->EndBlock();
Expand Down Expand Up @@ -196,6 +236,8 @@ class LegalizeMutator : public ExprMutator {
IRModule mod_;
/*! \brief The customized legalization function map. */
Map<String, PackedFunc> cmap_;
/*! \brief The map from GlobalVar of PrimFunc to compilation Target. */
Map<GlobalVar, Target> tmap_;
/*!
* \brief A boolean value indicating if to print warnings for CallNode whose op's
* legalization function is not registered.
Expand Down
Loading