When using MetaSchedule to tune a conv3d ncdhw workload, the tuning result cannot pass wellformed check and caused the following error
import tvm
from tvm.script import tir as T, ir as I
from tvm import meta_schedule as ms
from tvm.tir.tensor_intrin import *
@T.prim_func(private=True)
def func(silu33: T.Buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)), "float16"), down_blocks_2_resnets_0_temporal_res_block_conv1_weight: T.Buffer((T.int64(1280), T.int64(1280), T.int64(3), T.int64(1), T.int64(1)), "float16"), lv113: T.Buffer((T.int64(1), T.int64(1280), T.int64(1), T.int64(1), T.int64(1)), "float16"), permute_dims152: T.Buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(1), T.int64(1)), "float16"), T_add_intermediate_1: T.Buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)), "float16")):
T.func_attr({"tir.noalias": T.bool(True)})
# with T.block("root"):
pad_temp = T.alloc_buffer((T.int64(2), T.int64(1280), T.int64(27), T.int64(18), T.int64(32)), "float16")
conv3d_ncdhw_intermediate = T.alloc_buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)), "float16")
T_add_intermediate = T.alloc_buffer((T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)), "float16")
for i0, i1, i2, i3, i4 in T.grid(T.int64(2), T.int64(1280), T.int64(27), T.int64(18), T.int64(32)):
with T.block("pad_temp"):
v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4])
T.reads(silu33[v_i0, v_i1, v_i2 - T.int64(1), v_i3, v_i4])
T.writes(pad_temp[v_i0, v_i1, v_i2, v_i3, v_i4])
pad_temp[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else(T.int64(1) <= v_i2 and v_i2 < T.int64(26), silu33[v_i0, v_i1, v_i2 - T.int64(1), v_i3, v_i4], T.float16(0))
for nn, ff, yy, xx, zz, rc, ry, rx, rz in T.grid(T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32), T.int64(1280), T.int64(3), T.int64(1), T.int64(1)):
with T.block("conv3d_ncdhw"):
v_nn, v_ff, v_yy, v_xx, v_zz, v_rc, v_ry, v_rx, v_rz = T.axis.remap("SSSSSRRRR", [nn, ff, yy, xx, zz, rc, ry, rx, rz])
T.reads(pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx, v_zz + v_rz], down_blocks_2_resnets_0_temporal_res_block_conv1_weight[v_ff, v_rc, v_ry, v_rx, v_rz])
T.writes(conv3d_ncdhw_intermediate[v_nn, v_ff, v_yy, v_xx, v_zz])
with T.init():
conv3d_ncdhw_intermediate[v_nn, v_ff, v_yy, v_xx, v_zz] = T.float16(0)
conv3d_ncdhw_intermediate[v_nn, v_ff, v_yy, v_xx, v_zz] = conv3d_ncdhw_intermediate[v_nn, v_ff, v_yy, v_xx, v_zz] + pad_temp[v_nn, v_rc, v_yy + v_ry, v_xx + v_rx, v_zz + v_rz] * down_blocks_2_resnets_0_temporal_res_block_conv1_weight[v_ff, v_rc, v_ry, v_rx, v_rz]
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)):
with T.block("T_add"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(conv3d_ncdhw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], lv113[T.int64(0), v_ax1, T.int64(0), T.int64(0), T.int64(0)])
T.writes(T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = conv3d_ncdhw_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + lv113[T.int64(0), v_ax1, T.int64(0), T.int64(0), T.int64(0)]
for ax0, ax1, ax2, ax3, ax4 in T.grid(T.int64(2), T.int64(1280), T.int64(25), T.int64(18), T.int64(32)):
with T.block("T_add_1"):
v_ax0, v_ax1, v_ax2, v_ax3, v_ax4 = T.axis.remap("SSSSS", [ax0, ax1, ax2, ax3, ax4])
T.reads(T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4], permute_dims152[v_ax0, v_ax1, v_ax2, T.int64(0), T.int64(0)])
T.writes(T_add_intermediate_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4])
T_add_intermediate_1[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] = T_add_intermediate[v_ax0, v_ax1, v_ax2, v_ax3, v_ax4] + permute_dims152[v_ax0, v_ax1, v_ax2, T.int64(0), T.int64(0)]
if __name__ == "__main__":
func.show()
target = tvm.target.Target("nvidia/nvidia-a10g")
tune = False
if tune:
db = ms.tune_tir(func, target=target, work_dir="./temp", max_trials_global=500)
else:
db = ms.database.JSONDatabase(work_dir="./temp")
mod = tvm.ir.IRModule({"main": func.with_attrs({"global_symbol": "main"})})
tuned_mod = db.query_ir_module(mod=mod, target=target, workload_name="main")
tuned_mod.show()
tvm.build(tuned_mod, target=target)
tvm.tir.analysis.verify_well_formed(tuned_mod)
When using MetaSchedule to tune a conv3d ncdhw workload, the tuning result cannot pass wellformed check and caused the following error
To reproduce, you can do tuning with the following code with
tuneset toTrueand then apply the database to reproduce the error. To directly get the tuned workload TIR and run test, you can also try this script.Thanks to @jwfromm for reporting this issue.