Skip to content
Prev Previous commit
Next Next commit
fix mutate compute location
  • Loading branch information
merrymercy committed Sep 18, 2020
commit 813dd7406f1ec27beea84b2122ac1e379313bed7
151 changes: 64 additions & 87 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,8 @@ PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* p
return ResultKind::kValid;
}

PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNode* policy,
State* state,
bool infer_bound = true) {
PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return PopulationGenerationRule::ResultKind::kValid;
}
Expand All @@ -493,81 +492,8 @@ PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNod
continue;
}

int target_stage_id = GetSingleConsumerId(policy->search_task, *state, stage_id);
if (target_stage_id < 0) {
continue;
}
const Stage& target_stage = (*state)->stages[target_stage_id];

std::vector<std::pair<int, int>> candidates;
bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter;
bool target_is_tiled = IsTiled(target_stage);

bool visited_reduce = false;
// enumerate compute_at location at target_stage
// TODO(merrymercy): More analysis here to make smarter choices
for (size_t i = 0; i < target_stage->iters.size(); ++i) {
const Iterator& target_iter = target_stage->iters[i];
if (target_iter->iter_kind == IteratorKind::kReduction) {
visited_reduce = true;
if (!target_is_tiled) { // Do not go into reduce iter
break;
}
} else if (target_iter->iter_kind == IteratorKind::kSpatial) {
if (visited_reduce) { // Do not go into inner tile
break;
}
}

if (target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}

if (GetExtent(target_iter) == 1) {
// Skip iterators with length of 1
continue;
}
if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial &&
StrEndsWith(target_iter->name, ".0")) {
// Skip the first level iterators if target stage compute_at another stage
// In this case, the lengths of first level iterators are always one
continue;
}
candidates.emplace_back(target_stage_id, i);

if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) {
break;
}
}

// if the target_stage is already compute_at another stage X, try also compute_at X
// We call stage X as `target_target_stage`
if (target_compute_at_other) {
int target_target_stage_id;
target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first;
const Stage& target_target_stage = (*state)->stages[target_target_stage_id];

for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
const Iterator& target_target_iter = target_target_stage->iters[i];
if (target_target_iter->iter_kind == IteratorKind::kReduction ||
(*state)->attach_map->iter_to_attached_stages.count(
std::make_pair(target_target_stage_id, i))) {
break;
}

if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}

if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1
continue;
}

candidates.emplace_back(target_target_stage_id, i);
}
}
std::vector<std::pair<int, int>> candidates
= GetComputeLocationCandidates(policy->search_task, *state, stage_id);

int choice = (policy->rand_gen)() % (candidates.size() + 2);

Expand All @@ -588,17 +514,10 @@ PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNod
}
}

if (infer_bound) {
*state = policy->search_task->compute_dag.InferBound(*state);
}
*state = policy->search_task->compute_dag.InferBound(*state);
return PopulationGenerationRule::ResultKind::kValid;
}

PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
return MutateComputeLocationCommon(policy, state, true);
}

PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy,
State* state) const {
std::function<void(const SketchPolicyNode&, State*, int stage_id, int iter_offset)>
Expand Down Expand Up @@ -1066,7 +985,65 @@ PopulationGenerationRule::ResultKind MutateAutoUnroll::Apply(SketchPolicyNode* p

PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
return MutateComputeLocationCommon(policy, state, false);
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return PopulationGenerationRule::ResultKind::kInvalid;
}

// Extract all compute_at steps.
std::vector<int> compute_at_steps;
for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
if (auto ps = (*state)->transform_steps[s].as<ComputeAtStepNode>()) {
int stage_inc = GetTargetStageIDInState(*state, s) - ps->stage_id;

if (IsTiled((*state)->stages[ps->stage_id + stage_inc])) {
continue;
}

if (NeedsMultilevelTiling(policy->search_task, *state, ps->stage_id + stage_inc)) {
continue;
}
compute_at_steps.push_back(s);
}
}
if (compute_at_steps.empty()) {
return PopulationGenerationRule::ResultKind::kValid;
}

// Randomly pick one step
size_t step_id = compute_at_steps[(policy->rand_gen)() % compute_at_steps.size()];
auto ps = (*state)->transform_steps[step_id].as<ComputeAtStepNode>();
int stage_inc = GetTargetStageIDInState(*state, step_id) - ps->stage_id;
CHECK(ps != nullptr);

std::vector<std::pair<int, int>> candidates
= GetComputeLocationCandidates(policy->search_task, *state, ps->stage_id + stage_inc);

if (candidates.empty()) {
return PopulationGenerationRule::ResultKind::kInvalid;
}

int choice = (policy->rand_gen)() % (candidates.size());
int new_compute_at_stage_id = candidates[choice].first;
int new_compute_at_iter_id = candidates[choice].second;

