Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 221 additions & 6 deletions src/relay/transforms/device_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@
* 'result_virtual_device' function attributes we introduce below. This is so the pass is
* idempotent and can be re-run to flow additional memory scope constraints.
*
* We proceed in four phases:
* We proceed in five phases:
*
* Phase 0
* -------
Expand All @@ -77,6 +77,13 @@
*
* Phase 1
* -------
* We iteratively process the programs and find nodes with conflicting virtual devices. If the
* virtual devices ( \p d1 and \p d2 ) are joinable, they are replaced with a joined device \p d. If
* they are unjoinable, a "device_copy" CallNode is inserted to copy the node output to the second
* device.
*
* Phase 2
* -------
* We flow constraints from the "on_device" and "device_copy" calls, PrimFunc buffer memory scopes,
* and some special ops, to all other Relay sub-expressions.
*
Expand Down Expand Up @@ -109,7 +116,7 @@
* devices from their original Relay Function representations. However we know all calls to those
* functions are device-consistent, thus no information is lost.
*
* Phase 2
* Phase 3
* -------
* After flowing constraints we apply some defaulting heuristics (using a global default \p
* VirtualDevice) to fix the device for any as-yet unconstrained sub-expressions.
Expand All @@ -121,7 +128,7 @@
* This requires a formal notion of 'choicepoint' inside the compiler which can integrate with
* automation.
*
* Phase 3
* Phase 4
* -------
* Finally, the result of this analysis is reified into the result as:
* - Additional "param_virtual_devices" (an \p Array<VirtualDevice>) and "result_virtual_device"
Expand Down Expand Up @@ -404,6 +411,201 @@ class RewriteOnDevices : public ExprMutator {

/* =============== Phase 1 =============== */

/*!
* \brief Add "device_copy" calls for nodes that have conflicting virtual devices.
*
* Eg Suppose an IRModule contains the following expr:
* \code
* %0 = add(%a, %b);
* %1 = on_device(%0, virtual_device=d1);
* %2 = add(%b, %c);
* %3 = on_device(%2, virtual_device=d2);
* \endcode
* In the above example, node %b has two possible virtual devices: \p d1 and \p d2.
*
* - If \p d1 and \p d2 are joinable, replace \p d1 and \p d2 with the joined device \p d:
* \code
* %0 = add(%a, %b);
* %1 = on_device(%0, virtual_device=d);
* %2 = add(%b, %c);
* %3 = on_device(%2, virtual_device=d);
* \endcode
*
* - If \p d1 and \p d2 are unjoinable, insert a "device_copy" CallNode to copy \p %b to \p d2:
* \code
* %0 = add(%a, %b);
* %1 = on_device(%0, virtual_device=d);
* %2 = device_copy(%b, src_dev_type=d1, dst_dev_type=d2);
* %3 = add(%2, %c);
* %4 = on_device(%3, virtual_device=d);
* \endcode
*/
struct DeviceContext {
VirtualDevice VirtualDeviceFor(const ExprNode* expr) {
auto itr = expr_to_device.find(expr);
if (itr != expr_to_device.end()) {
return itr->second;
}
auto default_dev = VirtualDevice::FullyUnconstrained();
expr_to_device.emplace(expr, default_dev);
return default_dev;
}

bool Update(const ExprNode* expr, VirtualDevice dev) {
bool success = true;
auto pair = expr_to_device.emplace(expr, dev);
if (!pair.second) {
auto replaced_item = pair.first;
auto joined_dev = VirtualDevice::Join(replaced_item->second, dev);
if (joined_dev == nullptr) {
success = false;
} else {
replaced_item->second = joined_dev.value();
}
}
return success;
}

bool IsConflicted(const ExprNode* expr) {
auto itr = conflicted_nodes.find(expr);
return itr != conflicted_nodes.end();
}

std::unordered_set<const ExprNode*> conflicted_nodes;
std::unordered_map<const ExprNode*, VirtualDevice> expr_to_device;
};

/*!
* \brief Flow the device constraints over the module and find all the conflicted nodes. The
* conflicted nodes only contain nodes that have no explicit constraints. For example, "on_device"
* nodes are not considered as conflicted.
*/
class ConflictedNodeFinder : ExprVisitor {
public:
explicit ConflictedNodeFinder(IRModule mod)
: mod_(std::move(mod)), dev_ctx_(std::make_unique<DeviceContext>()) {}

std::unique_ptr<DeviceContext> Finder() {
VLOG_CONTEXT << "ConflictedNodeFinder";
for (const auto& kv : mod_->functions) {
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
VisitExpr(GetRef<Function>(function_node));
}
}
for (auto const node : dev_ctx_->conflicted_nodes) {
if (node->IsInstance<CallNode>()) {
auto call = Downcast<Call>(GetRef<Expr>(node));
// "DeviceCapturer" will insert "device_copy" for "on_device" calls.
// Therefore, "on_device" should not be considered as conflicted.
if (call->op == OnDeviceOp()) {
dev_ctx_->conflicted_nodes.erase(node);
}
}
}
return std::move(dev_ctx_);
}

private:
void VisitExpr_(const CallNode* call_node) final {
VLOG(2) << "Initial call node: " << std::endl << PrettyPrint(GetRef<Call>(call_node));
auto call_dev = dev_ctx_->VirtualDeviceFor(call_node);
auto body_dev = call_dev;

auto on_dev_props = GetOnDeviceProps(call_node);
auto dev_cp_props = GetDeviceCopyProps(call_node);
if (call_node->op == OnDeviceOp()) {
if (on_dev_props.constrain_body) {
body_dev = on_dev_props.virtual_device;
}
if (on_dev_props.constrain_result) {
call_dev = on_dev_props.virtual_device;
}
} else if (call_node->op == DeviceCopyOp()) {
body_dev = dev_cp_props.src_virtual_device;
call_dev = dev_cp_props.dst_virtual_device;
}

if (!dev_ctx_->Update(call_node, call_dev) && call_node->op != OnDeviceOp()) {
LOG(FATAL) << "Mismatched device type after iterating args. Implied device: " << std::endl
<< PrettyPrint(call_dev) << "and practial device:" << std::endl
<< PrettyPrint(dev_ctx_->VirtualDeviceFor(call_node)) << std::endl
<< "With CallNode: " << std::endl
<< PrettyPrint(GetRef<Call>(call_node));
}

for (auto& arg : call_node->args) {
VLOG(3) << "Handle call node arg: " << std::endl << PrettyPrint(arg);
if (!dev_ctx_->Update(arg.get(), body_dev)) {
VLOG(2) << "Conflicted node found:" << std::endl
<< PrettyPrint(GetRef<Expr>(arg.get())) << std::endl
<< "With corresponding Callee:" << std::endl
<< PrettyPrint(GetRef<Call>(call_node));
dev_ctx_->conflicted_nodes.emplace(arg.get());
}
}
for (auto& expr : call_node->args) {
VisitExpr(expr);
}
}

IRModule mod_;
std::unique_ptr<DeviceContext> dev_ctx_;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why need pointer here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason is similar to the PlanDevicesCore sub-pass, which uses a pointer for DeviceDomains to prevent unnecessary copying. Since the necessary information is contained in dev_ctx_, which is created in ConflictedNodeFinder and then passed to ConflictedNodeRewriter, we also use a pointer here.

};

