From f4a744224da4dea02a45d7f872e3fa2ceedb0b90 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 6 Jul 2023 08:36:50 -0500 Subject: [PATCH 1/2] [TIR] Call TVMBackendFreeWorkspace inside LetStmt Prior to this commit, the call to `TVMBackendFreeWorkspace` occurred outside the `LetStmt` that defined the workspace pointer. While works with current codegen, as the code produced for `LetStmt` does not check for out-of-scope access, this access of an out-of-scope should be avoided. This commit updates `LowerTVMBuiltin` to produce the call to `TVMBackendFreeWorkspace` at the end of the `LetStmt`'s body, rather than just after the `LetStmt`. --- src/tir/transforms/lower_tvm_builtin.cc | 41 +++++++++---------- .../test_tir_transform_lower_tvm_builtin.py | 15 +++---- 2 files changed, 27 insertions(+), 29 deletions(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 837a3e6d3587..4ceb2b562036 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -247,22 +247,21 @@ class BuiltinLower : public StmtExprMutator { ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); - Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), - throw_last_error), - op->body}); - Stmt alloca = LetStmt(op->buffer_var, - Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), - {cast(DataType::Int(32), device_type_.value()), - cast(DataType::Int(32), device_id_.value()), total_bytes, - IntImm(DataType::Int(32), op->dtype.code()), - IntImm(DataType::Int(32), op->dtype.bits())}), - body); - + Stmt alloc_nullptr_check = IfThenElse( + Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error); PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"), {cast(DataType::Int(32), device_type_.value()), cast(DataType::Int(32), device_id_.value()), op->buffer_var}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); - body = SeqStmt({alloca, free_stmt}); + + Stmt body = SeqStmt({alloc_nullptr_check, op->body, free_stmt}); + body = LetStmt(op->buffer_var, + Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), + {cast(DataType::Int(32), device_type_.value()), + cast(DataType::Int(32), device_id_.value()), total_bytes, + IntImm(DataType::Int(32), op->dtype.code()), + IntImm(DataType::Int(32), op->dtype.bits())}), + body); body = AttrStmt(op->buffer_var, attr::storage_alignment, make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); return body; @@ -567,9 +566,15 @@ class BuiltinLower : public StmtExprMutator { ICHECK(device_id_) << "Unknown device id in current IR"; Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {})); + PrimExpr storage_scope = call->args[0]; + Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), + {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), + storage_scope, let->var}); + Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); + Stmt body = SeqStmt( {IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error), - let->body}); + let->body, free_stmt}); DataType dtype = let->var->type_annotation.as()->element_type.as()->dtype; @@ -591,15 +596,7 @@ class BuiltinLower : public StmtExprMutator { Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args); Stmt alloca = LetStmt(let->var, call_packed, body); - - PrimExpr storage_scope = call->args[0]; - Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(), - {GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(), - storage_scope, let->var}); - - Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); - body = SeqStmt({alloca, free_stmt}); - return body; + return alloca; } private: diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index 6eac5e90b553..c48a4442535a 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -14,9 +14,13 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. + import tvm +import tvm.testing + from tvm import te from tvm.script import tir as T + import numpy as np @@ -205,9 +209,6 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter): CodeGenLLVM. This fails to match when `map_free_vars=False` (default), because the first occurrence is undefined. - - The call to TVMBackendFreeWorkspace uses the allocated pointer, - but occurs outside the LetStmt. - - TVMScript always produces "handle" dtype for `T.tvm_throw_last_error`, while LowerTVMBuiltin outputs "int32" dtype. @@ -230,10 +231,10 @@ def expected(): with T.LetStmt(T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32), var=ptr): if T.isnullptr(ptr): T.Call("int32", "tir.tvm_throw_last_error", []) - buf = T.decl_buffer((16,), data=ptr) - buf[0] = T.float32(0) - if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0: - T.Call("int32", "tir.tvm_throw_last_error", []) + with T.decl_buffer((16,), data=ptr) as buf: + buf[0] = T.float32(0) + if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0: + T.Call("int32", "tir.tvm_throw_last_error", []) def test_compare(self, before, expected, transform): after = transform(before) From 1a0fe67e86c031885326a6e0d9f5f69513b6dc9e Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Thu, 6 Jul 2023 09:10:37 -0500 Subject: [PATCH 2/2] [TIR] Output AttrStmt "storage_alignment" inside the var binding Prior to this commit, the `AttrStmt` providing the storage alignment was placed outside the `LetStmt` that defines the variable. As a result, the alignment assumption is never actually used, as `CodeGenLLVM::VisitStmt_(const AttrStmtNode*)` only creates an alignment assumption for in-scope variables. This commit moves the storage alignment `AttrStmt` to be inside the `LetStmt`, rather than outside. --- src/tir/transforms/lower_tvm_builtin.cc | 19 +++++++++++++++--- .../test_tir_transform_lower_tvm_builtin.py | 20 +++++++------------ 2 files changed, 23 insertions(+), 16 deletions(-) diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 4ceb2b562036..ef3af6339951 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -254,7 +254,21 @@ class BuiltinLower : public StmtExprMutator { cast(DataType::Int(32), device_id_.value()), op->buffer_var}); Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error); - Stmt body = SeqStmt({alloc_nullptr_check, op->body, free_stmt}); + Stmt body = op->body; + std::vector nest; + while (auto opt = body.as()) { + auto decl = opt.value(); + body = decl->body; + decl.CopyOnWrite()->body = Evaluate(0); + nest.push_back(decl); + } + + body = SeqStmt::Flatten(body, free_stmt); + body = MergeNest(nest, body); + body = SeqStmt::Flatten(alloc_nullptr_check, body); + + body = AttrStmt(op->buffer_var, attr::storage_alignment, + make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); body = LetStmt(op->buffer_var, Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"), {cast(DataType::Int(32), device_type_.value()), @@ -262,8 +276,7 @@ class BuiltinLower : public StmtExprMutator { IntImm(DataType::Int(32), op->dtype.code()), IntImm(DataType::Int(32), op->dtype.bits())}), body); - body = AttrStmt(op->buffer_var, attr::storage_alignment, - make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body); + return body; } diff --git a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py index c48a4442535a..ffaac077d384 100644 --- a/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py +++ b/tests/python/unittest/test_tir_transform_lower_tvm_builtin.py @@ -204,11 +204,6 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter): This test validates the current behavior of LowerTVMBuiltin. This unit test may be improved in the future by addressing: - - The AttrStmt for "storage_alignment" occurs outside the LetStmt - that defines the pointer, which is currently required by - CodeGenLLVM. This fails to match when `map_free_vars=False` - (default), because the first occurrence is undefined. - - TVMScript always produces "handle" dtype for `T.tvm_throw_last_error`, while LowerTVMBuiltin outputs "int32" dtype. @@ -226,15 +221,14 @@ def before(): def expected(): T.func_attr({"target": T.target("llvm")}) - ptr = T.handle("float32", "global") + ptr: T.handle("float32") = T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32) T.attr(ptr, "storage_alignment", 64) - with T.LetStmt(T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32), var=ptr): - if T.isnullptr(ptr): - T.Call("int32", "tir.tvm_throw_last_error", []) - with T.decl_buffer((16,), data=ptr) as buf: - buf[0] = T.float32(0) - if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0: - T.Call("int32", "tir.tvm_throw_last_error", []) + if T.isnullptr(ptr): + T.Call("int32", "tir.tvm_throw_last_error", []) + buf = T.decl_buffer((16,), data=ptr) + buf[0] = T.float32(0) + if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0: + T.Call("int32", "tir.tvm_throw_last_error", []) def test_compare(self, before, expected, transform): after = transform(before)