From bb0f7cf101d1e1e497a7a379d20063a4f43fc847 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 10 Feb 2024 23:50:49 +0800 Subject: [PATCH 1/2] [TIR] Fix get_block_access_region for let bindings The current implementation of `block_access_region_detector` does not consider the let bindings inside the block. To be more specific: - The let bindings inside the block can be the index of buffer access indices - The let bindings var is defined inside the block, so the block annotation cannot use those vars. - We need to substitute the let bindings inside the block to the block annotation. This PR fixes this problem and can create legal IRs. --- .../analysis/block_access_region_detector.cc | 16 ++++++- ...st_tir_analysis_get_block_access_region.py | 43 ++++++++++++++----- 2 files changed, 46 insertions(+), 13 deletions(-) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index a15cecabddf9..577639b7982e 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -25,6 +25,7 @@ #include #include #include +#include #include "../transforms/ir_utils.h" namespace tvm { @@ -78,6 +79,8 @@ class BlockReadWriteDetector : public StmtExprVisitor { Map buffer_var_map_; /*! \brief The target buffer var mapping to its matching */ std::unordered_map match_buffers_; + /*! \brief let bindings inside the block */ + std::unordered_map let_bindings_; /*!\ brief Internal analyzer. */ arith::Analyzer ana_; @@ -111,6 +114,7 @@ class BlockReadWriteDetector : public StmtExprVisitor { void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; + void VisitStmt_(const LetStmtNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const VarNode* op) override; void VisitExpr_(const CallNode* op) override; @@ -149,7 +153,8 @@ void BlockReadWriteDetector::VisitExpr_(const VarNode* op) { UpdateOpaque(GetRef void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); + PrimExpr remapped_index = Substitute(index, let_bindings_); + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_)); } Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region); ExprVisitor::VisitExpr_(op); @@ -176,6 +181,12 @@ void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { } } +void BlockReadWriteDetector::VisitStmt_(const LetStmtNode* op) { + let_bindings_[op->var.get()] = op->value; + StmtVisitor::VisitStmt_(op); + let_bindings_.erase(op->var.get()); +} + void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { if (op->op.same_as(builtin::tvm_access_ptr())) { const VarNode* buffer_var = op->args[1].as(); @@ -225,7 +236,8 @@ void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) { std::vector relaxed_region; for (const PrimExpr& index : op->indices) { - relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(index), dom_map_)); + PrimExpr remapped_index = Substitute(index, let_bindings_); + relaxed_region.push_back(arith::EvalSet(arith::IntSet::Vector(remapped_index), dom_map_)); } Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region); StmtVisitor::VisitStmt_(op); diff --git a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py index 21d832848e83..a65277df612d 100644 --- a/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py +++ b/tests/python/tir-analysis/test_tir_analysis_get_block_access_region.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. import pytest + import tvm +import tvm.testing from tvm import tir -from tvm.script import tir as T from tvm.ir import Range +from tvm.script import tir as T @T.prim_func @@ -355,14 +357,33 @@ def test_access_of_decompose_reduction(): tvm.ir.assert_structural_equal(block.writes, ret[1]) +def test_buffer_access_with_let_binding(): + @T.prim_func + def func( + storage: T.Buffer((16, 16, 16), "float32"), + seq_slot_ids: T.Buffer((16,), "int32"), + history_slot_ids: T.Buffer((16,), "int32"), + output: T.Buffer((16, 16), "float32"), + ): + for i, s in T.grid(16, 16): + with T.block("copy"): + vi, vs = T.axis.remap("SS", [i, s]) + T.reads( + seq_slot_ids[vi], + history_slot_ids[vi], + storage[seq_slot_ids[vi], history_slot_ids[vi], vs], + ) + T.writes(output[vi, vs]) + seq_id: T.int32 = seq_slot_ids[vi] + history_id: T.int32 = history_slot_ids[vi] + output[vi, vs] = storage[seq_id, history_id, vs] + + block = func.body.block.body.body.body.block + buffer_var_map = {buf.data: buf for buf in func.buffer_map.values()} + ret = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(block.reads, ret[0]) + tvm.ir.assert_structural_equal(block.writes, ret[1]) + + if __name__ == "__main__": - test_block_access_region_detector() - test_opaque_block() - test_opaque_access() - test_opaque_access_with_tvm_access_ptr() - test_match_buffer() - test_access_in_if_then_else_func() - test_access_in_branch_func() - test_access_of_padding_pattern() - test_access_of_reduction() - test_access_of_decompose_reduction() + tvm.testing.main() From c5b4cbedea6e9576a34fb045834d499cac007831 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sun, 11 Feb 2024 12:05:12 +0800 Subject: [PATCH 2/2] lint --- src/tir/analysis/block_access_region_detector.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 577639b7982e..d8fcee859f03 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -25,6 +25,7 @@ #include #include #include + #include #include "../transforms/ir_utils.h"