Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 25 additions & 53 deletions python/tvm/contrib/cutlass/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,17 @@ def visit_call(self, call):


def select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
cutlass_profiler, op_type, MM, KK, NN, out_dtype, batched, profile_all, use_multiprocessing
):
"""Run CUTLASS profiler to select the best kernel, or return the default one for dynamic
workloads."""
if any(isinstance(s, tvm.tir.Any) for s in [MM, KK, NN]):
out = cutlass_profiler.get_default(out_dtype, batched=batched)
logger.info("Picked the default kernel %s", out["name"])
out = cutlass_profiler.get_default(op_type, out_dtype, batched=batched)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
out = cutlass_profiler.profile(
name, cutlass_op_def, _ = cutlass_profiler.profile(
op_type,
MM,
NN,
KK,
Expand All @@ -112,10 +114,11 @@ def select_gemm_kernel(
use_multiprocessing=use_multiprocessing,
)
if profile_all:
logger.info("The best kernel is %s", out["name"])
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", out["name"])
return out
logger.info("Picked the first kernel found %s", name)

return name, cutlass_op_def


def handle_batch_matmul(
Expand All @@ -126,24 +129,17 @@ def handle_batch_matmul(
KK = arg0_shape[2]
NN = arg1_shape[1]

out = select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing
name, cutlass_op_def = select_gemm_kernel(
cutlass_profiler, op_type, MM, KK, NN, out_dtype, True, profile_all, use_multiprocessing
)

if op_type == "cutlass.batch_matmul":
cutlass_op_def = out["opdef"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now."

return {
"batch": arg0_shape[0],
"batch_stride_A": arg0_shape[1] * arg0_shape[2],
"batch_stride_B": arg1_shape[1] * arg1_shape[2],
"batch_stride_C": arg0_shape[1] * arg1_shape[1],
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
"cutlass_op_name": name,
"lda": "K",
"ldb": "K",
"ldc": "N",
Expand All @@ -158,26 +154,15 @@ def handle_dense(
KK = arg0_shape[1]
NN = arg1_shape[0]

out = select_gemm_kernel(
cutlass_profiler, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing
name, cutlass_op_def = select_gemm_kernel(
cutlass_profiler, op_type, MM, KK, NN, out_dtype, False, profile_all, use_multiprocessing
)

if op_type == "cutlass.dense":
cutlass_op_def = out["opdef"]
elif op_type == "cutlass.dense_bias":
cutlass_op_def = out["opdef_bias"]
elif op_type == "cutlass.dense_bias_relu":
cutlass_op_def = out["opdef_bias_relu"]
elif "cutlass.dense_bias_gelu" in op_type:
cutlass_op_def = out["opdef_bias_gelu"]
else:
raise ValueError("%s pattern is not implemented." % op_type)

assert "tn_align" in out["name"], "Only supports (row_major, col_major) input layout for now."
assert "tn_align" in name, "Only supports (row_major, col_major) input layout for now."

return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
"cutlass_op_name": name,
"lda": "K",
"ldb": "K",
"ldc": "N",
Expand All @@ -198,10 +183,12 @@ def handle_conv2d(
):
"""Profile and select a kernel for conv2d op workload."""
if any(isinstance(s, tvm.tir.Any) for s in d_shape):
out = cutlass_profiler.get_default(out_dtype)
logger.info("Picked the default kernel %s", out["name"])
out = cutlass_profiler.get_default(op_type, out_dtype)
name, cutlass_op_def = out["name"], out["opdef"]
logger.info("Picked the default kernel %s", name)
else:
out = cutlass_profiler.profile(
name, cutlass_op_def, _ = cutlass_profiler.profile(
op_type,
d_shape,
w_shape,
padding,
Expand All @@ -212,28 +199,13 @@ def handle_conv2d(
use_multiprocessing=use_multiprocessing,
)
if profile_all:
logger.info("The best kernel is %s", out["name"])
logger.info("The best kernel is %s", name)
else:
logger.info("Picked the first kernel found %s", out["name"])

if op_type == "cutlass.conv2d":
cutlass_op_def = out["opdef"]
elif op_type == "cutlass.conv2d_bias":
cutlass_op_def = out["opdef_bias"]
elif op_type == "cutlass.conv2d_bias_relu":
cutlass_op_def = out["opdef_bias_relu"]
elif op_type == "cutlass.conv2d_bias_sigmoid":
cutlass_op_def = out["opdef_bias_sigmoid"]
elif op_type == "cutlass.conv2d_bias_silu":
cutlass_op_def = out["opdef_bias_silu"]
elif op_type == "cutlass.conv2d_bias_hardswish":
cutlass_op_def = out["opdef_bias_hardswish"]
else:
raise ValueError("%s pattern is not implemented." % op_type)
logger.info("Picked the first kernel found %s", name)

return {
"cutlass_op_def": cutlass_op_def,
"cutlass_op_name": out["name"],
"cutlass_op_name": name,
}


Expand Down
2 changes: 1 addition & 1 deletion python/tvm/contrib/cutlass/conv2d_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def __init__(self):
>::Kernel;
"""

def emit(self, operation, no_beta_scaling=True):
def emit(self, operation, no_beta_scaling=False):
"""Instantiate a Conv2d kernel from given `operation`."""
warp_shape = [
int(
Expand Down
Loading