// Replay a new state.
State tmp_s = policy->search_task->compute_dag->init_state;
for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) {
if (s == step_id) {
tmp_s.CopyOnWrite()->transform_steps.push_back(
ComputeAtStep(ps->stage_id, new_compute_at_stage_id - stage_inc, new_compute_at_iter_id));
} else {
tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[s]);
}
try {
StepApplyToState(tmp_s->transform_steps.back(), &tmp_s, policy->search_task->compute_dag);
} catch (dmlc::Error &e) {
return PopulationGenerationRule::ResultKind::kInvalid;
}
}

*state = tmp_s;
return PopulationGenerationRule::ResultKind::kValid;
}

PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy,
Expand Down
82 changes: 82 additions & 0 deletions src/auto_scheduler/search_policy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,88 @@ Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id) {
return spatial_split_step_ids;
}


std::vector<std::pair<int, int>> GetComputeLocationCandidates(
const SearchTask& task, const State& state, int stage_id) {
int target_stage_id = GetSingleConsumerId(task, state, stage_id);
if (target_stage_id < 0) {
return {};
}
const Stage& target_stage = state->stages[target_stage_id];

std::vector<std::pair<int, int>> candidates;
bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter;
bool target_is_tiled = IsTiled(target_stage);

bool visited_reduce = false;
// Enumerate compute_at location at target_stage
// TODO(merrymercy): More analysis here to make smarter choices
for (size_t i = 0; i < target_stage->iters.size(); ++i) {
const Iterator& target_iter = target_stage->iters[i];
if (target_iter->iter_kind == IteratorKind::kReduction) {
visited_reduce = true;
if (!target_is_tiled) { // Do not go into reduce iter
break;
}
} else if (target_iter->iter_kind == IteratorKind::kSpatial) {
if (visited_reduce) { // Do not go into inner tile
break;
}
}

if (target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}

if (GetExtent(target_iter) == 1) {
// Skip iterators with length of 1
continue;
}
if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial &&
StrEndsWith(target_iter->name, ".0")) {
// Skip the first level iterators if target stage compute_at another stage
// In this case, the lengths of first level iterators are always one
continue;
}
candidates.emplace_back(target_stage_id, i);

if (state->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) {
break;
}
}

// if the target_stage is already compute_at another stage X, try also compute_at X
// We call stage X as `target_target_stage`
if (target_compute_at_other) {
int target_target_stage_id;
target_target_stage_id = state->attach_map->stage_to_attach_iter.at(target_stage_id).first;
const Stage& target_target_stage = state->stages[target_target_stage_id];

for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
const Iterator& target_target_iter = target_target_stage->iters[i];
if (target_target_iter->iter_kind == IteratorKind::kReduction ||
state->attach_map->iter_to_attached_stages.count(
std::make_pair(target_target_stage_id, i))) {
break;
}

if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}

if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1
continue;
}

candidates.emplace_back(target_target_stage_id, i);
}
}

return candidates;
}

State DoMultiLevelTiling(const State& state, int stage_id, const std::string& format,
std::vector<int>* spatial_split_step_ids) {
// Temporal object to be used if the input pointer is nullptr
Expand Down
20 changes: 20 additions & 0 deletions src/auto_scheduler/search_policy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,22 @@ inline Iterator GetLastReduceIteratorInOutermostReduceTile(const Stage& stage) {
return stage->iters[0];
}

/*! \brief Get the target stage id of a history step in the new state.
* We need this because the stage_id in the history may be stale due to later steps */
inline int GetTargetStageIDInState(const State& s, int step_id) {
int stage_inc = 0;

for (size_t i = step_id + 1; i < s->transform_steps.size(); ++i) {
if (s->transform_steps[i]->IsInstance<CacheWriteStepNode>() ||
s->transform_steps[i]->IsInstance<CacheReadStepNode>() ||
s->transform_steps[i]->IsInstance<RfactorStepNode>()) {
if (s->transform_steps[i]->stage_id <= s->transform_steps[step_id]->stage_id + stage_inc)
stage_inc++;
}
}
return s->transform_steps[step_id]->stage_id + stage_inc;
}

/*! \brief Get all split steps for one stage. */
inline void GetSplitStepIds(const State& s, int stage_id, std::vector<int>* split_step_ids) {
for (int i = static_cast<int>(s->transform_steps.size()) - 1; i >= 0; --i) {
Expand Down Expand Up @@ -675,6 +691,10 @@ class SplitFactorizationMemo {
/*! \brief Get the indexes of SplitStep that processes on spatial iterator. */
Array<Integer> GetSpatialSplitStepIds(const State& s, int stage_id);

/*! \brief Get the possible compute locations for a stage. */
std::vector<std::pair<int, int>> GetComputeLocationCandidates(
const SearchTask& task, const State& state, int stage_id);

// Apply multi-level tiling structure according to a string format,
// where "S" stands a space level, "R" stands for a reduction level.
// For example, if the format is "SSRSRS", then we will
Expand Down