diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 837a3e6d3587..ef3af6339951 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -247,24 +247,36 @@ class BuiltinLower : public StmtExprMutator { ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); - Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), - throw_last_error), - op->body}); - Stmt alloca = LetStmt(op->buffer_var, - Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), - {cast(DataType::Int(32), device_type_.value()), - cast(DataType::Int(32), device_id_.value()), total_bytes, - IntImm(DataType::Int(32), op->dtype.code()), - IntImm(DataType::Int(32), op->dtype.bits())}), - body); - + Stmt alloc_nullptr_check = IfThenElse( + Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error); PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), {cast(DataType::Int(32), device_type_.value()), cast(DataType::Int(32), device_id_.value()), op->buffer_var}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); - body = SeqStmt({alloca, free_stmt}); + + Stmt body = op->body; + std::vector nest; + while (auto opt = body.as()) { + auto decl = opt.value(); + body = decl->body; + decl.CopyOnWrite()->body = Evaluate(0); + nest.push_back(decl); + } + + body = SeqStmt::Flatten(body, free_stmt); + body = MergeNest(nest, body); + body = SeqStmt::Flatten(alloc_nullptr_check, body); + body = AttrStmt(op->buffer_var, attr::storage_alignment, make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); + body = LetStmt(op->buffer_var, + Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), + {cast(DataType::Int(32), device_type_.value()), + cast(DataType::Int(32), device_id_.value()), total_bytes, + IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}), + body); + return body; } @@ -567,9 +579,15 @@ class BuiltinLower : public StmtExprMutator { ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); + PrimExpr storage_scope = call->args[0]; + Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), + {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), + storage_scope, let->var}); + Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); + Stmt body = SeqStmt( {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error), - let->body}); + let->body, free_stmt}); DataType dtype = let->var->type_annotation.as()->element_type.as()->dtype; @@ -591,15 +609,7 @@ class BuiltinLower : public StmtExprMutator { Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args); Stmt alloca = LetStmt(let->var, call_packed, body); - - PrimExpr storage_scope = call->args[0]; - Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), - {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), - storage_scope, let->var}); - - Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); - body = SeqStmt({alloca, free_stmt}); - return body; + return alloca; } private: diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 6eac5e90b553..ffaac077d384 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -14,9 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import tvm +import tvm.testing + from tvm import te from tvm.script import tir as T + import numpy as np @@ -200,14 +204,6 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter): This test validates the current behavior of LowerTVMBuiltin. This unit test may be improved in the future by addressing: - - The AttrStmt for "storage_alignment" occurs outside the LetStmt - that defines the pointer, which is currently required by - CodeGenLLVM. This fails to match when `map_free_vars=False` - (default), because the first occurrence is undefined. - - - The call to TVMBackendFreeWorkspace uses the allocated pointer, - but occurs outside the LetStmt. - - TVMScript always produces "handle" dtype for `T.tvm_throw_last_error`, while LowerTVMBuiltin outputs "int32" dtype. @@ -225,13 +221,12 @@ def before(): def expected(): T.func_attr({"target": T.target("llvm")}) - ptr = T.handle("float32", "global") + ptr: T.handle("float32") = T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32) T.attr(ptr, "storage_alignment", 64) - with T.LetStmt(T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32), var=ptr): - if T.isnullptr(ptr): - T.Call("int32", "tir.tvm_throw_last_error", []) - buf = T.decl_buffer((16,), data=ptr) - buf[0] = T.float32(0) + if T.isnullptr(ptr): + T.Call("int32", "tir.tvm_throw_last_error", []) + buf = T.decl_buffer((16,), data=ptr) + buf[0] = T.float32(0) if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0: T.Call("int32", "tir.tvm_throw_last_error", [])