Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 59 additions & 60 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,79 +87,78 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
}

Stmt VisitStmt_(const AllocateNode* op) final {
Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (alloc->dtype == DataType::Bool()) {
auto writer = alloc.CopyOnWrite();
writer->dtype = DataType::Int(8);
}

if (alloc->extents.size() == 1) {
// No flattening required for buffers that are already flat

// TODO(rfc-70): Keep the DeclBuffer node as-is. Stripping it
// out in the current implementation as not all lowering passes
// support DeclBuffer.
if (auto* decl_buffer = alloc->body.as<DeclBufferNode>()) {
alloc.CopyOnWrite()->body = std::move(decl_buffer->body);
// Determine the flattened extents first, before stripping of
// DeclBuffer.
auto new_extents = [&]() -> Array<PrimExpr> {
if (op->extents.size() == 1) {
// No flattening required for buffers that are already flat
return op->extents;
}

return std::move(alloc);
}

if (auto* decl_buffer = alloc->body.as<DeclBufferNode>();
decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) {
// N-d buffer, use the DeclBuffer inside to determine how it
// should be flattened.
auto& buffer = decl_buffer->buffer;
bool matching_buffer = [&]() {
if (alloc->dtype != buffer->dtype) {
return false;
}
if (alloc->extents.size() != buffer->shape.size()) {
return false;
}
ExprDeepEqual expr_equal;
for (size_t i = 0; i < alloc->extents.size(); i++) {
if (!expr_equal(alloc->extents[i], buffer->shape[i])) {
if (auto* decl_buffer = op->body.as<DeclBufferNode>()) {
// N-d buffer, use the DeclBuffer inside to determine how it
// should be flattened.
auto& buffer = decl_buffer->buffer;
bool matching_buffer = [&]() {
if (!decl_buffer->buffer->data.same_as(op->buffer_var)) {
return false;
}
if (op->dtype != buffer->dtype) {
return false;
}
if (op->extents.size() != buffer->shape.size()) {
return false;
}
ExprDeepEqual expr_equal;
for (size_t i = 0; i < op->extents.size(); i++) {
if (!expr_equal(op->extents[i], buffer->shape[i])) {
return false;
}
}
return true;
}();

if (matching_buffer) {
Buffer flattened = GetFlattenedBuffer(buffer);
return flattened->shape;
} else {
ICHECK(decl_buffer->buffer->axis_separators.empty())
<< "DeclBuffer node doesn't match Allocate extents, but also shouldn't be "
"flattened to 1-d physical memory";
}
return true;
}();

if (matching_buffer) {
Buffer flattened = GetFlattenedBuffer(buffer);

auto n = alloc.CopyOnWrite();
// TODO(rfc-70): Update the DeclBuffer node instead of
// stripping it out. Stripping it out in the current
// implementation as not all lowering passes support
// DeclBuffer.
//
// n->body = DeclBuffer(flattened, std::move(decl_buffer->body));
n->body = std::move(decl_buffer->body);
n->extents = flattened->shape;
return std::move(alloc);
} else {
ICHECK(decl_buffer->buffer->axis_separators.empty())
<< "DeclBuffer node doesn't match Allocate extents, but also shouldn't be "
"flattened to 1-d physical memory";
}

// Fallback, this is an allocation without a matching DeclBuffer
PrimExpr flat_extent = 1;
for (const auto& dim : op->extents) {
flat_extent *= dim;
}
return {flat_extent};
}();

Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));

// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (alloc->dtype == DataType::Bool()) {
alloc.CopyOnWrite()->dtype = DataType::Int(8);
}

// Fallback, this is an allocation without a matching DeclBuffer
PrimExpr flat_extent = 1;
for (const auto& dim : alloc->extents) {
flat_extent *= dim;
if (!new_extents.same_as(alloc->extents)) {
alloc.CopyOnWrite()->extents = new_extents;
}

auto n = alloc.CopyOnWrite();
n->extents = {flat_extent};
return std::move(alloc);
}

Stmt VisitStmt_(const DeclBufferNode* op) final {
// TODO(rfc-70): Update the DeclBuffer node instead of
// stripping it out. Stripping it out in the current
// implementation as not all lowering passes support
// DeclBuffer.
return VisitStmt(op->body);
}

Buffer GetFlattenedBuffer(Buffer buf) {
auto it = buffer_remap_.find(buf);
if (it != buffer_remap_.end()) {
Expand Down
9 changes: 9 additions & 0 deletions src/tir/transforms/merge_dynamic_shared_memory_allocations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,14 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(op);
}

Stmt VisitStmt_(const DeclBufferNode* op) final {
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
if (auto new_buf = GetUpdatedBuffer(node->buffer); !new_buf.same_as(node->buffer)) {
node.CopyOnWrite()->buffer = new_buf;
}
return std::move(node);
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return VisitBufferAccess(std::move(node));
Expand Down Expand Up @@ -336,6 +344,7 @@ class DynamicSharedMemoryRewriter : public StmtExprMutator {

if (IsDynamicSharedMemory(buffer->data)) {
ICHECK_EQ(buffer->shape.size(), 1)
<< "Buffer " << buffer << " has shape " << buffer->shape << ". "
<< "MergeDynamicSharedMemoryAllocations expects flat memory buffers, "
<< "and is to be run after "
<< "StorageFlatten (TE schedules) or FlattenBuffer (TIR schedules)";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm import te
from tvm.driver.build_module import schedule_to_module
from tvm.topi.math import cast
from tvm.script import tir as T


def run_passes(sch, args):
Expand Down Expand Up @@ -326,8 +327,132 @@ def check_target(target):
check_target(target)


class TestMatmul(tvm.testing.CompareBeforeAfter):
"""Shared allocations should be merged, preserving DeclBuffer if present

This test uses a matmul PrimFunc adapted from from
test_matmul_dyn_shared, using either `T.Buffer` (Allocate without
DeclBuffer) or `T.decl_buffer` (Allocate followed by DeclBuffer)
for the replaced allocations.
"""

transform = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()

use_decl_buffer = tvm.testing.parameter(by_dict={"t_buffer": False, "decl_buffer": True})

@tvm.testing.fixture
def buffer_func(self, use_decl_buffer):
if use_decl_buffer:
return T.decl_buffer
else:
return T.Buffer

@tvm.testing.fixture
def before(self, buffer_func):
@T.prim_func
def func(
A: T.Buffer((1024, 1024), "float16"),
B: T.Buffer((1024, 1024), "float16"),
matmul: T.Buffer((1024, 1024), "float32"),
):
A_flat = T.Buffer(1048576, "float16", data=A.data)
B_flat = T.Buffer(1048576, "float16", data=B.data)
matmul_flat = T.Buffer(1048576, data=matmul.data)

threadIdx_x = T.launch_thread("threadIdx.x", 16)
C_local_data = T.allocate([1], "float32", "local")
C_local = T.Buffer(1, data=C_local_data, scope="local")

A_sh_data = T.allocate([256], "float16", "shared.dyn")
A_sh = buffer_func(256, "float16", data=A_sh_data, scope="shared.dyn")
B_sh_data = T.allocate([256], "float16", "shared.dyn")
B_sh = buffer_func(256, "float16", data=B_sh_data, scope="shared.dyn")
C_sh_data = T.allocate([256], "float32", "shared.dyn")
C_sh = buffer_func(256, "float32", data=C_sh_data, scope="shared.dyn")

threadIdx_y = T.launch_thread("threadIdx.y", 16)
blockIdx_x = T.launch_thread("blockIdx.x", 64)
blockIdx_y = T.launch_thread("blockIdx.y", 64)

C_local[0] = T.float32(0)
for i in range(64):

A_sh[threadIdx_y * 16 + threadIdx_x] = A_flat[
blockIdx_y * 16384 + threadIdx_y * 1024 + i * 16 + threadIdx_x
]

B_sh[threadIdx_y * 16 + threadIdx_x] = B_flat[
i * 16384 + threadIdx_y * 1024 + blockIdx_x * 16 + threadIdx_x
]
T.tvm_storage_sync("shared")
for k in range(16):
C_local[0] = C_local[0] + T.Cast(
"float32",
A_sh[threadIdx_y * 16 + k] * B_sh[k * 16 + threadIdx_x],
)
T.tvm_storage_sync("shared")

C_sh[threadIdx_y * 16 + threadIdx_x] = C_local[0]
T.tvm_storage_sync("shared.dyn")

matmul_flat[
blockIdx_y * 16384 + threadIdx_y * 1024 + blockIdx_x * 16 + threadIdx_x
] = C_sh[threadIdx_y * 16 + threadIdx_x]

return func

@tvm.testing.fixture
def expected(self, buffer_func):
@T.prim_func
def func(
A: T.Buffer((1024, 1024), "float16"),
B: T.Buffer((1024, 1024), "float16"),
matmul: T.Buffer((1024, 1024), "float32"),
):
A_flat = T.Buffer(1048576, "float16", data=A.data)
B_flat = T.Buffer(1048576, "float16", data=B.data)
matmul_flat = T.Buffer(1048576, data=matmul.data)

threadIdx_x = T.launch_thread("threadIdx.x", 16)

buf_dyn_shmem = T.allocate([1024], "uint8", "shared.dyn")

C_local_data = T.allocate([1], "float32", "local")
C_local = T.Buffer(1, data=C_local_data, scope="local")

A_sh = buffer_func(256, "float16", data=buf_dyn_shmem, scope="shared.dyn")
B_sh = buffer_func(256, "float16", data=buf_dyn_shmem, scope="shared.dyn")
C_sh = buffer_func(256, "float32", data=buf_dyn_shmem, scope="shared.dyn")

threadIdx_y = T.launch_thread("threadIdx.y", 16)
blockIdx_x = T.launch_thread("blockIdx.x", 64)
blockIdx_y = T.launch_thread("blockIdx.y", 64)

C_local[0] = T.float32(0)
for i in range(64):
A_sh[threadIdx_y * 16 + threadIdx_x + 256] = A_flat[
blockIdx_y * 16384 + threadIdx_y * 1024 + i * 16 + threadIdx_x
]
B_sh[threadIdx_y * 16 + threadIdx_x] = B_flat[
i * 16384 + threadIdx_y * 1024 + blockIdx_x * 16 + threadIdx_x
]
T.tvm_storage_sync("shared")
for k in range(16):
C_local[0] = C_local[0] + T.Cast(
"float32",
A_sh[threadIdx_y * 16 + k + 256] * B_sh[k * 16 + threadIdx_x],
)
T.tvm_storage_sync("shared")

C_sh[threadIdx_y * 16 + threadIdx_x] = C_local[0]
T.tvm_storage_sync("shared.dyn")

matmul_flat[
blockIdx_y * 16384 + threadIdx_y * 1024 + blockIdx_x * 16 + threadIdx_x
] = C_sh[threadIdx_y * 16 + threadIdx_x]

return func


if __name__ == "__main__":
test_matmul_dyn_shared()
test_dyn_shared_vectorized_store()
test_dyn_shared_reuse_and_merge()
test_dyn_shared_more_dtype()
tvm.testing.main()