From 5502502930cffc71ae742a1fb67b9e974c30a046 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Thu, 3 Aug 2023 11:55:48 -0700 Subject: [PATCH 1/2] [Script] Be more careful when generating ast.ExtSlice for Subscript The ast.ExtSlice expects a non-empty list, otherwise evaluation fails with "error: empty dims on ExtSlice". Also, each element in "dims" list of ExtSlice must be either Slice or Index. In python3.8 an expression A[()] is parsed (by ast) as Subscript with slice being Index(value=Tuple(elts=[])). When we translate a subscript from doc.AST to ast, we unconditionally convert every tuple to ast.ExtSlice, which in this case is incorrect. The fix is to map empty tuple back to the Index(Tuple[])) instead of ExtSlice. In other cases, ensure that members of ExtSlice are of correct types. --- python/tvm/script/parser/core/doc.py | 20 ++++++++++++------- .../unittest/test_tvmscript_parser_tir.py | 16 +++++++++++++++ 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py index 5ea83749eadf..7d1ff67035a3 100644 --- a/python/tvm/script/parser/core/doc.py +++ b/python/tvm/script/parser/core/doc.py @@ -414,13 +414,19 @@ def subscript_from_doc(x: doc.Subscript) -> ast.Subscript: ctx=from_doc(x.ctx), ) elif isinstance(x.slice, doc.Tuple): - result = ast.Subscript( - value=from_doc(x.value), - slice=ast.ExtSlice( - dims=[from_doc(i) for i in x.slice.elts], - ), - ctx=from_doc(x.ctx), - ) + def remap_dim(doc_item: doc.Expr) -> ast.Expr: + ast_item = from_doc(doc_item) + if isinstance(ast_item, (ast.Index, ast.Slice)): + return ast_item + return ast.Index(value=ast_item) + + # ast.ExtSlice requires a non-empty list of dims, and each dim must be either + # a Slice or an Index. + if x.slice.elts: + ast_slice = ast.ExtSlice(dims=[*map(remap_dim, x.slice.elts)]) + else: + ast_slice = ast.Index(value=ast.Tuple(elts=[], ctx=from_doc(x.ctx))) + result = ast.Subscript(value=from_doc(x.value), slice=ast_slice, ctx=from_doc(x.ctx)) else: result = ast.Subscript( value=from_doc(x.value), diff --git a/tests/python/unittest/test_tvmscript_parser_tir.py b/tests/python/unittest/test_tvmscript_parser_tir.py index 210c173141c5..ef02df497b7b 100644 --- a/tests/python/unittest/test_tvmscript_parser_tir.py +++ b/tests/python/unittest/test_tvmscript_parser_tir.py @@ -292,5 +292,21 @@ def non_starred(a: T.handle, b: T.handle): tvm.ir.assert_structural_equal(starred, non_starred) +def test_tir_empty_tuple_index(): + @T.macro + def bar(val): + T.evaluate(val) + + @T.prim_func(private=True) + def func_with_empty_tuple(A: T.Buffer((), "int32"), B: T.Buffer((), "int32")): + bar(val=A[()]) + + @T.prim_func(private=True) + def expected(A: T.Buffer((), "int32"), B: T.Buffer((), "int32")): + T.evaluate(A[()]) + + tvm.ir.assert_structural_equal(func_with_empty_tuple, expected) + + if __name__ == "__main__": tvm.testing.main() From fd6b7ac320006cb4bb57c4c6beced248b18fe970 Mon Sep 17 00:00:00 2001 From: Krzysztof Parzyszek Date: Fri, 4 Aug 2023 06:31:17 -0700 Subject: [PATCH 2/2] Fix lint #1 --- python/tvm/script/parser/core/doc.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/script/parser/core/doc.py b/python/tvm/script/parser/core/doc.py index 7d1ff67035a3..1c5241dc8d90 100644 --- a/python/tvm/script/parser/core/doc.py +++ b/python/tvm/script/parser/core/doc.py @@ -414,6 +414,7 @@ def subscript_from_doc(x: doc.Subscript) -> ast.Subscript: ctx=from_doc(x.ctx), ) elif isinstance(x.slice, doc.Tuple): + def remap_dim(doc_item: doc.Expr) -> ast.Expr: ast_item = from_doc(doc_item) if isinstance(ast_item, (ast.Index, ast.Slice)):