From 9e9b3ac834e7ec2d266d03637e673100d44c20e3 Mon Sep 17 00:00:00 2001 From: wrongtest Date: Mon, 4 Sep 2023 18:52:38 +0800 Subject: [PATCH 1/2] fix detect non-divisible iteration form like (x % 255) // 16 --- src/arith/iter_affine_map.cc | 16 ++++++++++++---- .../unittest/test_arith_iter_affine_map.py | 2 ++ 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 89a803d058e4..af46b2d86564 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1897,10 +1897,18 @@ PrimExpr IterMapRewriter::SplitFloorDivConst(IterSplitExpr lhs, PrimExpr base, P // = floormod(sc2+t, c2) // = floormod(floordiv(y, c1), c2) // = floormod(floordiv(iter, lower_factor*c1), c2), where c1=rhs, c2=extent/rhs - IterSplitExpr new_split(padded->source, - /* lower_factor = */ padded->lower_factor * rhs, - /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), - /* scale = */ padded->scale); + IterSplitExpr new_split; + if (CanProveDivisible(padded->extent, rhs)) { + new_split = IterSplitExpr(padded->source, + /* lower_factor = */ padded->lower_factor * rhs, + /* extent = */ analyzer_->Simplify(floordiv(padded->extent, rhs)), + /* scale = */ padded->scale); + } else { + new_split = IterSplitExpr(IterMark(padded, padded->extent), + /* lower_factor = */ rhs, + /* extent = */ analyzer_->Simplify(ceildiv(padded->extent, rhs)), + /* scale = */ make_const(rhs->dtype, 1)); + } auto new_base = analyzer_->Simplify(floordiv(base - left_pad, rhs), 6); if (is_zero(new_base)) { diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 63bb79d2b223..903c953f3d13 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -1051,6 +1051,8 @@ class TestPadding: # original extent is smaller than the divident # it is not surjective wrt to the region [0, 16) ({x: 3}, {flm(x, 16)}), + # (x % c1) // c2 is not proved as surjective if c1 % c2 != 0 + ({x: 255}, {fld(flm(x, 255), 16)}), ) def test_padding(self, positive_test_case): From 6c47e0e3c1a3e2f721a7b55a0c84016a9959a98a Mon Sep 17 00:00:00 2001 From: wrongtest Date: Thu, 7 Sep 2023 19:02:05 +0800 Subject: [PATCH 2/2] add required rule to prove divisibility of dynamic shape --- src/arith/rewrite_simplify.cc | 10 ++++++++++ .../unittest/test_arith_rewrite_simplify.py | 17 +++++++++++++++++ .../test_tir_transform_compact_buffer_region.py | 2 +- 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc index 63becf8eb77f..d5f946fca02a 100644 --- a/src/arith/rewrite_simplify.cc +++ b/src/arith/rewrite_simplify.cc @@ -1053,6 +1053,16 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) { TVM_TRY_REWRITE(matches_one_of(floormod(x * y, y), floormod(y * x, y)), ZeroWithTypeLike(y)); + // x = ay + b, then (ay + b + (ny - ay - b) % y) % y -> (b + (-b) % y) % y -> 0 + TVM_TRY_REWRITE_IF( + matches_one_of(floormod(x + floormod(z, y), y), floormod(floormod(z, y) + x, y)), + ZeroWithTypeLike(x), CanProveEqual(floormod(x.Eval() + z.Eval(), y.Eval()), 0)); + // x = ay + b, then (ay + b - (ay + b) % +-y) % y -> (b - b % +-y) % y -> 0 + TVM_TRY_REWRITE_IF( + matches_one_of(floormod(x - floormod(x, z), y), floormod(floormod(x, z) - x, y)), + ZeroWithTypeLike(x), + CanProveEqual(y.Eval() - z.Eval(), 0) || CanProveEqual(y.Eval() + z.Eval(), 0)); + if (floormod(x, c1).Match(ret)) { int64_t c1val = c1.Eval()->value; if (c1val > 0) { diff --git a/tests/python/unittest/test_arith_rewrite_simplify.py b/tests/python/unittest/test_arith_rewrite_simplify.py index 0b0a43a7d3d3..5b0627542204 100644 --- a/tests/python/unittest/test_arith_rewrite_simplify.py +++ b/tests/python/unittest/test_arith_rewrite_simplify.py @@ -605,6 +605,23 @@ class TestFloorModTwo(BaseCompare): ) +class TestFloorModPadded(BaseCompare): + """Special-case simplifications for divisibility proof + such that (x - x % k) must be divisible by k + """ + + x, y = te.var("x"), te.var("y") + test_case = tvm.testing.parameter( + TestCase(flm(x - flm(x, 9), 9), 0), + TestCase(flm(x - flm(x, -9), 9), 0), + TestCase(flm(x + flm(-x, 9), 9), 0), + TestCase(flm(x + flm(8 * x, 9), 9), 0), + TestCase(flm(x - flm(x, y), y), 0), + TestCase(flm(x - flm(x, -y), y), 0), + TestCase(flm(x + flm(-x, y), y), 0), + ) + + class TestMinIndex(BaseCompare): x, y, z = te.var("x"), te.var("y"), te.var("z") test_case = tvm.testing.parameter( diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index d268403c1be4..d5d5e0634ef6 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -999,7 +999,7 @@ def expected( ) -> None: for i0, i1 in T.grid(4, 1): with T.block(): - C_local2 = T.alloc_buffer([1, 1, 15, 1000, 16], dtype="float32", scope="local") + C_local2 = T.alloc_buffer([1, 1, 16, 1000, 16], dtype="float32", scope="local") C_local1 = T.alloc_buffer([255, 1000], dtype="float32", scope="local") for ax0, ax1, ax2 in T.grid(255, 1000, 64): with T.block("matmul"):