From bfed48f004d21aba039feecff9eb38af4ba23793 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 25 Sep 2023 17:29:50 -0700 Subject: [PATCH 1/2] Revert "[TensorIR][Visitor] Visit buffer members in `match_buffer`'s in block visitor functions (#15153)" This reverts commit 34637d7ee38f2636b1948548a39c15838d7a8db6. --- src/tir/ir/stmt_functor.cc | 32 ++------------ .../unittest/test_tir_lower_match_buffer.py | 11 ++++- ...test_tir_transform_unify_thread_binding.py | 43 ------------------- 3 files changed, 13 insertions(+), 73 deletions(-) diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 7d1fe9f8dd7c..1c15f9582686 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -135,15 +135,8 @@ void StmtVisitor::VisitStmt_(const BlockNode* op) { VisitArray(op->reads, fvisit_buffer_region); VisitArray(op->writes, fvisit_buffer_region); VisitArray(op->match_buffers, - [this, fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { + [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { fvisit_buffer_region(match_buffer_region->source); - this->VisitExpr(match_buffer_region->buffer->elem_offset); - VisitArray(match_buffer_region->buffer->strides, - [this](const PrimExpr& e) { this->VisitExpr(e); }); - VisitArray(match_buffer_region->buffer->shape, - [this](const PrimExpr& e) { this->VisitExpr(e); }); - VisitArray(match_buffer_region->buffer->axis_separators, - [this](const IntImm& e) { this->VisitExpr(e); }); }); if (op->init.defined()) { this->VisitStmt(op->init.value()); @@ -245,28 +238,11 @@ class StmtMutator::Internal { static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { - const Buffer& buffer = match_buffer_region->buffer; Array region = Mutate(self, match_buffer_region->source->region); - PrimExpr elem_offset = self->VisitExpr(buffer->elem_offset); - Array strides = Mutate(self, buffer->strides); - Array shape = Mutate(self, buffer->shape); - Array axis_separators = - MutateArray(self, buffer->axis_separators, - [self](const IntImm& e) { return Downcast(self->VisitExpr(e)); }); - - if (elem_offset.same_as(buffer->elem_offset) && strides.same_as(buffer->strides) && - shape.same_as(buffer->shape) && axis_separators.same_as(buffer->axis_separators)) { - if (region.same_as(match_buffer_region->source->region)) { - return match_buffer_region; - } else { - return MatchBufferRegion(buffer, - BufferRegion(match_buffer_region->source->buffer, region)); - } + if (region.same_as(match_buffer_region->source->region)) { + return match_buffer_region; } else { - Buffer new_buffer(buffer->data, buffer->dtype, shape, strides, elem_offset, buffer->name, - buffer->data_alignment, buffer->offset_factor, buffer->buffer_type, - axis_separators, buffer->span); - return MatchBufferRegion(new_buffer, + return MatchBufferRegion(match_buffer_region->buffer, BufferRegion(match_buffer_region->source->buffer, region)); } }; diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index 7dc164496501..cc5f2c2b8dbf 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -18,7 +18,6 @@ import pytest import tvm -import tvm.testing from tvm.script import tir as T @@ -531,4 +530,12 @@ def test_fail_match_func_param(): if __name__ == "__main__": - tvm.testing.main() + test_buffer_load_store() + test_opaque_access() + test_high_dim_opaque_access() + test_recursive_match() + test_symbolic_match() + test_rank0_buffer() + test_fail_load_store() + test_fail_buffer_bind() + test_fail_match_func_param() diff --git a/tests/python/unittest/test_tir_transform_unify_thread_binding.py b/tests/python/unittest/test_tir_transform_unify_thread_binding.py index d42adfcee4bb..9ee86433128d 100644 --- a/tests/python/unittest/test_tir_transform_unify_thread_binding.py +++ b/tests/python/unittest/test_tir_transform_unify_thread_binding.py @@ -258,45 +258,6 @@ def unified_element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) - ) -@T.prim_func -def match_buffer_with_elem_offset( - A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: T.int32 -) -> None: - for i in T.thread_binding(0, 4, "blockIdx.x"): - for j in range(2): - with T.block(): - T.writes(A[I[i], offset, j * 4 : j * 4 + 4]) - sub_A = T.match_buffer( - A[I[i], offset, j * 4 : j * 4 + 4], - (4), - elem_offset=I[i] * 80 + offset * 8 + j * 4, - ) - for ji in range(0, 4): - sub_A[j * 4 + ji] = 1 - - -@T.prim_func -def unified_match_buffer_with_elem_offset( - A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: T.int32 -) -> None: - for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"): - for j in range(2): - with T.block(""): - T.reads(I[blockIdx_x]) - T.writes(A[I[blockIdx_x], offset, j * 4 : j * 4 + 4]) - sub_A = T.match_buffer( - A[I[blockIdx_x], offset, j * 4 : j * 4 + 4], - (4,), - elem_offset=I[blockIdx_x] * 80 + offset * 8 + j * 4, - ) - for ji in range(4): - i = T.int32() - sub_A_1 = T.Buffer( - (4,), data=sub_A.data, elem_offset=I[i] * 80 + offset * 8 + j * 4 - ) - sub_A_1[j * 4 + ji] = T.float32(1) - - def test_thread_x(): _check(element_wise_thread_x, unified_element_wise_thread_x) @@ -327,10 +288,6 @@ def test_implicit_block(): _check(element_wise_implicit_block, unified_element_wise_implicit_block) -def test_match_buffer_with_elem_offset(): - _check(match_buffer_with_elem_offset, unified_match_buffer_with_elem_offset) - - def test_inner_binding_with_annotation(): @T.prim_func def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")): From e36730cf2187b17fef175e625f4f4ad6083398b3 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Mon, 25 Sep 2023 17:37:49 -0700 Subject: [PATCH 2/2] update --- tests/python/unittest/test_tir_lower_match_buffer.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/tests/python/unittest/test_tir_lower_match_buffer.py b/tests/python/unittest/test_tir_lower_match_buffer.py index cc5f2c2b8dbf..7dc164496501 100644 --- a/tests/python/unittest/test_tir_lower_match_buffer.py +++ b/tests/python/unittest/test_tir_lower_match_buffer.py @@ -18,6 +18,7 @@ import pytest import tvm +import tvm.testing from tvm.script import tir as T @@ -530,12 +531,4 @@ def test_fail_match_func_param(): if __name__ == "__main__": - test_buffer_load_store() - test_opaque_access() - test_high_dim_opaque_access() - test_recursive_match() - test_symbolic_match() - test_rank0_buffer() - test_fail_load_store() - test_fail_buffer_bind() - test_fail_match_func_param() + tvm.testing.main()