diff --git a/include/tvm/node/script_printer.h b/include/tvm/node/script_printer.h index c65394f7b7e1..f25ffa1caa1c 100644 --- a/include/tvm/node/script_printer.h +++ b/include/tvm/node/script_printer.h @@ -68,6 +68,8 @@ class PrinterConfigNode : public Object { int num_context_lines = -1; /*! \brief Whether to output with syntax sugar, set false for complete printing. */ bool syntax_sugar = true; + /*! \brief Whether variable names should include the object's address */ + bool show_object_address = false; /* \brief Object path to be underlined */ Array path_to_underline = Array(); /*! \brief Object path to be annotated. */ @@ -91,6 +93,7 @@ class PrinterConfigNode : public Object { v->Visit("print_line_numbers", &print_line_numbers); v->Visit("num_context_lines", &num_context_lines); v->Visit("syntax_sugar", &syntax_sugar); + v->Visit("show_object_address", &show_object_address); v->Visit("path_to_underline", &path_to_underline); v->Visit("path_to_annotate", &path_to_annotate); v->Visit("obj_to_underline", &obj_to_underline); diff --git a/python/tvm/runtime/script_printer.py b/python/tvm/runtime/script_printer.py index 269cab8e5d4d..2ed2b8ddd4bc 100644 --- a/python/tvm/runtime/script_printer.py +++ b/python/tvm/runtime/script_printer.py @@ -40,6 +40,7 @@ class PrinterConfig(Object): print_line_numbers: bool num_context_lines: int syntax_sugar: bool + show_object_address: bool path_to_underline: Optional[List[ObjectPath]] path_to_annotate: Optional[Dict[ObjectPath, str]] obj_to_underline: Optional[List[Object]] @@ -60,6 +61,7 @@ def __init__( print_line_numbers: bool = False, num_context_lines: Optional[int] = None, syntax_sugar: bool = True, + show_object_address: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -79,6 +81,7 @@ def __init__( "print_line_numbers": print_line_numbers, "num_context_lines": num_context_lines, "syntax_sugar": syntax_sugar, + "show_object_address": show_object_address, "path_to_underline": path_to_underline, "path_to_annotate": path_to_annotate, "obj_to_underline": obj_to_underline, @@ -119,6 +122,7 @@ def script( print_line_numbers: bool = False, num_context_lines: int = -1, syntax_sugar: bool = True, + show_object_address: bool = False, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -153,6 +157,8 @@ def script( The number of lines of context to print before and after the line to underline. syntax_sugar: bool = True Whether to output with syntax sugar, set false for complete printing. + show_object_address: bool = False + Whether to include the object's adddress as part of the TVMScript name path_to_underline : Optional[List[ObjectPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[ObjectPath, str]] = None @@ -182,6 +188,7 @@ def script( print_line_numbers=print_line_numbers, num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, + show_object_address=show_object_address, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, @@ -206,6 +213,7 @@ def show( print_line_numbers: bool = False, num_context_lines: int = -1, syntax_sugar: bool = True, + show_object_address: bool = True, path_to_underline: Optional[List[ObjectPath]] = None, path_to_annotate: Optional[Dict[ObjectPath, str]] = None, obj_to_underline: Optional[List[Object]] = None, @@ -245,6 +253,8 @@ def show( The number of lines of context to print before and after the line to underline. syntax_sugar: bool = True Whether to output with syntax sugar, set false for complete printing. + show_object_address: bool = False + Whether to include the object's adddress as part of the TVMScript name path_to_underline : Optional[List[ObjectPath]] = None Object path to be underlined path_to_annotate : Optional[Dict[ObjectPath, str]] = None @@ -272,6 +282,7 @@ def show( print_line_numbers=print_line_numbers, num_context_lines=num_context_lines, syntax_sugar=syntax_sugar, + show_object_address=show_object_address, path_to_underline=path_to_underline, path_to_annotate=path_to_annotate, obj_to_underline=obj_to_underline, diff --git a/src/node/script_printer.cc b/src/node/script_printer.cc index 8293af402ed9..28e72be78945 100644 --- a/src/node/script_printer.cc +++ b/src/node/script_printer.cc @@ -88,6 +88,10 @@ PrinterConfig::PrinterConfig(Map config_dict) { if (auto v = config_dict.Get("syntax_sugar")) { n->syntax_sugar = Downcast(v)->value; } + if (auto v = config_dict.Get("show_object_address")) { + n->show_object_address = Downcast(v)->value; + } + this->data_ = std::move(n); } diff --git a/src/script/printer/ir_docsifier.cc b/src/script/printer/ir_docsifier.cc index 7cd27057f4d4..7dc971bd2c35 100644 --- a/src/script/printer/ir_docsifier.cc +++ b/src/script/printer/ir_docsifier.cc @@ -21,6 +21,8 @@ #include #include +#include + #include "./utils.h" namespace tvm { @@ -29,7 +31,13 @@ namespace printer { IdDoc IRDocsifierNode::Define(const ObjectRef& obj, const Frame& frame, const String& name_hint) { ICHECK(obj2info.find(obj) == obj2info.end()) << "Duplicated object: " << obj; - String name = GenerateUniqueName(name_hint, this->defined_names); + String name = name_hint; + if (cfg->show_object_address) { + std::stringstream stream; + stream << name << "_" << obj.get(); + name = stream.str(); + } + name = GenerateUniqueName(name, this->defined_names); this->defined_names.insert(name); DocCreator doc_factory = [name]() { return IdDoc(name); }; obj2info.insert({obj, VariableInfo{std::move(doc_factory), name}}); diff --git a/tests/python/unittest/test_tvmscript_printer_tir.py b/tests/python/unittest/test_tvmscript_printer_tir.py index e6334553d64f..572c22e80e51 100644 --- a/tests/python/unittest/test_tvmscript_printer_tir.py +++ b/tests/python/unittest/test_tvmscript_printer_tir.py @@ -15,6 +15,9 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=missing-docstring + +import re + import tvm.testing from tvm import ir, tir from tvm.ir import Range @@ -798,5 +801,43 @@ def main(): _assert_print(root_block_explicitly, expected_output) +def test_variable_with_cpp_address(): + """The show_object_address option displays the C++ addressess + + Because the C++ address may vary with each execution, the output + produced with this option cannot be compared to a fixed string. + Instead, this test uses the normal script output to generate a + regular expression against with the test output must match. The + regular expression validates that all names have been appended + with "_0x" followed by a hexadecimal number, and that the address + is the same for each variable. + """ + from tvm.script import tir as T + + # The test function has all named objects suffixed with "_name", + # to avoid spurious replacement when generating the expected + # regex. + @T.prim_func + def func(a_name: T.handle): + N_name = T.int64() + A_name = T.match_buffer(a_name, N_name, "float32") + for i_name in range(N_name): + A_name[i_name] = A_name[i_name] + 1.0 + + without_address = func.script(show_object_address=False) + script = func.script(show_object_address=True) + + expected_regex = re.escape(without_address) + for name in ["a_name", "A_name", "N_name", "i_name"]: + # Replace all occurrences with a backref to an earlier match + expected_regex = expected_regex.replace(name, rf"(?P={name})") + # Then replace the first such backref with a capturing group. + expected_regex = expected_regex.replace( + rf"(?P={name})", rf"(?P<{name}>{name}_0x[A-Fa-f0-9]+)", 1 + ) + + assert re.match(expected_regex, script) + + if __name__ == "__main__": tvm.testing.main()