Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,24 +247,36 @@ class BuiltinLower : public StmtExprMutator {
ICHECK(device_id_) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));

Stmt body = SeqStmt({IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}),
throw_last_error),
op->body});
Stmt alloca = LetStmt(op->buffer_var,
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_.value()),
cast(DataType::Int(32), device_id_.value()), total_bytes,
IntImm(DataType::Int(32), op->dtype.code()),
IntImm(DataType::Int(32), op->dtype.bits())}),
body);

Stmt alloc_nullptr_check = IfThenElse(
Call(DataType::Bool(1), builtin::isnullptr(), {op->buffer_var}), throw_last_error);
PrimExpr free_op = Call(DataType::Int(32), Op::Get("tir.TVMBackendFreeWorkspace"),
{cast(DataType::Int(32), device_type_.value()),
cast(DataType::Int(32), device_id_.value()), op->buffer_var});
Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});

Stmt body = op->body;
std::vector<Stmt> nest;
while (auto opt = body.as<DeclBuffer>()) {
auto decl = opt.value();
body = decl->body;
decl.CopyOnWrite()->body = Evaluate(0);
nest.push_back(decl);
}

body = SeqStmt::Flatten(body, free_stmt);
body = MergeNest(nest, body);
body = SeqStmt::Flatten(alloc_nullptr_check, body);

body = AttrStmt(op->buffer_var, attr::storage_alignment,
make_const(DataType::Int(32), runtime::kTempAllocaAlignment), body);
body = LetStmt(op->buffer_var,
Call(op->buffer_var.dtype(), Op::Get("tir.TVMBackendAllocWorkspace"),
{cast(DataType::Int(32), device_type_.value()),
cast(DataType::Int(32), device_id_.value()), total_bytes,
IntImm(DataType::Int(32), op->dtype.code()),
IntImm(DataType::Int(32), op->dtype.bits())}),
body);

return body;
}

Expand Down Expand Up @@ -567,9 +579,15 @@ class BuiltinLower : public StmtExprMutator {
ICHECK(device_id_) << "Unknown device id in current IR";
Stmt throw_last_error = Evaluate(Call(DataType::Int(32), builtin::tvm_throw_last_error(), {}));

PrimExpr storage_scope = call->args[0];
Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(),
{GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(),
storage_scope, let->var});
Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);

Stmt body = SeqStmt(
{IfThenElse(Call(DataType::Bool(1), builtin::isnullptr(), {let->var}), throw_last_error),
let->body});
let->body, free_stmt});

DataType dtype =
let->var->type_annotation.as<PointerTypeNode>()->element_type.as<PrimTypeNode>()->dtype;
Expand All @@ -591,15 +609,7 @@ class BuiltinLower : public StmtExprMutator {

Call call_packed = Call(let->var.dtype(), builtin::tvm_call_packed(), args);
Stmt alloca = LetStmt(let->var, call_packed, body);

PrimExpr storage_scope = call->args[0];
Call free_op = Call(DataType::Int(32), builtin::tvm_call_packed(),
{GetDeviceMethodName("free_nd"), device_type_.value(), device_id_.value(),
storage_scope, let->var});

Stmt free_stmt = IfThenElse(free_op != make_zero(DataType::Int(32)), throw_last_error);
body = SeqStmt({alloca, free_stmt});
return body;
return alloca;
}

private:
Expand Down
23 changes: 9 additions & 14 deletions tests/python/unittest/test_tir_transform_lower_tvm_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
import tvm.testing

from tvm import te
from tvm.script import tir as T

import numpy as np


Expand Down Expand Up @@ -200,14 +204,6 @@ class TestLowerDeviceAllocate(tvm.testing.CompareBeforeAfter):
This test validates the current behavior of LowerTVMBuiltin. This
unit test may be improved in the future by addressing:

- The AttrStmt for "storage_alignment" occurs outside the LetStmt
that defines the pointer, which is currently required by
CodeGenLLVM. This fails to match when `map_free_vars=False`
(default), because the first occurrence is undefined.

- The call to TVMBackendFreeWorkspace uses the allocated pointer,
but occurs outside the LetStmt.

- TVMScript always produces "handle" dtype for
`T.tvm_throw_last_error`, while LowerTVMBuiltin outputs "int32"
dtype.
Expand All @@ -225,13 +221,12 @@ def before():

def expected():
T.func_attr({"target": T.target("llvm")})
ptr = T.handle("float32", "global")
ptr: T.handle("float32") = T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32)
T.attr(ptr, "storage_alignment", 64)
with T.LetStmt(T.TVMBackendAllocWorkspace(2, 0, T.uint64(64), 2, 32), var=ptr):
if T.isnullptr(ptr):
T.Call("int32", "tir.tvm_throw_last_error", [])
buf = T.decl_buffer((16,), data=ptr)
buf[0] = T.float32(0)
if T.isnullptr(ptr):
T.Call("int32", "tir.tvm_throw_last_error", [])
buf = T.decl_buffer((16,), data=ptr)
buf[0] = T.float32(0)
if T.TVMBackendFreeWorkspace(2, 0, ptr) != 0:
T.Call("int32", "tir.tvm_throw_last_error", [])

Expand Down