/*!
* \brief Insert "device_copy" CallNode for all the conflicted nodes found by \p
* ConflictedNodeFinder.
*/
class ConflictedNodeRewriter : ExprMutator {
public:
ConflictedNodeRewriter(IRModule mod, CompilationConfig config,
std::unique_ptr<DeviceContext> dev_ctx)
: mod_(mod), config_(config), dev_ctx_(std::move(dev_ctx)) {}

IRModule Rewrite() {
VLOG_CONTEXT << "ConflictedNodeRewriter";
IRModule result(/*functions=*/{}, mod_->type_definitions, mod_->Imports(), mod_->source_map,
mod_->attrs);
for (const auto& kv : mod_->functions) {
if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
auto func = Mutate(GetRef<Function>(function_node));
result->Add(kv.first, Downcast<Function>(func));
} else {
result->Add(kv.first, kv.second);
}
}

return result;
}

private:
Expr VisitExpr_(const CallNode* call_node) final {
VLOG(3) << "Initial call node:" << std::endl << PrettyPrint(GetRef<Call>(call_node));
auto call = Downcast<Call>(ExprMutator::VisitExpr_(call_node));
tvm::Array<Expr> call_args;
call_args.reserve(call_node->args.size());
for (auto arg : call->args) {
if (dev_ctx_->IsConflicted(arg.get())) {
auto src_dev = config_->CanonicalVirtualDevice(dev_ctx_->VirtualDeviceFor(arg.get()));
auto dst_dev = config_->CanonicalVirtualDevice(dev_ctx_->VirtualDeviceFor(call_node));
call_args.push_back(MaybeDeviceCopy(arg, src_dev, dst_dev));
VLOG(2) << "Adding DeviceCopy Op: " << std::endl << PrettyPrint(call_args.back());
} else {
call_args.push_back(arg);
}
}
auto new_call = WithFields(GetRef<Call>(call_node), call_node->op, call_args);
VLOG(3) << "Final call node:" << std::endl << PrettyPrint(GetRef<Call>(call_node));
return new_call;
}

