diff --git a/python/tvm/script/parser/relax/parser.py b/python/tvm/script/parser/relax/parser.py index 1a0c3cea8e0b..8ee51136009e 100644 --- a/python/tvm/script/parser/relax/parser.py +++ b/python/tvm/script/parser/relax/parser.py @@ -274,7 +274,21 @@ def post_visit_local_function(self: Parser, node: doc.Expr) -> None: @dispatch.register(token="relax", type_name="Expr") def visit_expr_stmt(self: Parser, node: doc.Expr) -> None: value = self.eval_expr(node.value) - if value is not None: + if isinstance(value, relax.Expr): + var = R.emit(value) + IRBuilder.name("_", var) + is_void_value = ( + isinstance(var.struct_info, relax.TupleStructInfo) and len(var.struct_info.fields) == 0 + ) + + if not is_void_value: + self.report_error( + node, + f"Non-void relax expressions must be bound to a variable, " + f"but expression of type {var.struct_info} was used as a statement.", + ) + + elif value is not None: self.report_error(node, f"Unsupported Expr stmt type {value}.") diff --git a/src/script/printer/relax/binding.cc b/src/script/printer/relax/binding.cc index acf0072c0f45..5aa99878f951 100644 --- a/src/script/printer/relax/binding.cc +++ b/src/script/printer/relax/binding.cc @@ -69,6 +69,24 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) Doc ret = d->AsDoc(n->value, n_p->Attr("value")); d->cfg->binding_names.pop_back(); return ret; + + // Uncommenting this section hides the variable binding + // when the StructInfo is void. For example, printing + // `R.assert_op(expr)` instead of `_ = R.assert_op(expr)`. + // However, Relax represents void values as an empty + // tuple, and a void-type variable may still be used later + // in the function. Hiding bindings of these void-type + // variables would result in use of an undefined variable. + // + // TODO(Lunderberg): Inline void-type variable to use + // `R.tuple()` during normalization. This will avoid the + // cases that trigger the undefined variables, and allow + // this syntax sugar to be enabled. + // + // } else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) && + // relax::HasVoidStructInfo(n->var)) { + // ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); + // return ExprStmtDoc(rhs); } else { ExprDoc rhs = d->AsDoc(n->value, n_p->Attr("value")); Optional ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value); diff --git a/tests/python/relax/test_tvmscript_parser.py b/tests/python/relax/test_tvmscript_parser.py index 75aeb6831c1c..48d087c18a20 100644 --- a/tests/python/relax/test_tvmscript_parser.py +++ b/tests/python/relax/test_tvmscript_parser.py @@ -1824,6 +1824,77 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): _check(Mixture) +def test_function_with_void_return_type_may_be_used_as_statements(): + """Void return of calls do not need to be assigned""" + + @I.ir_module + class Unsugared: + @R.function(pure=False) + def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.print(x, format="x: {}") + return x + + @R.function(pure=False) + def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") + return x + + @I.ir_module + class Sugared: + @R.function(pure=False) + def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.print(x, format="x: {}") + return x + + @R.function(pure=False) + def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.assert_op(R.const(False, dtype="bool"), x, format="x: {}") + return x + + tvm.ir.assert_structural_equal(Unsugared, Sugared) + + +def test_function_with_non_void_return_type_must_be_assigned(): + """Non-void results must be assigned to a variable""" + + with pytest.raises(tvm.error.DiagnosticError): + + @R.function(pure=False) + def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): + R.add(x, x) + return x + + +def test_function_with_void_return_type_in_if_else(): + """Last statement in if/else may be a void return""" + + @I.ir_module + class Unsugared: + @R.function(pure=False) + def conditional( + x: R.Tensor((), "int32"), condition: R.Tensor((), "bool") + ) -> R.Tensor((), "int32"): + if condition: + y = R.print(x, format="True condition: {}") + else: + y = R.print(x, format="False condition: {}") + return x + + @I.ir_module + class Sugared: + @R.function(pure=False) + def conditional( + x: R.Tensor((), "int32"), condition: R.Tensor((), "bool") + ) -> R.Tensor((), "int32"): + if condition: + R.print(x, format="True condition: {}") + else: + R.print(x, format="False condition: {}") + return x + + _check(Sugared, Unsugared) + + def test_call_pure_packed(): @R.function def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor: diff --git a/tests/python/relax/test_tvmscript_printer_relax.py b/tests/python/relax/test_tvmscript_printer_relax.py index a75977ff9910..667fb0a132b6 100644 --- a/tests/python/relax/test_tvmscript_printer_relax.py +++ b/tests/python/relax/test_tvmscript_printer_relax.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring + +import pytest + import tvm import tvm.testing from tvm import IRModule, relax, tir @@ -633,6 +636,7 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32 ) +@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled") def test_assert_op(): @I.ir_module class AssertOpMod: @@ -651,12 +655,13 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): class Module: @R.function(pure=False) def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - y: R.Tuple = R.assert_op(R.const(False, "bool"), x, format=R.str("x: {}")) + R.assert_op(R.const(False, "bool"), x, format=R.str("x: {}")) return x """, ) +@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled") def test_print(): @I.ir_module class PrintMod: @@ -675,7 +680,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"): class Module: @R.function(pure=False) def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): - y: R.Tuple = R.print(x, format=R.str("x: {}")) + R.print(x, format=R.str("x: {}")) return x """, ) @@ -705,6 +710,7 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): ) +@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled") def test_directly_construct_private_funcs(): # public @R.function @@ -758,7 +764,7 @@ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): R.func_attr({"relax.force_pure": 1}) - y: R.Tuple = R.print(format=R.str("Hi there!")) + R.print(format=R.str("Hi there!")) z: R.Tensor((), dtype="int32") = R.add(x, x) return z @@ -770,7 +776,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): @R.function(private=True) def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"): R.func_attr({"relax.force_pure": 1}) - y: R.Tuple = R.print(format=R.str("Lol")) + R.print(format=R.str("Lol")) z: R.Tensor((), dtype="int32") = R.multiply(x, x) return z """,