Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions include/tvm/node/object_path.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ class ObjectPathNode : public Object {
class ObjectPath : public ObjectRef {
public:
/*! \brief Create a path that represents the root object itself. */
static ObjectPath Root();
static ObjectPath Root(Optional<String> name = NullOpt);

TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(ObjectPath, ObjectRef, ObjectPathNode);
};
Expand All @@ -135,7 +135,9 @@ class ObjectPath : public ObjectRef {

class RootPathNode final : public ObjectPathNode {
public:
explicit RootPathNode();
Optional<String> name;

explicit RootPathNode(Optional<String> name = NullOpt);

static constexpr const char* _type_key = "RootPath";
TVM_DECLARE_FINAL_OBJECT_INFO(RootPathNode, ObjectPathNode);
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/runtime/object_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
via attribute access, array indexing etc.
"""

from typing import Optional

import tvm._ffi
from tvm.runtime import Object
from . import _ffi_node_api
Expand Down Expand Up @@ -52,8 +54,8 @@ def __init__(self) -> None:
)

@staticmethod
def root() -> "ObjectPath":
return _ffi_node_api.ObjectPathRoot()
def root(root_name: Optional[str] = None) -> "ObjectPath":
return _ffi_node_api.ObjectPathRoot(root_name)

def __eq__(self, other):
return _ffi_node_api.ObjectPathEqual(self, other)
Expand Down
20 changes: 16 additions & 4 deletions src/node/object_path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,19 +197,31 @@ const ObjectPathNode* ObjectPathNode::ParentNode() const {

// ============== ObjectPath ==============

/* static */ ObjectPath ObjectPath::Root() { return ObjectPath(make_object<RootPathNode>()); }
/* static */ ObjectPath ObjectPath::Root(Optional<String> name) {
return ObjectPath(make_object<RootPathNode>(name));
}

TVM_REGISTER_GLOBAL("node.ObjectPathRoot").set_body_typed(ObjectPath::Root);

// ============== Individual path classes ==============

// ----- Root -----

RootPathNode::RootPathNode() : ObjectPathNode(nullptr) {}
RootPathNode::RootPathNode(Optional<String> name) : ObjectPathNode(nullptr), name(name) {}

bool RootPathNode::LastNodeEqual(const ObjectPathNode* other_path) const {
const auto* other = static_cast<const RootPathNode*>(other_path);

bool RootPathNode::LastNodeEqual(const ObjectPathNode* other) const { return true; }
if (other->name.defined() != name.defined()) {
return false;
} else if (name && other->name) {
return name.value() == other->name.value();
} else {
return true;
}
}

std::string RootPathNode::LastNodeString() const { return "<root>"; }
std::string RootPathNode::LastNodeString() const { return name.value_or("<root>"); }

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable).set_dispatch<RootPathNode>(PrintObjectPathRepr);

Expand Down
10 changes: 10 additions & 0 deletions tests/python/unittest/test_object_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def test_root_path():
assert root.parent is None


def test_named_root_path():
root = ObjectPath.root("base_name")
assert isinstance(root, object_path.RootPath)
assert str(root) == "base_name"
assert len(root) == 1
assert root != ObjectPath.root()
assert root == ObjectPath.root("base_name")
assert root.parent is None


def test_path_attr():
path = ObjectPath.root().attr("foo")
assert isinstance(path, object_path.AttributeAccessPath)
Expand Down