Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
rebase upstream
  • Loading branch information
wrongtest-intellif committed May 31, 2022
commit 4d1239a0d5332cec1df577f2598fcea1d5459f42
7 changes: 4 additions & 3 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -392,8 +392,9 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,

auto iter_map = arith::DetectIterMap(
/*indices=*/transformed_block_iters, /*input_iters=*/block_iter_dom, /*predicate=*/Bool(true),
/*require_bijective=*/true, &analyzer, /*simplify_trivial_iterators=*/true);
if (iter_map.empty()) {
/*check_level=*/arith::IterMapLevel::Bijective, &analyzer,
/*simplify_trivial_iterators=*/true);
if (iter_map->indices.empty()) {
throw NotBijectiveAffineIndexMapError(self->mod, index_map);
}

Expand All @@ -417,7 +418,7 @@ void TransformBlockLayout(ScheduleState self, const StmtSRef& block_sref,
// Step 5.2: Update the block body. Use the inverse map f^{-1} to replace the original block iters
// in the body.

auto inverse_map = arith::InverseAffineIterMap(iter_map, new_block_vars);
auto inverse_map = arith::InverseAffineIterMap(iter_map->indices, new_block_vars);
// Trivial block iters will be simplified in DetectIterMap, they should be mapped to constant
// zero.
for (const auto& iter_var : block_ptr->iter_vars) {
Expand Down
11 changes: 4 additions & 7 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -1281,13 +1281,10 @@ def grouped_channel_bias_non_perfect_tiled(
cc = T.axis.spatial(720, c_o * 360 + c_i)
Y[cc, hh, ww] = X[cc, hh, ww] + B[cc // 16]

def check_sched(debug_mask):
sch = tir.Schedule(grouped_channel_bias, debug_mask=debug_mask)
loop = sch.get_loops(sch.get_block("compute"))[0]
sch.compute_at(sch.get_block("init"), loop)
tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled)

check_sched("all")
sch = tir.Schedule(grouped_channel_bias, debug_mask="all")
loop = sch.get_loops(sch.get_block("compute"))[0]
sch.compute_at(sch.get_block("init"), loop)
tvm.ir.assert_structural_equal(sch.mod["main"], grouped_channel_bias_non_perfect_tiled)


def test_fail_subtree_complete_block():
Expand Down