Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/dlight/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from .analysis import (
BlockInfo,
IterInfo,
collect_vars_used_in_access_region,
collect_block_iter_vars_used_in_access_region,
collect_vars_used_in_prim_expr,
detect_dominant_read,
is_broadcast_epilogue,
normalize_prim_func,
Expand Down
28 changes: 20 additions & 8 deletions python/tvm/dlight/base/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,27 @@ def get_root_block(sch: Schedule, func_name: str = "main") -> BlockRV:
return sch.get_block(block.name_hint)


def collect_vars_used_in_access_region(region: List[ir.Range]) -> Set[tir.Var]:
"""Collect the variables used in the access region of a buffer region."""
tir_vars: Set[tir.Var] = set()
def collect_block_iter_vars_used_in_access_region(
block: tir.Block, region: List[ir.Range]
) -> Set[tir.Var]:
"""Collect the block iter variables used in the access region of a buffer region."""
tir_vars = set()
for expr in region:
assert expr.extent == 1
tir_vars |= collect_vars_used_in_prim_expr(expr.min)
tir_vars &= set(iter_var.var for iter_var in block.iter_vars)
return tir_vars


def collect_vars_used_in_prim_expr(expr: tir.PrimExpr) -> Set[tir.Var]:
"""Collect the variables used in the PrimExpr."""
tir_vars = set()

def _collect_tir_var(expr):
if isinstance(expr, tir.Var):
tir_vars.add(expr)

for expr in region:
assert expr.extent == 1
tir.stmt_functor.post_order_visit(expr.min, _collect_tir_var)
tir.stmt_functor.post_order_visit(expr, _collect_tir_var)
return tir_vars


Expand All @@ -227,7 +237,7 @@ def detect_dominant_read(block: tir.Block) -> tir.PrimExpr:
dominant_read = None
num_read_iters = -1
for buffer_region in block.reads:
tir_vars = collect_vars_used_in_access_region(buffer_region.region)
tir_vars = collect_block_iter_vars_used_in_access_region(block, buffer_region.region)
if num_read_iters < len(tir_vars):
num_read_iters = len(tir_vars)
dominant_read = buffer_region
Expand All @@ -247,7 +257,9 @@ def is_broadcast_epilogue(
for buffer_region in sch.get(epilogue).reads:
if buffer_region.buffer not in write_buffers:
continue
tir_vars = collect_vars_used_in_access_region(buffer_region.region)
tir_vars = collect_block_iter_vars_used_in_access_region(
sch.get(epilogue), buffer_region.region
)
if len(tir_vars) < len(epilogue_iters):
return True
return False
32 changes: 22 additions & 10 deletions python/tvm/dlight/gpu/gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@

from ..base import (
BlockInfo,
collect_vars_used_in_access_region,
collect_block_iter_vars_used_in_access_region,
collect_vars_used_in_prim_expr,
detect_dominant_read,
is_broadcast_epilogue,
normalize_prim_func,
Expand Down Expand Up @@ -86,15 +87,19 @@ def is_gemv(sch: tir.Schedule, block_info: BlockInfo) -> Optional[List[tir.Buffe
conditions.append(len(block_stmt.reads) >= 2)
conditions.append(len(block_stmt.writes) == 1)
conditions.append(_get_reduction_expr(block_stmt) is not None)
conditions.append(len(collect_vars_used_in_access_region(block_stmt.writes[0].region)) > 0)
conditions.append(
len(collect_block_iter_vars_used_in_access_region(block_stmt, block_stmt.writes[0].region))
> 0
)
if not all(conditions):
return None

iter_num = len(block_stmt.iter_vars)
ret = [
read.buffer
for read in block_stmt.reads
if len(collect_vars_used_in_access_region(read.region)) < iter_num
if len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) < iter_num
and len(collect_block_iter_vars_used_in_access_region(block_stmt, read.region)) > 0
]
return ret if 0 < len(ret) < len(block_stmt.reads) else None

Expand All @@ -109,12 +114,19 @@ def normalize(
detect_dominant_read(block_stmt),
input_iters={i.var: i.dom for i in block_stmt.iter_vars},
)

buffers_use_vars = [collect_vars_used_in_access_region(buf.region) for buf in block_stmt.writes]
buffers_use_vars = [
collect_block_iter_vars_used_in_access_region(block_stmt, buf.region)
for buf in block_stmt.writes
]
buffers_use_vars.extend(
[collect_vars_used_in_access_region(buf.region) for buf in block_stmt.reads]
[
collect_block_iter_vars_used_in_access_region(block_stmt, buf.region)
for buf in block_stmt.reads
]
)
if access.base != 0:
if collect_vars_used_in_prim_expr(access.base) & set(
iter_var.var for iter_var in block_stmt.iter_vars
):
return None
iter_to_info = {i.var: i for i in block_info.iters}
batch_loops, s_loops, r_loops, c_loops = [], [], [], []
Expand Down Expand Up @@ -420,15 +432,15 @@ def apply(
elif target.kind.name == "metal":
# Note that the following tile size is tuned on M2 Ultra for 7B
TAG_S, TAG_R = "threadIdx.x", "threadIdx.y"
VEC_C = 4
VEC_C = 1
LOAD_V_SHARED = False
LOAD_V_VEC = -1
UNROLL = 256
if isinstance(len_S, int):
if len_S > len_R:
TS, TR = 1, 64
TS, TR = 4, 16
else:
TS, TR = 1, 256
TS, TR = 2, 64
elif target.kind.name == "rocm":
VEC_C = 4
LOAD_V_SHARED = True
Expand Down
68 changes: 57 additions & 11 deletions python/tvm/dlight/gpu/reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
):
return None
# Step 2. Normalize the block, merge spatial and reduction iters
is_inner_reduction, c_factor = self._normalize(
is_inner_reduction, c_factor, loop_order, s_split_index = self._normalize(
sch,
block_info,
arith.normalize_to_iter_sum(
Expand All @@ -97,9 +97,13 @@ def apply( # pylint: disable=too-many-locals,too-many-branches,too-many-return-
return None
# Step 3. Do the scheduling
if is_inner_reduction:
self._sch_inner_reduction(sch, target, block, c_factor, epilogue)
self._sch_inner_reduction(
sch, target, block, c_factor, epilogue, loop_order, s_split_index
)
else:
self._sch_inner_spatial(sch, target, block, block_info, c_factor, epilogue)
self._sch_inner_spatial(
sch, target, block, block_info, c_factor, epilogue, loop_order, s_split_index
)
return sch

def _normalize( # pylint: disable=too-many-branches
Expand All @@ -112,6 +116,7 @@ def _normalize( # pylint: disable=too-many-branches
return None, None
iter_to_info = {i.var: i for i in block_info.iters}
s_loops, r_loops, c_loops, c_factor = [], [], [], None
s_split_loop, s_split_index = None, None
for split_expr in access.args:
var = split_expr.source.source
info = iter_to_info.pop(var)
Expand All @@ -120,6 +125,8 @@ def _normalize( # pylint: disable=too-many-branches
if split_expr.lower_factor > 1:
if c_loops:
return None, None
s_split_loop = loop
s_split_index = len(s_loops)
loop, c_loop = sch.split(loop, factors=[None, split_expr.lower_factor])
c_loops.append(c_loop)
if not is_inner_reduction:
Expand All @@ -135,6 +142,22 @@ def _normalize( # pylint: disable=too-many-branches
s_loops.append(info.loop_rv)
else:
return None, None

loop_order = {}
s_block_var_loops = []
for i in block_info.iters:
if i.loop_rv in s_loops or i.loop_rv == s_split_loop:
s_block_var_loops.append(i.loop_rv)

for i in range(len(s_block_var_loops)):
for j in range(len(s_loops)):
if s_block_var_loops[i] == s_loops[j]:
loop_order[i] = j
break
if s_block_var_loops[i] == s_split_loop:
loop_order[i] = s_split_index
break

assert s_loops
assert r_loops
if len(s_loops) != len([i for i in block_info.iters if i.kind == "S"]):
Expand All @@ -144,7 +167,7 @@ def _normalize( # pylint: disable=too-many-branches
sch.reorder(*s_loops, *r_loops, *c_loops)
sch.fuse(*s_loops)
sch.fuse(*r_loops)
return is_inner_reduction, c_factor
return is_inner_reduction, c_factor, loop_order, s_split_index

def _sch_inner_reduction( # pylint: disable=too-many-arguments
self,
Expand All @@ -153,6 +176,8 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments
block: tir.schedule.BlockRV,
unroll_spatial_factor: Optional[int],
epilogue_info: Optional[BlockInfo],
loop_order,
s_split_index,
):
# pylint: disable=invalid-name
_, r, _ = sch.get_loops(block)
Expand All @@ -174,11 +199,20 @@ def _sch_inner_reduction( # pylint: disable=too-many-arguments
# Schedule the write back block
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
_, tx, *s = sch.get_loops(block)
s = sch.fuse(*s)
sch.reorder(s, tx)

if unroll_spatial_factor:
s, inner = sch.split(s, factors=[None, unroll_spatial_factor])
sch.reorder(s, tx, inner)
assert len(s) == len(loop_order)
new_order_s = [s[loop_order[i]] for i in range(len(s))]
sch.reorder(*new_order_s)
new_order_s[s_split_index], c = sch.split(
new_order_s[s_split_index], factors=[None, unroll_spatial_factor]
)
sch.reorder(*new_order_s, c)
s = sch.fuse(*new_order_s)
sch.reorder(s, tx, c)
else:
s = sch.fuse(*s)
sch.reorder(s, tx)
sch.bind(tx, "threadIdx.x")
# Schedule epilogue
if epilogue_info is not None:
Expand All @@ -201,6 +235,8 @@ def _sch_inner_spatial(
block_info: BlockInfo,
unroll_spatial_factor: Optional[int],
epilogue_info: Optional[BlockInfo],
loop_order,
s_split_index,
):
# pylint: disable=invalid-name
s, r, _ = sch.get_loops(block)
Expand All @@ -226,12 +262,22 @@ def _sch_inner_spatial(
# Schedule the write back block
sch.reverse_compute_at(block, bx, preserve_unit_loops=True)
_, r, *s = sch.get_loops(block)
s = sch.fuse(*s)
sch.reorder(s, r)
if unroll_spatial_factor:
s, _ = sch.split(s, factors=[None, unroll_spatial_factor])
assert len(s) == len(loop_order)
new_order_s = [s[loop_order[i]] for i in range(len(s))]
sch.reorder(*new_order_s)
new_order_s[s_split_index], c = sch.split(
new_order_s[s_split_index], factors=[None, unroll_spatial_factor]
)
sch.reorder(*new_order_s, c)
s = sch.fuse(*new_order_s)
sch.reorder(s, c, r)
else:
s = sch.fuse(*s)
sch.reorder(s, r)
sch.bind(s, "threadIdx.x")
sch.bind(r, "threadIdx.y")

# Schedule epilogue
if epilogue_info is not None:
epilogue = epilogue_info.block_rv
Expand Down
21 changes: 17 additions & 4 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -444,10 +444,22 @@ PrimExpr BlockBufferAccessSimplifier::VisitExpr_(const BufferLoadNode* op) {

/******** PrimFunc-level analysis and transformation ********/

void GetLeafBlocksHelper(Schedule sch, BlockRV cur_block_rv, Array<BlockRV>* leaf_blocks) {
Array<BlockRV> blocks = sch->GetChildBlocks(cur_block_rv);
if (blocks.empty()) {
leaf_blocks->push_back(cur_block_rv);
} else {
for (const BlockRV& block : blocks) {
GetLeafBlocksHelper(sch, block, leaf_blocks);
}
}
}

Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
BlockRV root_block = sch->GetBlock("root");
Array<BlockRV> blocks = sch->GetChildBlocks(root_block);
for (const BlockRV& block : blocks) {
Array<BlockRV> leaf_blocks;
GetLeafBlocksHelper(sch, root_block, &leaf_blocks);
for (const BlockRV& block : leaf_blocks) {
StmtSRef block_sref = sch->GetSRef(block);
Array<StmtSRef> loops = GetLoops(block_sref);
Array<PrimExpr> binds = GetBlockRealize(sch->state(), block_sref)->iter_values;
Expand All @@ -465,10 +477,11 @@ Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
}
}
}

Array<Array<LoopRV>> block_loops;
Array<Array<IterVar>> block_iters;
Array<IntImm> block_is_reduction;
for (const BlockRV& block : blocks) {
for (const BlockRV& block : leaf_blocks) {
Array<IterVar> iters = sch->Get(block)->iter_vars;
bool has_spatial_iter = false;
Array<Var> index_map_inputs;
Expand Down Expand Up @@ -498,7 +511,7 @@ Optional<ObjectRef> NormalizePrimFunc(Schedule sch) {
sch->GetSRef(root_block));
block_is_reduction.push_back(Bool(is_reduction));
}
return Array<ObjectRef>{blocks, block_loops, block_iters, block_is_reduction};
return Array<ObjectRef>{leaf_blocks, block_loops, block_iters, block_is_reduction};
}

TVM_REGISTER_GLOBAL("tir.schedule.NormalizePrimFunc").set_body_typed(NormalizePrimFunc);
Expand Down
Loading