Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions src/relax/transform/static_plan_block_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -437,15 +437,18 @@ void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana,
auto it_upper = var_upper_bound_attr.find(tir_var->name_hint);
auto it_lower = var_lower_bound_attr.find(tir_var->name_hint);

if (it_upper != var_upper_bound_attr.end() || it_lower != var_lower_bound_attr.end()) {
// Only bind the variable to a range if an upper bound is explicitly provided.
// Without an upper bound, memory planning cannot determine the required storage size,
// so we skip binding and let the variable remain unbounded.
if (it_upper != var_upper_bound_attr.end()) {
int64_t lower = (it_lower != var_lower_bound_attr.end()) ? it_lower->second->value : 0;
int64_t upper = (it_upper != var_upper_bound_attr.end())
? it_upper->second->value
: std::numeric_limits<int64_t>::max();
int64_t upper = it_upper->second->value;
tvm::Range range = tvm::Range::FromMinExtent(
tvm::IntImm(DataType::Int(64), lower), tvm::IntImm(DataType::Int(64), upper - lower + 1));
ana->Bind(tir_var, range);
dom_map->Set(tir_var, arith::IntSet::FromRange(range));
} else if (it_lower != var_lower_bound_attr.end() && it_lower->second->value >= 0) {
ana->MarkGlobalNonNegValue(tir_var);
} else if (non_negative_var_attr.count(tir_var->name_hint)) {
ana->MarkGlobalNonNegValue(tir_var);
}
Expand Down
263 changes: 239 additions & 24 deletions tests/python/relax/test_transform_static_plan_block_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,245 @@ def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dty
tvm.ir.assert_structural_equal(mod, Expected)


def test_lower_bound_only():
# fmt: off
@tvm.script.ir_module
class Module:
@T.prim_func
def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: T.handle):
T.evaluate(0)

@T.prim_func
def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
T.evaluate(0)

@T.prim_func
def relu(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def log(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def exp(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def pad(rxplaceholder: T.handle, PadInput: T.handle):
T.evaluate(0)

@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"):
R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True})
n = T.int64()
cls = Module
alloc: R.Tensor((2, n), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
_: R.Tuple() = cls.exp(x, alloc)
lv: R.Tensor((2, n), dtype="float32") = alloc
lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
alloc1: R.Tensor((2 * n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", runtime_device_index=0)
_1: R.Tuple() = cls.relu(lv1, alloc1)
lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
alloc2: R.Tensor((2 * n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", runtime_device_index=0)
_2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
alloc3: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32", runtime_device_index=0)
_3: R.Tuple() = cls.pad(lv3, alloc3)
lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
alloc4: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
_4: R.Tuple() = cls.log(lv4, alloc4)
gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
return gv

@I.ir_module
class Expected:
@T.prim_func
def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: T.handle):
T.evaluate(0)

@T.prim_func
def exp(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def log(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def pad(rxplaceholder: T.handle, PadInput: T.handle):
T.evaluate(0)

@T.prim_func
def relu(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
T.evaluate(0)

@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"):
n = T.int64()
R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True})
cls = Expected
storage: R.Object = R.memory.alloc_storage(R.shape([8 * n]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc: R.Tensor((2, n), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), R.dtype("float32"), R.prim_value(0))
_: R.Tuple = cls.exp(x, alloc)
lv: R.Tensor((2, n), dtype="float32") = alloc
lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, R.shape([2 * n]))
storage1: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * n)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc1: R.Tensor((2 * n,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]), R.dtype("float32"))
_1: R.Tuple = cls.relu(lv1, alloc1)
lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
alloc2: R.Tensor((2 * n,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]), R.dtype("float32"))
_2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
storage2: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * n + 2)]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc3: R.Tensor((2 * n + 2,), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([2 * n + 2]), R.dtype("float32"), R.prim_value(0))
_3: R.Tuple = cls.pad(lv3, alloc3)
lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
alloc4: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
_4: R.Tuple = cls.log(lv4, alloc4)
gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
return gv
# fmt: on

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)

Comment thread
mshr-h marked this conversation as resolved.

def test_upper_and_lower_bounds():
# fmt: off
@tvm.script.ir_module
class Module:
@T.prim_func
def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: T.handle):
T.evaluate(0)

@T.prim_func
def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
T.evaluate(0)

