Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
convert constant back to original before trace apply
  • Loading branch information
masahi committed Oct 7, 2022
commit 4c9d91757afc0228e1ed0c9f9a0d07f8629c1948
24 changes: 19 additions & 5 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -334,20 +334,20 @@ class LayoutFreeConstantCollector : public StmtVisitor {
};

// HACK
using NDArrayMap =
std::unordered_map<runtime::NDArray, runtime::NDArray, ObjectPtrHash, ObjectPtrEqual>;

class AllocateConstReplaceConstant : public StmtExprMutator {
public:
static PrimFunc Rewrite(PrimFunc f,
const std::unordered_map<runtime::NDArray, runtime::NDArray,
ObjectPtrHash, ObjectPtrEqual>& constant_map) {
static PrimFunc Rewrite(PrimFunc f, const NDArrayMap& constant_map) {
AllocateConstReplaceConstant rewriter;
rewriter.constant_map_ = constant_map;
PrimFuncNode* n = f.CopyOnWrite();
n->body = rewriter(std::move(n->body));
return f;
}

std::unordered_map<runtime::NDArray, runtime::NDArray, ObjectPtrHash, ObjectPtrEqual>
constant_map_;
NDArrayMap constant_map_;

private:
Stmt VisitStmt_(const AllocateConstNode* op) final {
Expand Down Expand Up @@ -441,6 +441,7 @@ class ScheduleBuilder : public ExprVisitor {

static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout");
TuningRecord record = opt_record.value();
NDArrayMap constant_map;
for (const Instruction& inst : record->trace->insts) {
if (inst->kind.same_as(kind_transform_layout)) {
ICHECK_EQ(inst->attrs.size(), 4);
Expand All @@ -456,18 +457,31 @@ class ScheduleBuilder : public ExprVisitor {
TuningRecord new_rec(record->trace, workload, record->run_secs, record->target,
record->args_info);
database_.value()->CommitTuningRecord(new_rec);
} else {
ICHECK(index_map->inverse_index_map);
auto inverse_map = Downcast<IndexMap>(index_map->inverse_index_map.value());
ICHECK(constant.Shape().size() == inverse_map->initial_indices.size());
runtime::NDArray orig_constant = inverse_map->MapNDArray(constant);
auto f_ = AllocateConstReplaceConstant().Rewrite(f.value(),
{{constant, orig_constant}});
constant_map[orig_constant] = constant;
query_mod = backend::PrimFuncToIRModule(f_);
}
}
MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast<IndexMap>(inst->attrs[2]));
}
}

Schedule sch = Schedule::Traced(query_mod, /*seed=*/-1, /*debug_mask=*/0,
tir::ScheduleErrorRenderLevel::kDetail);
record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false);
IRModule mod = sch->mod();
ICHECK_EQ(mod->functions.size(), 1);
mod = tir::transform::RemoveWeightLayoutRewriteBlock()(std::move(mod));
prim_func = Downcast<PrimFunc>(mod->Lookup("main"));
if (!constant_map.empty()) {
prim_func = AllocateConstReplaceConstant().Rewrite(prim_func, constant_map);
}
} else {
int dispatch = backend::UseMetaScheduleDispatch();
// (dispatch & 2): controls whether to print TVMScript for missing TIR
Expand Down