From d0100c37d2fe63aacfcfaf6fe7368ba6f2357349 Mon Sep 17 00:00:00 2001 From: LeiWang Date: Tue, 13 Feb 2024 09:15:28 -0400 Subject: [PATCH 1/8] support tensorize with simplified and call expr --- src/tir/schedule/ir_comparator.cc | 24 ++++ src/tir/schedule/ir_comparator.h | 1 + .../schedule/primitive/blockize_tensorize.cc | 27 +++- src/tir/transforms/simplify.cc | 23 ++++ .../test_tir_schedule_tensorize.py | 118 ++++++++++++++++++ 5 files changed, 192 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/ir_comparator.cc b/src/tir/schedule/ir_comparator.cc index 5353a051a60a..00e573eaf6e4 100644 --- a/src/tir/schedule/ir_comparator.cc +++ b/src/tir/schedule/ir_comparator.cc @@ -83,6 +83,30 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) { return equal; } +bool TensorizeComparator::VisitExpr_(const CallNode* op, const PrimExpr& other) { + const auto* rhs = other.as(); + if (!rhs->op.same_as(op->op)) return false; + if (op->dtype.code() != rhs->dtype.code()) { + if (assert_mode_) { + std::ostringstream os; + os << "CallNode data type codes do not match: op->dtype.code()=" << op->dtype.code() + << " vs rhs->dtype.code()=" << rhs->dtype.code(); + EmitError(os.str()); + } + return false; + } + if (!CompareArray(op->args, rhs->args, &TensorizeComparator::VisitExpr)) { + if (assert_mode_) { + std::ostringstream os; + os << "CallNode iter_values do not match: op->iter_values=" << op->args + << " vs rhs->iter_values=" << rhs->args; + EmitError(os.str()); + } + return false; + } + return true; +} + bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) { const auto* rhs = other.as(); if (!DefEqual(op->loop_var, rhs->loop_var)) { diff --git a/src/tir/schedule/ir_comparator.h b/src/tir/schedule/ir_comparator.h index debf0f946e28..f86dbd358391 100644 --- a/src/tir/schedule/ir_comparator.h +++ b/src/tir/schedule/ir_comparator.h @@ -46,6 +46,7 @@ class TensorizeComparator : public ExprComparator, public StmtComparator { bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override; bool VisitStmt(const Stmt& n, const Stmt& other) override; + bool VisitExpr_(const CallNode* op, const PrimExpr& other) override; bool VisitStmt_(const ForNode* op, const Stmt& other) override; bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override; bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override; diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index e8445a510147..c6fa4c9c7997 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -20,6 +20,7 @@ #include +#include "../../transforms/simplify.h" #include "../ir_comparator.h" #include "../utils.h" @@ -738,6 +739,28 @@ StmtSRef Blockize(ScheduleState self, const Array& blocks, bool preser return result; } +class TensorIntrinSimplifier : public arith::IRMutatorWithAnalyzer { + public: + static PrimFunc Apply(PrimFunc func, arith::Analyzer* analyzer) { + TensorIntrinSimplifier simplifier(analyzer); + func.CopyOnWrite()->body = simplifier(func->body); + return func; + } + + private: + explicit TensorIntrinSimplifier(arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} + + using Parent = IRMutatorWithAnalyzer; + using Parent::VisitExpr_; + using Parent::VisitStmt; + using Parent::VisitStmt_; + + Stmt VisitStmt_(const BlockNode* block) final { + Block sref = GetRef(block); + return tvm::tir::Simplify(sref, analyzer_); + } +}; + void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin, bool preserve_unit_iters) { // Step 1: Blockize the subtree rooted at the given loop if needed @@ -755,7 +778,9 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int << GetRef(sref->stmt); throw; } - PrimFunc intrin_desc = intrin->desc; + + arith::Analyzer analyzer; + PrimFunc intrin_desc = TensorIntrinSimplifier::Apply(intrin->desc, &analyzer); PrimFunc intrin_impl = DeepCopy(intrin->impl); int index_dtype_bits = -1; diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 44d64df63d9f..df8a19e5ac0a 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -31,6 +31,7 @@ #include +#include "simplify.h" #include "../../arith/ir_mutator_with_analyzer.h" #include "../../tir/analysis/control_flow_graph.h" #include "../../tir/analysis/var_use_def_analysis.h" @@ -162,6 +163,23 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return func; } + static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional config_opt = NullOpt) { + auto config = config_opt.value_or(AttrsWithDefaultValues()); + analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); + + std::optional touch_pattern = std::nullopt; + if (config->propagate_knowns_to_prove_conditional || + config->propagate_knowns_to_simplify_expressions) { + touch_pattern = ControlFlowGraph(stmt); + } + + std::unordered_set used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt); + StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), + std::move(used_in_buffer_def)); + stmt = simplifier.Simplify(std::move(stmt)); + return stmt; + } + private: explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, std::optional touch_pattern, @@ -339,6 +357,11 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { } // namespace arith namespace tir { + +Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) { + return arith::StmtSimplifier::Apply(std::move(stmt), analyzer); +} + namespace transform { Pass Simplify() { diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize.py b/tests/python/tir-schedule/test_tir_schedule_tensorize.py index 1891914bc06f..d28674fa803a 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize.py @@ -836,6 +836,124 @@ def tensorized_matmul_int64_shape( assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape) verify_trace_roundtrip(sch=s, mod=matmul_int64_shape) +def _tir_packed_int_to_int_to_float(storage_nbit: int): + storage_dtype = "int" + str(storage_nbit) + + def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): + assert val.dtype == storage_dtype + mask = tir.const((1 << nbit) - 1, "int32") + unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask + return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32")) + + return f_convert + +@T.prim_func +def decode_i4s_to_f16_desc(compressed: T.handle, decompressed: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + 1, + ], + dtype="int32", + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + 8, + ], + dtype="float16", + scope="local", + ) + + with T.block("root"): + T.reads(Compressed[0:1]) + T.writes(Decompressed[0:8]) + for i in T.grid(8): + with T.block("decode"): + vi = T.axis.remap("S", [i]) + Decompressed[vi] = _tir_packed_int_to_int_to_float(32)( + 4, + Compressed[vi // 8], + vi % 8, + dtype="float16", + ) + +@T.prim_func +def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None: + Compressed = T.match_buffer( + compressed, + [ + 1, + ], + dtype="int32", + scope="local", + ) + Decompressed = T.match_buffer( + decompressed, + [ + 8, + ], + dtype="float16", + scope="local", + ) + + with T.block("root"): + T.reads(Compressed[0:1]) + T.writes(Decompressed[0:8]) + T.call_extern( + "handle", + "test_decode_i4s_to_f16", + Compressed.data, + Decompressed.data, + 8, + ) + +tir.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl) + +def test_tensorize_arith_simplification(): + # fmt: off + @T.prim_func + def decode_i4s_to_int32_to_f16(): + B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") + B_local = T.alloc_buffer((16384, 2048), "int32", scope="local") + for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"): + for ax0_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax1_0 in range(32): + for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0, ax1 in T.grid(1, 8): + with T.block("B_decode_local"): + v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) + v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1) + T.reads(B_local[v0, v1 // 8]) + T.writes(B_decode_local[v0, v1]) + B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28)) + + @T.prim_func + def tensorized_decode_i4s_to_int32_to_f16(): + B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") + B_local = T.alloc_buffer((16384, 2048), "int32", scope="local") + for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"): + for ax0_1 in T.thread_binding(2, thread="threadIdx.y"): + for ax1_0 in range(32): + for ax1_1 in T.thread_binding(64, thread="threadIdx.x"): + for ax0 in range(1): + with T.block("B_decode_local_o"): + v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0) + v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1) + T.reads(B_local[v0_o, v1_o]) + T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8]) + Compressed = T.match_buffer(B_local[v0_o, v1_o], (1,), "int32", scope="local") + Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local") + T.call_extern("handle", "test_decode_i4s_to_f16", Compressed.data, Decompressed.data, 8) + + s = tir.Schedule(decode_i4s_to_int32_to_f16, debug_mask="all") + update = s.get_block("B_decode_local") + ii = s.get_loops(update)[-1] + s.tensorize(ii, "test_decode_i4s_to_f16_intrin") + assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_decode_i4s_to_int32_to_f16) + verify_trace_roundtrip(sch=s, mod=decode_i4s_to_int32_to_f16) + if __name__ == "__main__": tvm.testing.main() From 834f20438261ae09a842e24d821454cfb9fe887a Mon Sep 17 00:00:00 2001 From: LeiWang Date: Fri, 16 Feb 2024 00:01:53 -0400 Subject: [PATCH 2/8] replace stmt simplifier with primfunc simplifier --- .../schedule/primitive/blockize_tensorize.cc | 24 +------------------ src/tir/transforms/simplify.cc | 24 ++++--------------- src/tir/transforms/simplify.h | 9 ++++--- 3 files changed, 9 insertions(+), 48 deletions(-) diff --git a/src/tir/schedule/primitive/blockize_tensorize.cc b/src/tir/schedule/primitive/blockize_tensorize.cc index c6fa4c9c7997..c057a3d4fe72 100644 --- a/src/tir/schedule/primitive/blockize_tensorize.cc +++ b/src/tir/schedule/primitive/blockize_tensorize.cc @@ -739,28 +739,6 @@ StmtSRef Blockize(ScheduleState self, const Array& blocks, bool preser return result; } -class TensorIntrinSimplifier : public arith::IRMutatorWithAnalyzer { - public: - static PrimFunc Apply(PrimFunc func, arith::Analyzer* analyzer) { - TensorIntrinSimplifier simplifier(analyzer); - func.CopyOnWrite()->body = simplifier(func->body); - return func; - } - - private: - explicit TensorIntrinSimplifier(arith::Analyzer* analyzer) : IRMutatorWithAnalyzer(analyzer) {} - - using Parent = IRMutatorWithAnalyzer; - using Parent::VisitExpr_; - using Parent::VisitStmt; - using Parent::VisitStmt_; - - Stmt VisitStmt_(const BlockNode* block) final { - Block sref = GetRef(block); - return tvm::tir::Simplify(sref, analyzer_); - } -}; - void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& intrin, bool preserve_unit_iters) { // Step 1: Blockize the subtree rooted at the given loop if needed @@ -780,7 +758,7 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int } arith::Analyzer analyzer; - PrimFunc intrin_desc = TensorIntrinSimplifier::Apply(intrin->desc, &analyzer); + PrimFunc intrin_desc = Simplify(intrin->desc, &analyzer); PrimFunc intrin_impl = DeepCopy(intrin->impl); int index_dtype_bits = -1; diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index df8a19e5ac0a..85d976819db2 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -21,6 +21,8 @@ * \file simplify.cc * \brief Statement simplifier based on analyzer */ +#include "simplify.h" + #include #include #include @@ -31,7 +33,6 @@ #include -#include "simplify.h" #include "../../arith/ir_mutator_with_analyzer.h" #include "../../tir/analysis/control_flow_graph.h" #include "../../tir/analysis/var_use_def_analysis.h" @@ -163,23 +164,6 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { return func; } - static Stmt Apply(Stmt stmt, Analyzer* analyzer, Optional config_opt = NullOpt) { - auto config = config_opt.value_or(AttrsWithDefaultValues()); - analyzer->rewrite_simplify.SetEnabledExtensions(config->GetEnabledExtensions()); - - std::optional touch_pattern = std::nullopt; - if (config->propagate_knowns_to_prove_conditional || - config->propagate_knowns_to_simplify_expressions) { - touch_pattern = ControlFlowGraph(stmt); - } - - std::unordered_set used_in_buffer_def = CollectVarsUsedInBufferDefinition(stmt); - StmtSimplifier simplifier(analyzer, config, std::move(touch_pattern), - std::move(used_in_buffer_def)); - stmt = simplifier.Simplify(std::move(stmt)); - return stmt; - } - private: explicit StmtSimplifier(Analyzer* analyzer, SimplifyConfig config, std::optional touch_pattern, @@ -358,8 +342,8 @@ class StmtSimplifier : public IRMutatorWithAnalyzer { namespace tir { -Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer) { - return arith::StmtSimplifier::Apply(std::move(stmt), analyzer); +PrimFunc Simplify(PrimFunc func, arith::Analyzer* analyzer) { + return arith::StmtSimplifier::Apply(std::move(func), analyzer); } namespace transform { diff --git a/src/tir/transforms/simplify.h b/src/tir/transforms/simplify.h index 43afc5e48dcb..25c9dd5791d9 100644 --- a/src/tir/transforms/simplify.h +++ b/src/tir/transforms/simplify.h @@ -25,17 +25,16 @@ #define TVM_TIR_TRANSFORMS_SIMPLIFY_H_ #include -#include +#include namespace tvm { namespace tir { -/* \brief Simplifies the statement +/* \brief Simplifies the prim func * - * Applies the same behavior as the tir.transform.Simplify pass, but - * on a single statement, usable as a subroutine in other passes. + * Applies the same behavior as the tir.transform.Simplify pass. */ -Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer); +PrimFunc Simplify(PrimFunc stmt, arith::Analyzer* analyzer); } // namespace tir } // namespace tvm From cd1c24cfe88828c4a5ecb513fd5a82b4884cec77 Mon Sep 17 00:00:00 2001 From: LeiWang Date: Sun, 18 Feb 2024 09:19:13 -0400 Subject: [PATCH 3/8] lint fix --- tests/python/tir-schedule/test_tir_schedule_tensorize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize.py b/tests/python/tir-schedule/test_tir_schedule_tensorize.py index d28674fa803a..fe8847272952 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize.py @@ -928,7 +928,7 @@ def decode_i4s_to_int32_to_f16(): T.reads(B_local[v0, v1 // 8]) T.writes(B_decode_local[v0, v1]) B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28)) - + @T.prim_func def tensorized_decode_i4s_to_int32_to_f16(): B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") From 42da897a547ccd1013a0a25d52c9720c9d0aa5d1 Mon Sep 17 00:00:00 2001 From: LeiWang Date: Sun, 18 Feb 2024 11:13:27 -0400 Subject: [PATCH 4/8] lint:remove white space --- tests/python/tir-schedule/test_tir_schedule_tensorize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize.py b/tests/python/tir-schedule/test_tir_schedule_tensorize.py index fe8847272952..14f533f53de2 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize.py @@ -928,7 +928,7 @@ def decode_i4s_to_int32_to_f16(): T.reads(B_local[v0, v1 // 8]) T.writes(B_decode_local[v0, v1]) B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28)) - + @T.prim_func def tensorized_decode_i4s_to_int32_to_f16(): B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") From 4721d7e6e5272894acb9a1e69687f45bd63fa4a4 Mon Sep 17 00:00:00 2001 From: LeiWang Date: Sun, 18 Feb 2024 21:52:38 -0400 Subject: [PATCH 5/8] lint: remove white space --- tests/python/tir-schedule/test_tir_schedule_tensorize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/tir-schedule/test_tir_schedule_tensorize.py b/tests/python/tir-schedule/test_tir_schedule_tensorize.py index 14f533f53de2..789d6be3ad0b 100644 --- a/tests/python/tir-schedule/test_tir_schedule_tensorize.py +++ b/tests/python/tir-schedule/test_tir_schedule_tensorize.py @@ -928,7 +928,7 @@ def decode_i4s_to_int32_to_f16(): T.reads(B_local[v0, v1 // 8]) T.writes(B_decode_local[v0, v1]) B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28)) - + @T.prim_func def tensorized_decode_i4s_to_int32_to_f16(): B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local") From 0cd107c1b8766fa68baaa134c77a0bfafd768a3e Mon Sep 17 00:00:00 2001 From: LeiWang Date: Mon, 19 Feb 2024 09:25:55 -0400 Subject: [PATCH 6/8] cpp lint fix --- src/tir/transforms/simplify.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 85d976819db2..990c92d887aa 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -21,8 +21,6 @@ * \file simplify.cc * \brief Statement simplifier based on analyzer */ -#include "simplify.h" - #include #include #include @@ -33,6 +31,7 @@ #include +#include "../../tir/transforms/simplify.h" #include "../../arith/ir_mutator_with_analyzer.h" #include "../../tir/analysis/control_flow_graph.h" #include "../../tir/analysis/var_use_def_analysis.h" From e1306a5a02eb7bd035b2053cbed18184e6e34a3b Mon Sep 17 00:00:00 2001 From: LeiWang Date: Mon, 19 Feb 2024 11:12:19 -0400 Subject: [PATCH 7/8] lint: resolve include --- src/tir/transforms/simplify.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 990c92d887aa..62936abe5dde 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -21,6 +21,7 @@ * \file simplify.cc * \brief Statement simplifier based on analyzer */ + #include #include #include @@ -31,10 +32,10 @@ #include -#include "../../tir/transforms/simplify.h" #include "../../arith/ir_mutator_with_analyzer.h" #include "../../tir/analysis/control_flow_graph.h" #include "../../tir/analysis/var_use_def_analysis.h" +#include "../../tir/transforms/simplify.h" namespace tvm { namespace arith { From 8fc861b75abab16a7eaf2ed5e47ff117ac2b62b6 Mon Sep 17 00:00:00 2001 From: LeiWang Date: Tue, 20 Feb 2024 08:52:31 -0400 Subject: [PATCH 8/8] clang format lint fix --- src/tir/transforms/simplify.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/simplify.cc b/src/tir/transforms/simplify.cc index 62936abe5dde..f518c61bc676 100644 --- a/src/tir/transforms/simplify.cc +++ b/src/tir/transforms/simplify.cc @@ -22,6 +22,8 @@ * \brief Statement simplifier based on analyzer */ +#include "../../tir/transforms/simplify.h" + #include #include #include @@ -35,7 +37,6 @@ #include "../../arith/ir_mutator_with_analyzer.h" #include "../../tir/analysis/control_flow_graph.h" #include "../../tir/analysis/var_use_def_analysis.h" -#include "../../tir/transforms/simplify.h" namespace tvm { namespace arith {