diff --git a/src/tir/schedule/primitive/cache_read_write.cc b/src/tir/schedule/primitive/cache_read_write.cc index a687624bacd4..eac5500a19b3 100644 --- a/src/tir/schedule/primitive/cache_read_write.cc +++ b/src/tir/schedule/primitive/cache_read_write.cc @@ -958,9 +958,10 @@ class CacheReadRewriter : public StmtExprMutator { // Otherwise, update read regions and match_buffers // Only make this change if the block is one of the specified consumers. if (is_consumer) { - Array reads = update_access_regions(block->reads); - Array match_buffers = update_match_buffers(block->match_buffers); - if (!reads.same_as(block->reads) || !match_buffers.same_as(block->match_buffers)) { + // Use the updated block stmt + Array reads = update_access_regions(stmt->reads); + Array match_buffers = update_match_buffers(stmt->match_buffers); + if (!reads.same_as(stmt->reads) || !match_buffers.same_as(stmt->match_buffers)) { ObjectPtr n = make_object(*stmt.as()); n->reads = std::move(reads); n->match_buffers = std::move(match_buffers); diff --git a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py index 345c7368ce91..1fda0f432108 100644 --- a/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py +++ b/tests/python/tir-schedule/test_tir_schedule_cache_read_write.py @@ -488,6 +488,19 @@ def cache_read_nested_seq_target( C[vi, vj] = A_global[vi, vj] * T.float32(2) +@T.prim_func +def nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle): + A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32") + B = T.match_buffer(var_B, T.int64(1), dtype="int32") + C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32") + for ax0, ax1 in T.grid(T.int64(1), T.int64(512)): + with T.block("C"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[B[v_ax0], v_ax1], B[v_ax0]) + T.writes(C[v_ax0, v_ax1]) + C[v_ax0, v_ax1] = A[B[v_ax0], v_ax1] + + ########## Expected function after cache_read ########## @@ -831,6 +844,26 @@ def cache_inplace_buffer(data_io: T.Buffer(64, "int32")) -> None: data_io[v0] = data_io_global_1[v0] +@T.prim_func +def cache_read_nested_buffer_access(var_A: T.handle, var_B: T.handle, var_C: T.handle): + A = T.match_buffer(var_A, (T.int64(7), T.int64(512)), dtype="float32") + B = T.match_buffer(var_B, T.int64(1), dtype="int32") + C = T.match_buffer(var_C, (T.int64(1), T.int64(512)), dtype="float32") + B_global = T.alloc_buffer((T.int64(1),), "int32") + for ax0 in range(T.int64(1)): + with T.block("B_global"): + v0 = T.axis.spatial(T.int64(1), ax0) + T.reads(B[v0]) + T.writes(B_global[v0]) + B_global[v0] = B[v0] + for ax0, ax1 in T.grid(T.int64(1), T.int64(512)): + with T.block("C"): + v_ax0, v_ax1 = T.axis.remap("SS", [ax0, ax1]) + T.reads(A[B_global[v_ax0], v_ax1], B_global[v_ax0]) + T.writes(C[v_ax0, v_ax1]) + C[v_ax0, v_ax1] = A[B_global[v_ax0], v_ax1] + + ########## Expected function after cache_write ########## @@ -1358,6 +1391,14 @@ def test_cache_read_non_int32_shape(use_block_name): verify_trace_roundtrip(sch=sch, mod=elementwise_shape_int64) +def test_cache_read_nested_buffer_access(use_block_name): + sch = tir.Schedule(nested_buffer_access, debug_mask="all") + block_c = "C" if use_block_name else sch.get_block("C") + sch.cache_read(block_c, 1, "global") + assert_structural_equal_ignore_global_symbol(cache_read_nested_buffer_access, sch.mod["main"]) + verify_trace_roundtrip(sch=sch, mod=nested_buffer_access) + + def test_cache_read_fail_multi_producer(use_block_name): sch = tir.Schedule(func_multi_producer, debug_mask="all") block_b = "B" if use_block_name else sch.get_block("B")