diff --git a/src/relax/transform/static_plan_block_memory.cc b/src/relax/transform/static_plan_block_memory.cc index cebbaa4ce5ac..2a41687983a1 100644 --- a/src/relax/transform/static_plan_block_memory.cc +++ b/src/relax/transform/static_plan_block_memory.cc @@ -437,15 +437,18 @@ void SetTIRVarRangeConstraints(Function func, arith::Analyzer* ana, auto it_upper = var_upper_bound_attr.find(tir_var->name_hint); auto it_lower = var_lower_bound_attr.find(tir_var->name_hint); - if (it_upper != var_upper_bound_attr.end() || it_lower != var_lower_bound_attr.end()) { + // Only bind the variable to a range if an upper bound is explicitly provided. + // Without an upper bound, memory planning cannot determine the required storage size, + // so we skip binding and let the variable remain unbounded. + if (it_upper != var_upper_bound_attr.end()) { int64_t lower = (it_lower != var_lower_bound_attr.end()) ? it_lower->second->value : 0; - int64_t upper = (it_upper != var_upper_bound_attr.end()) - ? it_upper->second->value - : std::numeric_limits::max(); + int64_t upper = it_upper->second->value; tvm::Range range = tvm::Range::FromMinExtent( tvm::IntImm(DataType::Int(64), lower), tvm::IntImm(DataType::Int(64), upper - lower + 1)); ana->Bind(tir_var, range); dom_map->Set(tir_var, arith::IntSet::FromRange(range)); + } else if (it_lower != var_lower_bound_attr.end() && it_lower->second->value >= 0) { + ana->MarkGlobalNonNegValue(tir_var); } else if (non_negative_var_attr.count(tir_var->name_hint)) { ana->MarkGlobalNonNegValue(tir_var); } diff --git a/tests/python/relax/test_transform_static_plan_block_memory.py b/tests/python/relax/test_transform_static_plan_block_memory.py index 06e4ea142e95..87c6f12f53d9 100644 --- a/tests/python/relax/test_transform_static_plan_block_memory.py +++ b/tests/python/relax/test_transform_static_plan_block_memory.py @@ -1018,6 +1018,245 @@ def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dty tvm.ir.assert_structural_equal(mod, Expected) +def test_lower_bound_only(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: T.handle): + T.evaluate(0) + + @T.prim_func + def reshape(rxplaceholder: T.handle, T_reshape: T.handle): + T.evaluate(0) + + @T.prim_func + def relu(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def log(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def exp(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def pad(rxplaceholder: T.handle, PadInput: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"): + R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True}) + n = T.int64() + cls = Module + alloc: R.Tensor((2, n), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0) + _: R.Tuple() = cls.exp(x, alloc) + lv: R.Tensor((2, n), dtype="float32") = alloc + lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,)) + alloc1: R.Tensor((2 * n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", runtime_device_index=0) + _1: R.Tuple() = cls.relu(lv1, alloc1) + lv2: R.Tensor((2 * n,), dtype="float32") = alloc1 + alloc2: R.Tensor((2 * n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", runtime_device_index=0) + _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2) + lv3: R.Tensor((2 * n,), dtype="float32") = alloc2 + alloc3: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32", runtime_device_index=0) + _3: R.Tuple() = cls.pad(lv3, alloc3) + lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3 + alloc4: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) + _4: R.Tuple() = cls.log(lv4, alloc4) + gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4 + return gv + + @I.ir_module + class Expected: + @T.prim_func + def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: T.handle): + T.evaluate(0) + + @T.prim_func + def exp(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def log(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def pad(rxplaceholder: T.handle, PadInput: T.handle): + T.evaluate(0) + + @T.prim_func + def relu(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def reshape(rxplaceholder: T.handle, T_reshape: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"): + n = T.int64() + R.func_attr({"tir_var_lower_bound": {"n": 2}, "relax.force_pure": True}) + cls = Expected + storage: R.Object = R.memory.alloc_storage(R.shape([8 * n]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc: R.Tensor((2, n), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), R.dtype("float32"), R.prim_value(0)) + _: R.Tuple = cls.exp(x, alloc) + lv: R.Tensor((2, n), dtype="float32") = alloc + lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, R.shape([2 * n])) + storage1: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * n)]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc1: R.Tensor((2 * n,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]), R.dtype("float32")) + _1: R.Tuple = cls.relu(lv1, alloc1) + lv2: R.Tensor((2 * n,), dtype="float32") = alloc1 + alloc2: R.Tensor((2 * n,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]), R.dtype("float32")) + _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2) + lv3: R.Tensor((2 * n,), dtype="float32") = alloc2 + storage2: R.Object = R.memory.alloc_storage(R.shape([4 * (2 * n + 2)]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc3: R.Tensor((2 * n + 2,), dtype="float32") = R.memory.alloc_tensor(storage2, R.prim_value(0), R.shape([2 * n + 2]), R.dtype("float32"), R.prim_value(0)) + _3: R.Tuple = cls.pad(lv3, alloc3) + lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3 + alloc4: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0)) + _4: R.Tuple = cls.log(lv4, alloc4) + gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4 + return gv + # fmt: on + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_upper_and_lower_bounds(): + # fmt: off + @tvm.script.ir_module + class Module: + @T.prim_func + def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: T.handle): + T.evaluate(0) + + @T.prim_func + def reshape(rxplaceholder: T.handle, T_reshape: T.handle): + T.evaluate(0) + + @T.prim_func + def relu(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def log(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def exp(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def pad(rxplaceholder: T.handle, PadInput: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"): + R.func_attr({"tir_var_upper_bound": {"n": 4}, "tir_var_lower_bound": {"n": 2}, "relax.force_pure": True}) + n = T.int64() + cls = Module + alloc: R.Tensor((2, n), dtype="float32") = R.builtin.alloc_tensor(R.shape([2, n]), dtype="float32", runtime_device_index=0) + _: R.Tuple() = cls.exp(x, alloc) + lv: R.Tensor((2, n), dtype="float32") = alloc + lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, (2 * n,)) + alloc1: R.Tensor((2 * n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", runtime_device_index=0) + _1: R.Tuple() = cls.relu(lv1, alloc1) + lv2: R.Tensor((2 * n,), dtype="float32") = alloc1 + alloc2: R.Tensor((2 * n,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n]), dtype="float32", runtime_device_index=0) + _2: R.Tuple() = cls.add(lv2, R.const(1, "float32"), alloc2) + lv3: R.Tensor((2 * n,), dtype="float32") = alloc2 + alloc3: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([2 * n + 2]), dtype="float32", runtime_device_index=0) + _3: R.Tuple() = cls.pad(lv3, alloc3) + lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3 + alloc4: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), dtype="float32", runtime_device_index=0) + _4: R.Tuple() = cls.log(lv4, alloc4) + gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4 + return gv + + @I.ir_module + class Expected: + @T.prim_func + def add(rxplaceholder: T.handle, rxplaceholder_1: T.handle, T_add: T.handle): + T.evaluate(0) + + @T.prim_func + def exp(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def log(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def pad(rxplaceholder: T.handle, PadInput: T.handle): + T.evaluate(0) + + @T.prim_func + def relu(rxplaceholder: T.handle, compute: T.handle): + T.evaluate(0) + + @T.prim_func + def reshape(rxplaceholder: T.handle, T_reshape: T.handle): + T.evaluate(0) + + @R.function + def main(x: R.Tensor((2, "n"), dtype="float32")) -> R.Tensor(("2 * n + 2",), dtype="float32"): + n = T.int64() + R.func_attr({"tir_var_upper_bound": {"n": 4}, "tir_var_lower_bound": {"n": 2}, "relax.force_pure": True}) + cls = Expected + storage: R.Object = R.memory.alloc_storage(R.shape([32]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc: R.Tensor((2, n), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2, n]), R.dtype("float32")) + _: R.Tuple = cls.exp(x, alloc) + lv: R.Tensor((2, n), dtype="float32") = alloc + lv1: R.Tensor((2 * n,), dtype="float32") = R.reshape(lv, R.shape([2 * n])) + storage1: R.Object = R.memory.alloc_storage(R.shape([40]), R.prim_value(0), R.str("global"), R.dtype("float32")) + alloc1: R.Tensor((2 * n,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n]), R.dtype("float32")) + _1: R.Tuple = cls.relu(lv1, alloc1) + lv2: R.Tensor((2 * n,), dtype="float32") = alloc1 + alloc2: R.Tensor((2 * n,), dtype="float32") = R.memory.alloc_tensor(storage, R.prim_value(0), R.shape([2 * n]), R.dtype("float32")) + _2: R.Tuple = cls.add(lv2, R.const(1, "float32"), alloc2) + lv3: R.Tensor((2 * n,), dtype="float32") = alloc2 + alloc3: R.Tensor((2 * n + 2,), dtype="float32") = R.memory.alloc_tensor(storage1, R.prim_value(0), R.shape([2 * n + 2]), R.dtype("float32")) + _3: R.Tuple = cls.pad(lv3, alloc3) + lv4: R.Tensor((2 * n + 2,), dtype="float32") = alloc3 + alloc4: R.Tensor((2 * n + 2,), dtype="float32") = R.builtin.alloc_tensor(R.shape([10]), R.dtype("float32"), R.prim_value(0)) + _4: R.Tuple = cls.log(lv4, alloc4) + gv: R.Tensor((2 * n + 2,), dtype="float32") = alloc4 + return gv + # fmt: on + + mod = relax.transform.StaticPlanBlockMemory()(Module) + tvm.ir.assert_structural_equal(mod, Expected) + + +def test_invalid_tir_var_upper_bound(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, "n"), dtype="float32")): + R.func_attr({"tir_var_upper_bound": {"n": [4]}, "relax.force_pure": True}) + return x + + with pytest.raises((TVMError, TypeError)): + relax.transform.StaticPlanBlockMemory()(Module) + + +def test_invalid_tir_var_lower_bound(): + @tvm.script.ir_module + class Module: + @R.function + def main(x: R.Tensor((2, "n"), dtype="float32")): + R.func_attr({"tir_var_lower_bound": {"n": [4]}, "relax.force_pure": True}) + return x + + with pytest.raises((TVMError, TypeError)): + relax.transform.StaticPlanBlockMemory()(Module) + + def test_tir_var_decreasing_monotone(): # fmt: off @I.ir_module @@ -1335,30 +1574,6 @@ def func2(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32 tvm.ir.assert_structural_equal(mod, Expected) -def test_invalid_tir_var_upper_bound(): - @tvm.script.ir_module - class Module: - @R.function - def main(x: R.Tensor((2, "n"), dtype="float32")): - R.func_attr({"tir_var_upper_bound": {"n": [4]}, "relax.force_pure": True}) - return x - - with pytest.raises((TVMError, TypeError)): - relax.transform.StaticPlanBlockMemory()(Module) - - -def test_invalid_tir_var_lower_bound(): - @tvm.script.ir_module - class Module: - @R.function - def main(x: R.Tensor((2, "n"), dtype="float32")): - R.func_attr({"tir_var_lower_bound": {"n": [4]}, "relax.force_pure": True}) - return x - - with pytest.raises((TVMError, TypeError)): - relax.transform.StaticPlanBlockMemory()(Module) - - def test_add(): @I.ir_module class Module: