Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 34 additions & 8 deletions src/relax/transform/fuse_tir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,17 +357,43 @@ class BlockNameDeduplicator : public tir::StmtMutator {
}

ffi::String GetUniqueName(const ffi::String& prefix) {
ffi::String unique_prefix = prefix;
auto it = name_count_.find(prefix);
while (name_count_.count(unique_prefix)) {
unique_prefix = prefix + "_" + std::to_string(++it->second);
std::string str_prefix = std::string(prefix);

// Find where the trailing digits start
size_t base_len = str_prefix.length();
while (base_len > 0 && std::isdigit(str_prefix[base_len - 1])) {
--base_len;
}

std::string base_name;
int start_num = 0;

if (base_len < str_prefix.length()) {
base_name = str_prefix.substr(0, base_len);
start_num = std::stoi(str_prefix.substr(base_len));
} else {
base_name = str_prefix;
}

// Check if the original name is available
ffi::String candidate = prefix;
if (!name_count_.count(candidate)) {
name_count_[candidate] = 0;
return candidate;
}

// Generate unique name by incrementing the numeric suffix
int counter = (start_num > 0) ? start_num + 1 : 1;
while (true) {
candidate = ffi::String(base_name + std::to_string(counter));
if (!name_count_.count(candidate)) {
name_count_[candidate] = 0;
return candidate;
}
++counter;
}
name_count_[unique_prefix] = 0;
return unique_prefix;
}
Comment on lines 359 to 404
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The new implementation is a great improvement and correctly handles numeric suffixes. However, there are a few areas where it could be made more robust and readable:

  1. Integer Overflow: int counter and start_num could overflow if a block name has a very large numeric suffix, leading to undefined behavior. Using long long and std::stoll would be safer.
  2. Unhandled Exception: std::stoi can throw std::out_of_range if the numeric suffix is too large to fit in an int, which would crash the program. This should be handled, for example with a try-catch block.
  3. Confusing Logic: The logic for initializing the counter, int counter = (start_num > 0) ? start_num + 1 : 1;, is a bit subtle. While correct, its intent is not immediately obvious. A more explicit check would improve readability and maintainability.

Here is a suggested refactoring that addresses these points:

  ffi::String GetUniqueName(const ffi::String& prefix) {
    std::string str_prefix = std::string(prefix);

    // Find where the trailing digits start
    size_t base_len = str_prefix.length();
    while (base_len > 0 && std::isdigit(str_prefix[base_len - 1])) {
      --base_len;
    }

    std::string base_name;
    long long start_num = 0;
    bool has_suffix = base_len < str_prefix.length();

    if (has_suffix) {
      base_name = str_prefix.substr(0, base_len);
      try {
        start_num = std::stoll(str_prefix.substr(base_len));
      } catch (const std::out_of_range&) {
        // Fallback: if the number is too large, treat the whole string as a base name.
        has_suffix = false;
        base_name = str_prefix;
      }
    } else {
      base_name = str_prefix;
    }

    // Check if the original name is available
    ffi::String candidate = prefix;
    if (!name_count_.count(candidate)) {
      name_count_[candidate] = 0;
      return candidate;
    }

    // Generate unique name by incrementing the numeric suffix
    long long counter = has_suffix ? start_num + 1 : 1;
    while (true) {
      candidate = ffi::String(base_name + std::to_string(counter));
      if (!name_count_.count(candidate)) {
        name_count_[candidate] = 0;
        return candidate;
      }
      ++counter;
      ICHECK_GT(counter, 0) << "Counter overflow when generating unique block name for prefix: "
                            << prefix;
    }
  }

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this review makes sense especially the second. Could you modify the code like that?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sense to me as well, I've updated it.


// TODO(relax-team): It should detects the number suffix and do renaming properly
// e.g. GetUniqueName("name1") should return "name2" instead of "name10".
/*! \brief The count map to make block name unique. */
std::unordered_map<ffi::String, int> name_count_;
};
Expand Down
72 changes: 72 additions & 0 deletions tests/python/relax/test_transform_fuse_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,5 +2444,77 @@ def main(
relax.transform.FuseTIR()(Before)


def test_block_name_numeric_suffix_deduplication():
@I.ir_module
class Before:
@T.prim_func(private=True)
def add1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")):
T.func_attr({"tir.noalias": True})
for i in range(10):
with T.block("compute1"):
vi = T.axis.spatial(10, i)
y[vi] = x[vi] + T.float32(1.0)

@T.prim_func(private=True)
def mul1(x: T.Buffer((10,), "float32"), y: T.Buffer((10,), "float32")):
T.func_attr({"tir.noalias": True})
for i in range(10):
with T.block("compute1"):
vi = T.axis.spatial(10, i)
y[vi] = x[vi] * T.float32(2.0)

@R.function(private=True)
def fused_add_mul(x: R.Tensor((10,), "float32")) -> R.Tensor((10,), dtype="float32"):
R.func_attr({"Primitive": True})
cls = Before
with R.dataflow():
lv1 = R.call_tir(cls.add1, (x,), out_sinfo=R.Tensor((10,), dtype="float32"))
lv2 = R.call_tir(cls.mul1, (lv1,), out_sinfo=R.Tensor((10,), dtype="float32"))
R.output(lv2)
return lv2

@R.function
def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
cls = Before
with R.dataflow():
gv = cls.fused_add_mul(x)
R.output(gv)
return gv

@I.ir_module
class Expected:
@T.prim_func(private=True)
def fused_add_mul(p_x: T.handle, p_output0: T.handle):
T.func_attr({"tir.noalias": True})
x = T.match_buffer(p_x, (T.int64(10),))
y_intermediate_1 = T.match_buffer(p_output0, (T.int64(10),), elem_offset=T.int32(0))
with T.block("root"):
T.reads()
T.writes()
y_intermediate = T.alloc_buffer((T.int64(10),), elem_offset=T.int32(0))
for i in range(10):
with T.block("compute1"):
vi = T.axis.spatial(10, i)
T.reads(x[vi])
T.writes(y_intermediate[vi])
y_intermediate[vi] = x[vi] + T.float32(1.0)
for i in range(10):
with T.block("compute2"):
vi = T.axis.spatial(10, i)
T.reads(y_intermediate[vi])
T.writes(y_intermediate_1[vi])
y_intermediate_1[vi] = y_intermediate[vi] * T.float32(2.0)

@R.function
def main(x: R.Tensor((10,), dtype="float32")) -> R.Tensor((10,), dtype="float32"):
cls = Expected
with R.dataflow():
gv = R.call_tir(cls.fused_add_mul, (x,), out_sinfo=R.Tensor((10,), dtype="float32"))
R.output(gv)
return gv

_check(Before, Expected)


if __name__ == "__main__":
tvm.testing.main()
Loading