Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/arith/int_set.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1053,11 +1053,17 @@ Map<Var, arith::IntSet> AsIntSet(const Map<Var, Range>& var_dom) {
/*! \brief Helper function to convert IterSumExpr to the actual touched range. */
static Optional<IntSet> EvalIterSum(const IterSumExpr& iter_min, const PrimExpr& extent,
Analyzer* analyzer) {
if (analyzer->CanProve(extent == 0)) {
return IntSet::Nothing();
}
if (iter_min->args.empty()) {
return IntSet::FromMinExtent(iter_min->base, extent);
}
ICHECK_EQ(iter_min->args.size(), 1) << "The `EvalIterSum` expects fused iter sum expr";
const IterSplitExpr& split = iter_min->args[0];
if (analyzer->CanProve(split->extent == 0)) {
return IntSet::Nothing();
}
if (!analyzer->CanProve(extent >= split->scale)) {
return NullOpt;
}
Expand Down
3 changes: 3 additions & 0 deletions src/tir/schedule/state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ bool ProducerCoversConsumer(const Array<PrimExpr>& buffer_shape,
if (produced_region[i].IsNothing()) {
return false;
}
if (consumed_region[i].IsNothing()) {
continue;
}
arith::IntSet produced =
arith::IntSet::Interval(analyzer->canonical_simplify(produced_region[i].min()),
analyzer->canonical_simplify(produced_region[i].max()));
Expand Down
28 changes: 28 additions & 0 deletions tests/python/unittest/test_tir_schedule_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,5 +417,33 @@ def two_elementwise(a: T.handle, c: T.handle) -> None:
assert is_output_block(sch, block_rv)


def test_empty_grid():
@T.prim_func
def foo(out: T.Buffer((T.int64(1), T.int64(8), T.int64(8)), "int32")):
act = T.alloc_buffer((1, 8, 8), "int32")
for z2, y2, x2 in T.grid(1, 8, 8):
with T.block("b0"):
az, ay, ax = T.axis.remap("SSS", [z2, y2, x2])
T.writes(act[az, ay, ax])
act[az, ay, az] = T.int32(0)
# Empty grid:
for z1, y1, x1 in T.grid(0, 8, 8):
with T.block("b1"):
az, ay, ax = T.axis.remap("SSS", [z1, y1, x1])
T.reads(act[az + 1, ay, ax])
T.writes(out[az, ay, ax])
out[az, ay, ax] = act[az + 1, ay, ax]
# The block below is not needed to show the bug, but the 'out'
# buffer would be undefined without it.
for z2, y2, x2 in T.grid(1, 8, 8):
with T.block("b2"):
az, ay, ax = T.axis.remap("SSS", [z2, y2, x2])
T.writes(out[az, ay, ax])
out[az, ay, az] = T.int32(0)

# This caused a crash before.
sch = tvm.tir.Schedule(foo)


if __name__ == "__main__":
tvm.testing.main()