From 329b558d8bb9ed5aea171670c86f21e683e51410 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Mon, 8 Apr 2024 18:51:25 -0500 Subject: [PATCH] [QoL][Relax] Infer StructInfo for relax::Tuple on construction Prior to this commit, the `relax::Tuple` constructor left the `struct_info_` field undefined. This is inconsistent with other Relax leaf nodes, such as `relax::PrimValue`, `relax::Constant`, and `relax::ExternFunc`, which initialize their struct info on construction. This commit updates the `relax::Tuple` constructor to define `struct_info_` as `TupleStructInfo`, if all fields have a known struct info. If any field does not have a known struct info, the current behavior is kept, where `struct_info_` is constructed as `NullOpt`, and is later populated by the `relax::BlockBuilder`. --- src/relax/ir/expr.cc | 16 ++++++++++++++++ tests/python/relax/test_expr.py | 19 +++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc index 1b5551e5097b..f6b04e178223 100644 --- a/src/relax/ir/expr.cc +++ b/src/relax/ir/expr.cc @@ -137,9 +137,25 @@ TVM_REGISTER_GLOBAL("relax.If") }); Tuple::Tuple(tvm::Array fields, Span span) { + Optional tuple_sinfo = [&]() -> Optional { + Array field_sinfo; + for (const auto& field : fields) { + if (field->struct_info_.defined()) { + field_sinfo.push_back(GetStructInfo(field)); + } else { + return NullOpt; + } + } + return TupleStructInfo(field_sinfo); + }(); + ObjectPtr n = make_object(); n->fields = std::move(fields); n->span = std::move(span); + if (tuple_sinfo) { + n->checked_type_ = GetStaticType(tuple_sinfo.value()); + } + n->struct_info_ = tuple_sinfo; data_ = std::move(n); } diff --git a/tests/python/relax/test_expr.py b/tests/python/relax/test_expr.py index af1bc851be99..b20c9ef2d982 100644 --- a/tests/python/relax/test_expr.py +++ b/tests/python/relax/test_expr.py @@ -86,6 +86,25 @@ def test_tuple() -> None: t[-3] +def test_tuple_sinfo_inferred_on_construction(): + v0 = rx.Var("v0", rx.ObjectStructInfo()) + v1 = rx.Var("v1", rx.ObjectStructInfo()) + tup = rx.Tuple((v0, v1)) + + assert tup.struct_info_ is not None + tvm.ir.assert_structural_equal( + tup.struct_info, rx.TupleStructInfo([rx.ObjectStructInfo(), rx.ObjectStructInfo()]) + ) + + +def test_tuple_sinfo_requires_fields_with_known_sinfo(): + v0 = rx.Var("v0", rx.ObjectStructInfo()) + v1 = rx.Var("v1") + tup = rx.Tuple((v0, v1)) + + assert tup.struct_info_ is None + + def test_match_cast() -> None: # match_cast([16, 8], [m, n]) m = tir.Var("m", dtype="int64")