From 13588f0ac709165467472f261f351bd86a17dfe1 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Tue, 10 Mar 2020 08:17:59 -0700 Subject: [PATCH] Revert "Tighten split's extent (#4931)" This reverts commit 585f9ce6e7bef7d0e8902b1c1e55dcb3bbe84eed. --- src/te/schedule/message_passing.cc | 76 +------------------ .../unittest/test_schedule_bound_inference.py | 26 ------- 2 files changed, 3 insertions(+), 99 deletions(-) diff --git a/src/te/schedule/message_passing.cc b/src/te/schedule/message_passing.cc index a7b248285c4d..5b6fa861895a 100644 --- a/src/te/schedule/message_passing.cc +++ b/src/te/schedule/message_passing.cc @@ -51,66 +51,17 @@ void Update(std::unordered_map* p_state, } } -/*! - * \param Upward propagating whether an IterVar derives at least one leaf IterVar that binds to - * a thread. - * - * \param stage The stage to operate on. - * \param p_state The propagation result of each IterVar. - */ -void PassUpThreadBinding(const Stage& stage, std::unordered_map* p_state) { - auto bound_to_thread = [&stage](const IterVar& iv) { - bool bound = false; - auto it = stage->iter_var_attrs.find(iv); - if (it != stage->iter_var_attrs.end()) { - bound = (*it).second->bind_thread.defined(); - } - return bound; - }; - - auto& state = *p_state; - // Fill p_state with leaf itervars - for (const IterVar& iv : stage->leaf_iter_vars) { - state[iv] = bound_to_thread(iv); - } - // Traverse the graph bottom-up to propagate thread binding information - for (size_t i = stage->relations.size(); i != 0; --i) { - IterVarRelation rel = stage->relations[i - 1]; - if (const SplitNode* s = rel.as()) { - state[s->parent] = state[s->inner] || state[s->outer]; - } else if (const FuseNode* s = rel.as()) { - state[s->inner] = state[s->fused]; - state[s->outer] = state[s->fused]; - } else if (const RebaseNode* s = rel.as()) { - state[s->parent] = state[s->rebased]; - } else if (rel.as()) { - } else { - LOG(FATAL) << "unknown relation type"; - } - } -} - void PassDownDomain(const Stage& stage, std::unordered_map* p_state, arith::Analyzer* actx, bool allow_missing) { - auto ceil_div = [actx](const PrimExpr& a, const PrimExpr& b) { + auto ceil_div = [actx](PrimExpr a, PrimExpr b) { if (actx->CanProve(indexmod(a, b) == 0)) { return actx->Simplify(indexdiv(a, b)); } return actx->Simplify(indexdiv(a + (b - 1), b)); }; - auto minimum_or_later = [actx](const PrimExpr& a, const PrimExpr& b) { - if (actx->CanProve(a < b)) { - return actx->Simplify(a); - } - return actx->Simplify(b); - }; - - std::unordered_map dominating_thread; - PassUpThreadBinding(stage, &dominating_thread); - auto& state = *p_state; // forwar iteration on relations for (IterVarRelation rel : stage->relations) { @@ -121,35 +72,14 @@ void PassDownDomain(const Stage& stage, } CHECK(!state.count(r->inner)); const Range& range_parent = state.at(r->parent); - // Tighten iv's extent to min(parent_extent, factor_or_nparts), only if all of the - // following conditions are met: - // 1. No leaf IterVar derived from iv binds to any thread. People may use split - // to force an IterVar extent to match the number of allocated threads to fuse stages - // that require different number of threads. We don't want to change these extents. - // 2. allow_missing is false, i.e. that PassDownDomain is called by the final InferBound, - // rather than by an early compiler phase, such as rfactor(). We don't want to tighten an - // IterVar in an early phase allowing missing IterVars, because it may bind to a thread later. - // 3. range_parent's extent is not 0. At lest one Topi test has a case where a tensor has one - // zero-sized dimension. Split creates iv with a positive extent to avoid zero-extent - // IterVar. We don't touch it. - auto resolve_min_extent_for_split = [&](const IterVar& iv, const PrimExpr& factor_or_nparts) { - return dominating_thread[iv] || allow_missing || is_zero(range_parent->extent) - ? factor_or_nparts - : minimum_or_later(range_parent->extent, factor_or_nparts); - }; if (r->factor.defined()) { Update(p_state, r->inner, - Range::make_by_min_extent( - 0, resolve_min_extent_for_split(r->inner, r->factor)), - actx); + Range::make_by_min_extent(0, r->factor), actx); Update(p_state, r->outer, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->factor)), actx); } else { - Update(p_state, r->outer, - Range::make_by_min_extent( - 0, resolve_min_extent_for_split(r->outer, r->nparts)), - actx); + Update(p_state, r->outer, Range::make_by_min_extent(0, r->nparts), actx); Update(p_state, r->inner, Range::make_by_min_extent( 0, ceil_div(range_parent->extent, r->nparts)), actx); diff --git a/tests/python/unittest/test_schedule_bound_inference.py b/tests/python/unittest/test_schedule_bound_inference.py index edae527c0183..484aa503e066 100644 --- a/tests/python/unittest/test_schedule_bound_inference.py +++ b/tests/python/unittest/test_schedule_bound_inference.py @@ -70,32 +70,6 @@ def test_bound3(): assert(bounds[A1.op.axis[0]].extent.value==32) assert(bounds[A1.op.axis[1]].extent.value==16) -def test_bound_split_ext_less_than_factor(): - m = 8 - I = te.placeholder((m,), name='I') - EF = te.compute((m,), lambda i: I[i] * 2, name = "EF") - E = te.compute((m,), lambda i: EF[i] * 2, name = "E") - s = te.create_schedule([E.op]) - xo, xi = s[E].split(s[E].op.axis[0], factor = 32) - s[EF].compute_at(s[E], xo) - - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - assert bounds[xi].extent.value == m - -def test_bound_split_ext_less_than_naprts(): - m = 8 - I = te.placeholder((m,), name='I') - EF = te.compute((m,), lambda i: I[i] * 2, name = "EF") - E = te.compute((m,), lambda i: EF[i] * 2, name = "E") - s = te.create_schedule([E.op]) - xo, xi = s[E].split(s[E].op.axis[0], nparts = 32) - s[EF].compute_at(s[E], xo) - - bounds = tvm.te.schedule.InferBound(s) - assert isinstance(bounds, tvm.container.Map) - assert bounds[xo].extent.value == m - def test_bound_split_divisible(): m = te.var('m') l = te.var('l')