From 19b38f3bb8f7d6d946c18c7d82f8221ca612bed8 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Fri, 5 May 2023 11:38:38 -0500 Subject: [PATCH 1/2] [TIR] Output DeclBuffer in LowerTVMBuiltin For the `stack_shape` and `stack_tcode` buffers, generate a `DeclBuffer`. This is a subset of the changes made in https://github.com/apache/tvm/pull/14778, broken out for ease of testing and review. --- src/tir/transforms/lower_tvm_builtin.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 837a3e6d3587..df7a88598532 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -145,6 +145,7 @@ class BuiltinLower : public StmtExprMutator { if (scope.max_sizes.shape_stack != -1) { scope.stack_shape = decl_buffer({IntImm(DataType::Int(64), scope.max_sizes.shape_stack)}, DataType::Int(64), "stack_shape"); + stmt = DeclBuffer(scope.stack_shape, stmt); stmt = LetStmt(scope.stack_shape->data, StackAlloca("shape", scope.max_sizes.shape_stack), stmt); } @@ -159,6 +160,7 @@ class BuiltinLower : public StmtExprMutator { stmt = LetStmt(scope.stack_value, StackAlloca("arg_value", scope.max_sizes.arg_stack), stmt); + stmt = DeclBuffer(scope.stack_tcode, stmt); stmt = LetStmt(scope.stack_tcode->data, StackAlloca("arg_tcode", scope.max_sizes.arg_stack), stmt); } From c020de2aa7d33ca7d782da1597b20325438e21c7 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 9 May 2023 09:48:55 -0500 Subject: [PATCH 2/2] Updated LowerTVMBuiltin tests for DeclBuffer --- .../python/unittest/test_tir_transform_lower_tvm_builtin.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 6eac5e90b553..cf2e3f045b63 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -71,7 +71,9 @@ def check_packed_func(target="llvm"): # Recursively visit PrimFunc until we meet the for-loop: while True: - if isinstance(node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt)): + if isinstance( + node, (tvm.tir.AssertStmt, tvm.tir.LetStmt, tvm.tir.AttrStmt, tvm.tir.DeclBuffer) + ): node = node.body elif isinstance(node, tvm.tir.SeqStmt): node = node[0] @@ -98,7 +100,7 @@ def check_packed_func(target="llvm"): # # let stack_value = tir.tvm_stack_alloca("arg_value", 4) # - alloca_value = alloca_tcode.body + alloca_value = alloca_tcode.body.body assert isinstance(alloca_value, tvm.tir.LetStmt) expected_value = tvm.tir.call_intrin(