diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 1b5551e5097b..f6b04e178223 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -137,9 +137,25 @@ TVM_REGISTER_GLOBAL("relax.If") }); Tuple::Tuple(tvm::Array fields, Span span) { + Optional tuple_sinfo = [&]() -> Optional { + Array field_sinfo; + for (const auto& field : fields) { + if (field->struct_info_.defined()) { + field_sinfo.push_back(GetStructInfo(field)); + } else { + return NullOpt; + } + } + return TupleStructInfo(field_sinfo); + }(); + ObjectPtr n = make_object(); n->fields = std::move(fields); n->span = std::move(span); + if (tuple_sinfo) { + n->checked_type_ = GetStaticType(tuple_sinfo.value()); + } + n->struct_info_ = tuple_sinfo; data_ = std::move(n); } diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index af1bc851be99..b20c9ef2d982 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -86,6 +86,25 @@ def test_tuple() -> None: t[-3] +def test_tuple_sinfo_inferred_on_construction(): + v0 = rx.Var("v0", rx.ObjectStructInfo()) + v1 = rx.Var("v1", rx.ObjectStructInfo()) + tup = rx.Tuple((v0, v1)) + + assert tup.struct_info_ is not None + tvm.ir.assert_structural_equal( + tup.struct_info, rx.TupleStructInfo([rx.ObjectStructInfo(), rx.ObjectStructInfo()]) + ) + + +def test_tuple_sinfo_requires_fields_with_known_sinfo(): + v0 = rx.Var("v0", rx.ObjectStructInfo()) + v1 = rx.Var("v1") + tup = rx.Tuple((v0, v1)) + + assert tup.struct_info_ is None + + def test_match_cast() -> None: # match_cast([16, 8], [m, n]) m = tir.Var("m", dtype="int64")