Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion python/tvm/script/parser/relax/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,21 @@ def post_visit_local_function(self: Parser, node: doc.Expr) -> None:
@dispatch.register(token="relax", type_name="Expr")
def visit_expr_stmt(self: Parser, node: doc.Expr) -> None:
value = self.eval_expr(node.value)
if value is not None:
if isinstance(value, relax.Expr):
var = R.emit(value)
IRBuilder.name("_", var)
is_void_value = (
isinstance(var.struct_info, relax.TupleStructInfo) and len(var.struct_info.fields) == 0
)

if not is_void_value:
self.report_error(
node,
f"Non-void relax expressions must be bound to a variable, "
f"but expression of type {var.struct_info} was used as a statement.",
)
Comment on lines +284 to +289
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should even have this as a rule. Why not let users evaluate expressions without binding them regardless of their return type?

Copy link
Copy Markdown
Contributor Author

@Lunderberg Lunderberg Feb 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, because I wanted to make the minimal change that would support common cases. I think it would be good to remove the restriction altogether, but for the first step, I wanted to make the restriction be explicit.

There's a couple of concerns I could see with allowing non-void return value to be implicitly ignored.

  • Prevent accidentally unused values. If cls.add1(a,b) does an in-place update of a, but cls.add2(a,b) returns a new value, using cls.add2(a,b) without assigning to a value would likely be an error.
  • Round-trip TVMScript -> Relax -> TVMScript without a pre-processing pass. Checking if a value has void type can done while printing the IR. Checking whether a non-void variable could be omitted would require a pre-processing step to find any downstream users.

I don't think either of those are definitive arguments, but I figured I'd handle the unambiguous beneficial cases first, with a follow-up PR to relax the restriction.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Those are good points. For the first case, we could have a warning (as C and other languages do with the right settings) for ignoring a return value. The second one is an interesting issue. I think it suggests that expecting an exact textual match for the parser roundtripping is too strict of a criterion for this situation, since a "statement" can always be written as _ = ... and it would be a choice as to whether to write it that way or use the friendlier syntax. It would make it harder to write automatic tests, true. For a systematic solution, maybe we could formalize the idea of a desugaring step for testing purposes?


elif value is not None:
self.report_error(node, f"Unsupported Expr stmt type {value}.")


Expand Down
18 changes: 18 additions & 0 deletions src/script/printer/relax/binding.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,24 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
Doc ret = d->AsDoc(n->value, n_p->Attr("value"));
d->cfg->binding_names.pop_back();
return ret;

// Uncommenting this section hides the variable binding
// when the StructInfo is void. For example, printing
// `R.assert_op(expr)` instead of `_ = R.assert_op(expr)`.
// However, Relax represents void values as an empty
// tuple, and a void-type variable may still be used later
// in the function. Hiding bindings of these void-type
// variables would result in use of an undefined variable.
//
// TODO(Lunderberg): Inline void-type variable to use
// `R.tuple()` during normalization. This will avoid the
// cases that trigger the undefined variables, and allow
// this syntax sugar to be enabled.
//
// } else if (d->cfg->syntax_sugar && relax::HasVoidStructInfo(n->value) &&
// relax::HasVoidStructInfo(n->var)) {
// ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
// return ExprStmtDoc(rhs);
} else {
ExprDoc rhs = d->AsDoc<ExprDoc>(n->value, n_p->Attr("value"));
Optional<ExprDoc> ann = StructInfoAsAnn(n->var, n_p->Attr("var"), d, n->value);
Expand Down
71 changes: 71 additions & 0 deletions tests/python/relax/test_tvmscript_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -1824,6 +1824,77 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
_check(Mixture)


def test_function_with_void_return_type_may_be_used_as_statements():
"""Void return of calls do not need to be assigned"""

@I.ir_module
class Unsugared:
@R.function(pure=False)
def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
y = R.print(x, format="x: {}")
return x

@R.function(pure=False)
def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
y = R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
return x

@I.ir_module
class Sugared:
@R.function(pure=False)
def print(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
R.print(x, format="x: {}")
return x

@R.function(pure=False)
def assert_func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
R.assert_op(R.const(False, dtype="bool"), x, format="x: {}")
return x

tvm.ir.assert_structural_equal(Unsugared, Sugared)


def test_function_with_non_void_return_type_must_be_assigned():
"""Non-void results must be assigned to a variable"""

with pytest.raises(tvm.error.DiagnosticError):

@R.function(pure=False)
def func(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
R.add(x, x)
return x


def test_function_with_void_return_type_in_if_else():
"""Last statement in if/else may be a void return"""

@I.ir_module
class Unsugared:
@R.function(pure=False)
def conditional(
x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")
) -> R.Tensor((), "int32"):
if condition:
y = R.print(x, format="True condition: {}")
else:
y = R.print(x, format="False condition: {}")
return x

@I.ir_module
class Sugared:
@R.function(pure=False)
def conditional(
x: R.Tensor((), "int32"), condition: R.Tensor((), "bool")
) -> R.Tensor((), "int32"):
if condition:
R.print(x, format="True condition: {}")
else:
R.print(x, format="False condition: {}")
return x

_check(Sugared, Unsugared)


def test_call_pure_packed():
@R.function
def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor:
Expand Down
14 changes: 10 additions & 4 deletions tests/python/relax/test_tvmscript_printer_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring

import pytest

import tvm
import tvm.testing
from tvm import IRModule, relax, tir
Expand Down Expand Up @@ -633,6 +636,7 @@ def foo(x: R.Tensor((128,), dtype="float32")) -> R.Tensor((128,), dtype="float32
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_assert_op():
@I.ir_module
class AssertOpMod:
Expand All @@ -651,12 +655,13 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
class Module:
@R.function(pure=False)
def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
y: R.Tuple = R.assert_op(R.const(False, "bool"), x, format=R.str("x: {}"))
R.assert_op(R.const(False, "bool"), x, format=R.str("x: {}"))
return x
""",
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_print():
@I.ir_module
class PrintMod:
Expand All @@ -675,7 +680,7 @@ def main(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
class Module:
@R.function(pure=False)
def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
y: R.Tuple = R.print(x, format=R.str("x: {}"))
R.print(x, format=R.str("x: {}"))
return x
""",
)
Expand Down Expand Up @@ -705,6 +710,7 @@ def main(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
)


@pytest.mark.xfail(reason="Eliding void variable bindings currently disabled")
def test_directly_construct_private_funcs():
# public
@R.function
Expand Down Expand Up @@ -758,7 +764,7 @@ def bar(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
@R.function
def baz(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
R.func_attr({"relax.force_pure": 1})
y: R.Tuple = R.print(format=R.str("Hi there!"))
R.print(format=R.str("Hi there!"))
z: R.Tensor((), dtype="int32") = R.add(x, x)
return z

Expand All @@ -770,7 +776,7 @@ def foo(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
@R.function(private=True)
def quux(x: R.Tensor((), dtype="int32")) -> R.Tensor((), dtype="int32"):
R.func_attr({"relax.force_pure": 1})
y: R.Tuple = R.print(format=R.str("Lol"))
R.print(format=R.str("Lol"))
z: R.Tensor((), dtype="int32") = R.multiply(x, x)
return z
""",
Expand Down