Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions src/tir/schedule/primitive/cache_read_write.cc
Original file line number Diff line number Diff line change
Expand Up @@ -958,9 +958,10 @@ class CacheReadRewriter : public StmtExprMutator {
// Otherwise, update read regions and match_buffers
// Only make this change if the block is one of the specified consumers.
if (is_consumer) {
Array<BufferRegion> reads = update_access_regions(block->reads);
Array<MatchBufferRegion> match_buffers = update_match_buffers(block->match_buffers);
if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) {
// Use the updated block stmt
Array<BufferRegion> reads = update_access_regions(stmt->reads);
Array<MatchBufferRegion> match_buffers = update_match_buffers(stmt->match_buffers);
if (!reads.same_as(stmt->reads) || !match_buffers.same_as(stmt->match_buffers)) {
ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
n->reads = std::move(reads);
n->match_buffers = std::move(match_buffers);
Expand Down
41 changes: 41 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_cache_read_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,19 @@ def cache_read_nested_seq_target(
C[vi, vj] = A_global[vi, vj] * T.float32(2)


@T.prim_func
def nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle):
A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
B = T.match_buffer(var_B, T.int64(1), dtype="int32")
C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
with T.block("C"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[B[v_ax0], v_ax1], B[v_ax0])
T.writes(C[v_ax0, v_ax1])
C[v_ax0, v_ax1] = A[B[v_ax0], v_ax1]


########## Expected function after cache_read ##########


Expand Down Expand Up @@ -831,6 +844,26 @@ def cache_inplace_buffer(data_io: T.Buffer(64, "int32")) -> None:
data_io[v0] = data_io_global_1[v0]


@T.prim_func
def cache_read_nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle):
A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32")
B = T.match_buffer(var_B, T.int64(1), dtype="int32")
C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32")
B_global = T.alloc_buffer((T.int64(1),), "int32")
for ax0 in range(T.int64(1)):
with T.block("B_global"):
v0 = T.axis.spatial(T.int64(1), ax0)
T.reads(B[v0])
T.writes(B_global[v0])
B_global[v0] = B[v0]
for ax0, ax1 in T.grid(T.int64(1), T.int64(512)):
with T.block("C"):
v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1])
T.reads(A[B_global[v_ax0], v_ax1], B_global[v_ax0])
T.writes(C[v_ax0, v_ax1])
C[v_ax0, v_ax1] = A[B_global[v_ax0], v_ax1]


########## Expected function after cache_write ##########


Expand Down Expand Up @@ -1358,6 +1391,14 @@ def test_cache_read_non_int32_shape(use_block_name):
verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64)


def test_cache_read_nested_buffer_access(use_block_name):
sch = tir.Schedule(nested_buffer_access, debug_mask="all")
block_c = "C" if use_block_name else sch.get_block("C")
sch.cache_read(block_c, 1, "global")
assert_structural_equal_ignore_global_symbol(cache_read_nested_buffer_access, sch.mod["main"])
verify_trace_roundtrip(sch=sch, mod=nested_buffer_access)


def test_cache_read_fail_multi_producer(use_block_name):
sch = tir.Schedule(func_multi_producer, debug_mask="all")
block_b = "B" if use_block_name else sch.get_block("B")
Expand Down