Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/tir/schedule/ir_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,30 @@ bool TensorizeComparator::VisitExpr(const PrimExpr& n, const PrimExpr& other) {
return equal;
}

bool TensorizeComparator::VisitExpr_(const CallNode* op, const PrimExpr& other) {
const auto* rhs = other.as<CallNode>();
if (!rhs->op.same_as(op->op)) return false;
if (op->dtype.code() != rhs->dtype.code()) {
if (assert_mode_) {
std::ostringstream os;
os << "CallNode data type codes do not match: op->dtype.code()=" << op->dtype.code()
<< " vs rhs->dtype.code()=" << rhs->dtype.code();
EmitError(os.str());
}
return false;
}
if (!CompareArray(op->args, rhs->args, &TensorizeComparator::VisitExpr)) {
if (assert_mode_) {
std::ostringstream os;
os << "CallNode iter_values do not match: op->iter_values=" << op->args
<< " vs rhs->iter_values=" << rhs->args;
EmitError(os.str());
}
return false;
}
return true;
}

bool TensorizeComparator::VisitStmt_(const ForNode* op, const Stmt& other) {
const auto* rhs = other.as<ForNode>();
if (!DefEqual(op->loop_var, rhs->loop_var)) {
Expand Down
1 change: 1 addition & 0 deletions src/tir/schedule/ir_comparator.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class TensorizeComparator : public ExprComparator, public StmtComparator {
bool VisitExpr(const PrimExpr& n, const PrimExpr& other) override;
bool VisitStmt(const Stmt& n, const Stmt& other) override;

bool VisitExpr_(const CallNode* op, const PrimExpr& other) override;
bool VisitStmt_(const ForNode* op, const Stmt& other) override;
bool VisitStmt_(const SeqStmtNode* op, const Stmt& other) override;
bool VisitStmt_(const BufferStoreNode* op, const Stmt& other) override;
Expand Down
5 changes: 4 additions & 1 deletion src/tir/schedule/primitive/blockize_tensorize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <functional>

#include "../../transforms/simplify.h"
#include "../ir_comparator.h"
#include "../utils.h"

Expand Down Expand Up @@ -755,7 +756,9 @@ void Tensorize(ScheduleState self, const StmtSRef& sref, const TensorIntrin& int
<< GetRef<Stmt>(sref->stmt);
throw;
}
PrimFunc intrin_desc = intrin->desc;

arith::Analyzer analyzer;
PrimFunc intrin_desc = Simplify(intrin->desc, &analyzer);
PrimFunc intrin_impl = DeepCopy(intrin->impl);

int index_dtype_bits = -1;
Expand Down
8 changes: 8 additions & 0 deletions src/tir/transforms/simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
* \file simplify.cc
* \brief Statement simplifier based on analyzer
*/

#include "../../tir/transforms/simplify.h"

#include <tvm/arith/analyzer.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/analysis.h>
Expand Down Expand Up @@ -339,6 +342,11 @@ class StmtSimplifier : public IRMutatorWithAnalyzer {
} // namespace arith

namespace tir {

PrimFunc Simplify(PrimFunc func, arith::Analyzer* analyzer) {
return arith::StmtSimplifier::Apply(std::move(func), analyzer);
}

namespace transform {

Pass Simplify() {
Expand Down
9 changes: 4 additions & 5 deletions src/tir/transforms/simplify.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,16 @@
#define TVM_TIR_TRANSFORMS_SIMPLIFY_H_

#include <tvm/arith/analyzer.h>
#include <tvm/tir/stmt.h>
#include <tvm/tir/function.h>

namespace tvm {
namespace tir {

/* \brief Simplifies the statement
/* \brief Simplifies the prim func
*
* Applies the same behavior as the tir.transform.Simplify pass, but
* on a single statement, usable as a subroutine in other passes.
* Applies the same behavior as the tir.transform.Simplify pass.
*/
Stmt Simplify(Stmt stmt, arith::Analyzer* analyzer);
PrimFunc Simplify(PrimFunc stmt, arith::Analyzer* analyzer);

} // namespace tir
} // namespace tvm
Expand Down
118 changes: 118 additions & 0 deletions tests/python/tir-schedule/test_tir_schedule_tensorize.py
Original file line number Diff line number Diff line change
Expand Up @@ -836,6 +836,124 @@ def tensorized_matmul_int64_shape(
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_matmul_int64_shape)
verify_trace_roundtrip(sch=s, mod=matmul_int64_shape)

def _tir_packed_int_to_int_to_float(storage_nbit: int):
storage_dtype = "int" + str(storage_nbit)

def f_convert(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
assert val.dtype == storage_dtype
mask = tir.const((1 << nbit) - 1, "int32")
unextended = (val >> (pos.astype("int32") * tir.const(nbit, "int32"))) & mask
return tir.Cast(dtype, (unextended << tir.const(32 - nbit, "int32")) >> tir.const(32 - nbit, "int32"))

return f_convert

@T.prim_func
def decode_i4s_to_f16_desc(compressed: T.handle, decompressed: T.handle) -> None:
Compressed = T.match_buffer(
compressed,
[
1,
],
dtype="int32",
scope="local",
)
Decompressed = T.match_buffer(
decompressed,
[
8,
],
dtype="float16",
scope="local",
)

with T.block("root"):
T.reads(Compressed[0:1])
T.writes(Decompressed[0:8])
for i in T.grid(8):
with T.block("decode"):
vi = T.axis.remap("S", [i])
Decompressed[vi] = _tir_packed_int_to_int_to_float(32)(
4,
Compressed[vi // 8],
vi % 8,
dtype="float16",
)

@T.prim_func
def decode_i4s_to_f16_impl(compressed: T.handle, decompressed: T.handle) -> None:
Compressed = T.match_buffer(
compressed,
[
1,
],
dtype="int32",
scope="local",
)
Decompressed = T.match_buffer(
decompressed,
[
8,
],
dtype="float16",
scope="local",
)

with T.block("root"):
T.reads(Compressed[0:1])
T.writes(Decompressed[0:8])
T.call_extern(
"handle",
"test_decode_i4s_to_f16",
Compressed.data,
Decompressed.data,
8,
)

tir.TensorIntrin.register("test_decode_i4s_to_f16_intrin", decode_i4s_to_f16_desc, decode_i4s_to_f16_impl)

def test_tensorize_arith_simplification():
# fmt: off
@T.prim_func
def decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0, ax1 in T.grid(1, 8):
with T.block("B_decode_local"):
v0 = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1 = T.axis.spatial(16384, ax1_0 * 512 + ax1_1 * 8 + ax1)
T.reads(B_local[v0, v1 // 8])
T.writes(B_decode_local[v0, v1])
B_decode_local[v0, v1] = T.Cast("float16", T.shift_right(T.shift_left(T.bitwise_and(T.shift_right(B_local[v0, v1 // 8], v1 % 8 * 4), 15), 28), 28))

@T.prim_func
def tensorized_decode_i4s_to_int32_to_f16():
B_decode_local = T.alloc_buffer((16384, 16384), "float16", scope="local")
B_local = T.alloc_buffer((16384, 2048), "int32", scope="local")
for ax0_0 in T.thread_binding(8192, thread="blockIdx.x"):
for ax0_1 in T.thread_binding(2, thread="threadIdx.y"):
for ax1_0 in range(32):
for ax1_1 in T.thread_binding(64, thread="threadIdx.x"):
for ax0 in range(1):
with T.block("B_decode_local_o"):
v0_o = T.axis.spatial(16384, ax0_0 * 2 + ax0_1 + ax0)
v1_o = T.axis.spatial(2048, ax1_0 * 64 + ax1_1)
T.reads(B_local[v0_o, v1_o])
T.writes(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8])
Compressed = T.match_buffer(B_local[v0_o, v1_o], (1,), "int32", scope="local")
Decompressed = T.match_buffer(B_decode_local[v0_o, v1_o * 8:v1_o * 8 + 8], (8,), "float16", scope="local")
T.call_extern("handle", "test_decode_i4s_to_f16", Compressed.data, Decompressed.data, 8)

s = tir.Schedule(decode_i4s_to_int32_to_f16, debug_mask="all")
update = s.get_block("B_decode_local")
ii = s.get_loops(update)[-1]
s.tensorize(ii, "test_decode_i4s_to_f16_intrin")
assert_structural_equal_ignore_global_symbol(s.mod["main"], tensorized_decode_i4s_to_int32_to_f16)
verify_trace_roundtrip(sch=s, mod=decode_i4s_to_int32_to_f16)


if __name__ == "__main__":
tvm.testing.main()