@T.prim_func
def relu(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def log(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def exp(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def pad(rxplaceholder: T.handle, PadInput: T.handle):
T.evaluate(0)

@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"):
R.func_attr({"tir_var_upper_bound": {"n": 4}, "tir_var_lower_bound": {"n": 2}, "relax.force_pure": True})
n = T.int64()
cls = Module
alloc: R.Tensor((2, n), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0)
_: R.Tuple() = cls.exp(x, alloc)
lv: R.Tensor((2, n), dtype="float32") = alloc
lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,))
alloc1: R.Tensor((2 * n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", runtime_device_index=0)
_1: R.Tuple() = cls.relu(lv1, alloc1)
lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
alloc2: R.Tensor((2 * n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", runtime_device_index=0)
_2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2)
lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
alloc3: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32", runtime_device_index=0)
_3: R.Tuple() = cls.pad(lv3, alloc3)
lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
alloc4: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0)
_4: R.Tuple() = cls.log(lv4, alloc4)
gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
return gv

@I.ir_module
class Expected:
@T.prim_func
def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: T.handle):
T.evaluate(0)

@T.prim_func
def exp(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def log(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def pad(rxplaceholder: T.handle, PadInput: T.handle):
T.evaluate(0)

@T.prim_func
def relu(rxplaceholder: T.handle, compute: T.handle):
T.evaluate(0)

@T.prim_func
def reshape(rxplaceholder: T.handle, T_reshape: T.handle):
T.evaluate(0)

@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"):
n = T.int64()
R.func_attr({"tir_var_upper_bound": {"n": 4}, "tir_var_lower_bound": {"n": 2}, "relax.force_pure": True})
cls = Expected
storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc: R.Tensor((2, n), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), R.dtype("float32"))
_: R.Tuple = cls.exp(x, alloc)
lv: R.Tensor((2, n), dtype="float32") = alloc
lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, R.shape([2 * n]))
storage1: R.Object = R.memory.alloc_storage(R.shape([40]), R.prim_value(0), R.str("global"), R.dtype("float32"))
alloc1: R.Tensor((2 * n,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]), R.dtype("float32"))
_1: R.Tuple = cls.relu(lv1, alloc1)
lv2: R.Tensor((2 * n,), dtype="float32") = alloc1
alloc2: R.Tensor((2 * n,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]), R.dtype("float32"))
_2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2)
lv3: R.Tensor((2 * n,), dtype="float32") = alloc2
alloc3: R.Tensor((2 * n + 2,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n + 2]), R.dtype("float32"))
_3: R.Tuple = cls.pad(lv3, alloc3)
lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3
alloc4: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0))
_4: R.Tuple = cls.log(lv4, alloc4)
gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4
return gv
# fmt: on

mod = relax.transform.StaticPlanBlockMemory()(Module)
tvm.ir.assert_structural_equal(mod, Expected)


def test_invalid_tir_var_upper_bound():
@tvm.script.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")):
R.func_attr({"tir_var_upper_bound": {"n": [4]}, "relax.force_pure": True})
return x

with pytest.raises((TVMError, TypeError)):
relax.transform.StaticPlanBlockMemory()(Module)


def test_invalid_tir_var_lower_bound():
@tvm.script.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")):
R.func_attr({"tir_var_lower_bound": {"n": [4]}, "relax.force_pure": True})
return x

with pytest.raises((TVMError, TypeError)):
relax.transform.StaticPlanBlockMemory()(Module)


def test_tir_var_decreasing_monotone():
# fmt: off
@I.ir_module
Expand Down Expand Up @@ -1335,30 +1574,6 @@ def func2(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32
tvm.ir.assert_structural_equal(mod, Expected)


def test_invalid_tir_var_upper_bound():
@tvm.script.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")):
R.func_attr({"tir_var_upper_bound": {"n": [4]}, "relax.force_pure": True})
return x

with pytest.raises((TVMError, TypeError)):
relax.transform.StaticPlanBlockMemory()(Module)


def test_invalid_tir_var_lower_bound():
@tvm.script.ir_module
class Module:
@R.function
def main(x: R.Tensor((2, "n"), dtype="float32")):
R.func_attr({"tir_var_lower_bound": {"n": [4]}, "relax.force_pure": True})
return x

with pytest.raises((TVMError, TypeError)):
relax.transform.StaticPlanBlockMemory()(Module)


def test_add():
@I.ir_module
class Module:
Expand Down
Loading