diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 6c6427a90649..f1e9106a635b 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1207,17 +1207,20 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ // Step 4: Rewrite buffer_map of the PrimFunc if necessary. if (!defining_site_sref.defined()) { GlobalVar g_var; - GetRootPrimFunc(self->mod, scope_block, &g_var); + const auto* old_func = GetRootPrimFunc(self->mod, scope_block, &g_var); IRModuleNode* new_mod = self->mod.CopyOnWrite(); MapNode* new_map = new_mod->functions.CopyOnWrite(); - PrimFunc ref_new_func = Downcast(std::move(new_map->at(g_var))); - PrimFuncNode* new_func = ref_new_func.CopyOnWrite(); - MapNode* new_buffer_map = new_func->buffer_map.CopyOnWrite(); - for (auto it = new_buffer_map->begin(); it != new_buffer_map->end(); ++it) { - if ((*it).second.same_as(old_buffer)) { - (*it).second = new_buffer; + + Map new_buffer_map; + for (auto [var, buffer] : old_func->buffer_map) { + if (buffer.same_as(old_buffer)) { + buffer = new_buffer; } + new_buffer_map.Set(var, buffer); } + + PrimFunc ref_new_func(old_func->params, old_func->body, old_func->ret_type, new_buffer_map, + old_func->attrs, old_func->span); new_map->at(g_var) = std::move(ref_new_func); } diff --git a/tests/python/relax/test_transform_legalize_ops_manipulate.py b/tests/python/relax/test_transform_legalize_ops_manipulate.py index 9b7a8f23c91b..dd0208f5db07 100644 --- a/tests/python/relax/test_transform_legalize_ops_manipulate.py +++ b/tests/python/relax/test_transform_legalize_ops_manipulate.py @@ -1666,5 +1666,69 @@ def main(x: R.Tensor((10, 20, 30), dtype="float32")) -> R.Tensor((10, 30, 7, 3), tvm.ir.assert_structural_equal(mod, Expected) +def test_func_struct_info_of_legalized_layout_transform(): + """PrimFunc shape information must be correct + + This is a regression test. Previously, the legalization of + `R.layout_transform` produced a PrimFunc with `FuncStructInfo` + different than its actual signature. This resulted in errors + when later passes attempted to infer the StructInfo. + """ + + @I.ir_module + class Before: + @R.function + def main( + x: R.Tensor((16,), dtype="float32"), y: R.Tensor((16,), dtype="float32") + ) -> R.Tensor((16,), dtype="float32"): + R.func_attr({"relax.force_pure": True}) + with R.dataflow(): + lv: R.Tensor((4, 4), dtype="float32") = R.layout_transform( + x, index_map=lambda i: (i // 4, i % 4), pad_value=None + ) + gv: R.Tensor((4, 4), dtype="float32") = lv + R.output(gv) + return gv + + After = tvm.ir.transform.Sequential( + [ + relax.transform.LegalizeOps(), + relax.transform.ToNonDataflow(), + relax.transform.RemovePurityChecking(), + relax.transform.CallTIRRewrite(), + ] + )(Before) + + @I.ir_module + class Expected: + @R.function + def main( + x: R.Tensor((16,), dtype="float32"), + y: R.Tensor((16,), dtype="float32"), + ): + R.func_attr({"relax.force_pure": True}) + cls = Expected + alloc: R.Tensor((4, 4), dtype="float32") = R.builtin.alloc_tensor( + R.shape([4, 4]), R.dtype("float32"), R.prim_value(0), R.str("global") + ) + cls.te_layout_transform(x, alloc) + lv = alloc + gv = lv + return gv + + @T.prim_func(private=True) + def te_layout_transform( + A: T.Buffer((T.int64(16),), "float32"), + te_layout_transform: T.Buffer((T.int64(4), T.int64(4)), "float32"), + ): + T.func_attr({"tir.noalias": T.bool(True)}) + for i in range(T.int64(16)): + with T.block("te_layout_transform"): + vi = T.axis.spatial(T.int64(16), i) + te_layout_transform[vi // T.int64(4), vi % T.int64(4)] = A[vi] + + tvm.ir.assert_structural_equal(Expected, After) + + if __name__ == "__main__": tvm.testing.main()