IRModule mod_;
CompilationConfig config_;
std::unique_ptr<DeviceContext> dev_ctx_;
};

/* =============== Phase 2 =============== */

/*
* \brief Collects the system of device constraints for all sub-expressions in a module.
* It is possible some devices remain free and will need to be defaulted by \p DeviceDefaulter.
Expand Down Expand Up @@ -707,7 +909,7 @@ class DeviceAnalyzer : public MixedModeVisitor {
std::unique_ptr<DeviceDomains> domains_;
};

/* =============== Phase 2 =============== */
/* =============== Phase 3 =============== */

/*!
* \brief Calls to 'free' "on_device" annotations (ie where both constrain_body=false and
Expand Down Expand Up @@ -865,7 +1067,7 @@ class DeviceDefaulter : public ExprVisitor {
std::unique_ptr<DeviceDomains> domains_;
};

/* =============== Phase 3 =============== */
/* =============== Phase 4 =============== */
/*!
* \brief Inserts missing "device_copy" CallNodes, and ensures the device type of every
* sub-expression in a module can be easily recovered by a later transformation using simple
Expand Down Expand Up @@ -1276,6 +1478,17 @@ tvm::transform::Pass Rewrite() {
return tvm::relay::transform::CreateFunctionPass(pass_func, 0, "PlanDevicesRewrite", {});
}

/*! \brief Check the conflicted nodes and add "device_copy" calls. */
tvm::transform::Pass Check(CompilationConfig config) {
return tvm::transform::CreateModulePass(
[config = std::move(config)](IRModule mod,
tvm::transform::PassContext pass_cnxt) -> IRModule {
auto dev_ctx = ConflictedNodeFinder(mod).Finder();
return ConflictedNodeRewriter(mod, config, std::move(dev_ctx)).Rewrite();
},
/*opt_level=*/0, "PlanDevicesCheckConflicts", {});
}

/*! \brief Run the remaining phases. */
tvm::transform::Pass PlanDevicesCore(CompilationConfig config) {
return tvm::transform::CreateModulePass(
Expand Down Expand Up @@ -1308,7 +1521,9 @@ tvm::transform::Pass PlanDevicesCore(CompilationConfig config) {
tvm::transform::Pass PlanDevices(CompilationConfig config) {
std::vector<Pass> passes;
passes.emplace_back(Rewrite());
passes.emplace_back(PlanDevicesCore(std::move(config)));
passes.emplace_back(Check(config));
passes.emplace_back(InferType());
passes.emplace_back(PlanDevicesCore(config));
return tvm::transform::Sequential(passes, "PlanDevices");
}

Expand Down
47 changes: 47 additions & 0 deletions tests/python/relay/test_pass_plan_devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -1830,5 +1830,52 @@ def @main(%data1: Tensor[(1, 32, 40, 40), float32],
print(mod)


def test_conflicated_inputs():
metatable = {"VirtualDevice": [CPU, GPU]}

def input():
return tvm.relay.parse(
"""
#[version = "0.0.5"]
def @main(%a: Tensor[(5, 7), float32], %b: Tensor[(5, 7), float32],
%c: Tensor[(5, 7), float32]) {
%0 = add(%a, %b);
%1 = on_device(%0, virtual_device=meta[VirtualDevice][0]);
%2 = add(%b, %c);
%3 = on_device(%2, virtual_device=meta[VirtualDevice][1]);
subtract(%1, %3)
}
""",
"from_string",
None,
metatable,
)

def expected():
return tvm.relay.parse(
"""
#[version = "0.0.5"]
def @main(%a {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32],
%b {virtual_device=meta[VirtualDevice][0]}: Tensor[(5, 7), float32],
%c {virtual_device=meta[VirtualDevice][1]}: Tensor[(5, 7), float32]) {
%0 = add(%a, %b);
%1 = on_device(%0, virtual_device=meta[VirtualDevice][0], constrain_result=True);
%2 = device_copy(%b, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%3 = device_copy(%1, src_virtual_device=meta[VirtualDevice][0], dst_virtual_device=meta[VirtualDevice][1]);
%4 = add(%2, %c);
subtract(%3, %4)
}
""",
"from_string",
None,
metatable,
)

def ref(a, b, c):
return np.subtract(np.add(a, b), np.add(b, c))

exercise(input(), expected(), ref, rands((5, 7), 3))


if __name__ == "__main__":
tvm.testing.main()