diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index a15cecabddf9..d8fcee859f03 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -26,6 +26,8 @@ #include #include +#include + #include "../transforms/ir_utils.h" namespace tvm { namespace tir { @@ -78,6 +80,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; + /*! \brief let bindings inside the block */ + std::unordered_map let_bindings_; /*!\ brief Internal analyzer. */ arith::Analyzer ana_; @@ -111,6 +115,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const LetStmtNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const CallNode* op) override; @@ -149,7 +154,8 @@ void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); + PrimExpr remapped_index = Substitute(index, let_bindings_); + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_)); } Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); ExprVisitor::VisitExpr_(op); @@ -176,6 +182,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { } } +void BlockReadWriteDetector::VisitStmt_(const LetStmtNode* op) { + let_bindings_[op->var.get()] = op->value; + StmtVisitor::VisitStmt_(op); + let_bindings_.erase(op->var.get()); +} + void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::tvm_access_ptr())) { const VarNode* buffer_var = op->args[1].as(); @@ -225,7 +237,8 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); + PrimExpr remapped_index = Substitute(index, let_bindings_); + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_)); } Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); StmtVisitor::VisitStmt_(op); diff --git a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py index 21d832848e83..a65277df612d 100644 --- a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py +++ b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. import pytest + import tvm +import tvm.testing from tvm import tir -from tvm.script import tir as T from tvm.ir import Range +from tvm.script import tir as T @T.prim_func @@ -355,14 +357,33 @@ def test_access_of_decompose_reduction(): tvm.ir.assert_structural_equal(block.writes, ret[1]) +def test_buffer_access_with_let_binding(): + @T.prim_func + def func( + storage: T.Buffer((16, 16, 16), "float32"), + seq_slot_ids: T.Buffer((16,), "int32"), + history_slot_ids: T.Buffer((16,), "int32"), + output: T.Buffer((16, 16), "float32"), + ): + for i, s in T.grid(16, 16): + with T.block("copy"): + vi, vs = T.axis.remap("SS", [i, s]) + T.reads( + seq_slot_ids[vi], + history_slot_ids[vi], + storage[seq_slot_ids[vi], history_slot_ids[vi], vs], + ) + T.writes(output[vi, vs]) + seq_id: T.int32 = seq_slot_ids[vi] + history_id: T.int32 = history_slot_ids[vi] + output[vi, vs] = storage[seq_id, history_id, vs] + + block = func.body.block.body.body.body.block + buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()} + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(block.reads, ret[0]) + tvm.ir.assert_structural_equal(block.writes, ret[1]) + + if __name__ == "__main__": - test_block_access_region_detector() - test_opaque_block() - test_opaque_access() - test_opaque_access_with_tvm_access_ptr() - test_match_buffer() - test_access_in_if_then_else_func() - test_access_in_branch_func() - test_access_of_padding_pattern() - test_access_of_reduction() - test_access_of_decompose_reduction() + tvm.testing.main()