Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 4 additions & 28 deletions src/tir/ir/stmt_functor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,8 @@ void StmtVisitor::VisitStmt_(const BlockNode* op) {
VisitArray(op->reads, fvisit_buffer_region);
VisitArray(op->writes, fvisit_buffer_region);
VisitArray(op->match_buffers,
[this, fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
[fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) {
fvisit_buffer_region(match_buffer_region->source);
this->VisitExpr(match_buffer_region->buffer->elem_offset);
VisitArray(match_buffer_region->buffer->strides,
[this](const PrimExpr& e) { this->VisitExpr(e); });
VisitArray(match_buffer_region->buffer->shape,
[this](const PrimExpr& e) { this->VisitExpr(e); });
VisitArray(match_buffer_region->buffer->axis_separators,
[this](const IntImm& e) { this->VisitExpr(e); });
});
if (op->init.defined()) {
this->VisitStmt(op->init.value());
Expand Down Expand Up @@ -245,28 +238,11 @@ class StmtMutator::Internal {

static Array<MatchBufferRegion> Mutate(StmtMutator* self, const Array<MatchBufferRegion>& arr) {
auto fmutate = [self](const MatchBufferRegion& match_buffer_region) {
const Buffer& buffer = match_buffer_region->buffer;
Array<Range> region = Mutate(self, match_buffer_region->source->region);
PrimExpr elem_offset = self->VisitExpr(buffer->elem_offset);
Array<PrimExpr> strides = Mutate(self, buffer->strides);
Array<PrimExpr> shape = Mutate(self, buffer->shape);
Array<IntImm> axis_separators =
MutateArray(self, buffer->axis_separators,
[self](const IntImm& e) { return Downcast<IntImm>(self->VisitExpr(e)); });

if (elem_offset.same_as(buffer->elem_offset) && strides.same_as(buffer->strides) &&
shape.same_as(buffer->shape) && axis_separators.same_as(buffer->axis_separators)) {
if (region.same_as(match_buffer_region->source->region)) {
return match_buffer_region;
} else {
return MatchBufferRegion(buffer,
BufferRegion(match_buffer_region->source->buffer, region));
}
if (region.same_as(match_buffer_region->source->region)) {
return match_buffer_region;
} else {
Buffer new_buffer(buffer->data, buffer->dtype, shape, strides, elem_offset, buffer->name,
buffer->data_alignment, buffer->offset_factor, buffer->buffer_type,
axis_separators, buffer->span);
return MatchBufferRegion(new_buffer,
return MatchBufferRegion(match_buffer_region->buffer,
BufferRegion(match_buffer_region->source->buffer, region));
}
};
Expand Down
43 changes: 0 additions & 43 deletions tests/python/unittest/test_tir_transform_unify_thread_binding.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,45 +258,6 @@ def unified_element_wise_implicit_block(a: T.handle, b: T.handle, c: T.handle) -
)


@T.prim_func
def match_buffer_with_elem_offset(
A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: T.int32
) -> None:
for i in T.thread_binding(0, 4, "blockIdx.x"):
for j in range(2):
with T.block():
T.writes(A[I[i], offset, j * 4 : j * 4 + 4])
sub_A = T.match_buffer(
A[I[i], offset, j * 4 : j * 4 + 4],
(4),
elem_offset=I[i] * 80 + offset * 8 + j * 4,
)
for ji in range(0, 4):
sub_A[j * 4 + ji] = 1


@T.prim_func
def unified_match_buffer_with_elem_offset(
A: T.Buffer((8, 10, 8), "float32"), I: T.Buffer((4,), "int32"), offset: T.int32
) -> None:
for blockIdx_x in T.thread_binding(4, thread="blockIdx.x"):
for j in range(2):
with T.block(""):
T.reads(I[blockIdx_x])
T.writes(A[I[blockIdx_x], offset, j * 4 : j * 4 + 4])
sub_A = T.match_buffer(
A[I[blockIdx_x], offset, j * 4 : j * 4 + 4],
(4,),
elem_offset=I[blockIdx_x] * 80 + offset * 8 + j * 4,
)
for ji in range(4):
i = T.int32()
sub_A_1 = T.Buffer(
(4,), data=sub_A.data, elem_offset=I[i] * 80 + offset * 8 + j * 4
)
sub_A_1[j * 4 + ji] = T.float32(1)


def test_thread_x():
_check(element_wise_thread_x, unified_element_wise_thread_x)

Expand Down Expand Up @@ -327,10 +288,6 @@ def test_implicit_block():
_check(element_wise_implicit_block, unified_element_wise_implicit_block)


def test_match_buffer_with_elem_offset():
_check(match_buffer_with_elem_offset, unified_match_buffer_with_elem_offset)


def test_inner_binding_with_annotation():
@T.prim_func
def inner_binding_with_annotation(A: T.Buffer((64,), "float32"), B: T.Buffer((64,), "float32")):
Expand Down