Skip to content

Commit 80d12a2

Browse files
authored
feat(common): add Node.find_below() methods to exclude the root node from filtering (ibis-project#8861)
1 parent a5de9ed commit 80d12a2

2 files changed

Lines changed: 55 additions & 35 deletions

File tree

ibis/common/graph.py

Lines changed: 41 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from ibis.common.collections import frozendict
1212
from ibis.common.patterns import NoMatch, Pattern
1313
from ibis.common.typing import _ClassInfo
14-
from ibis.util import experimental
14+
from ibis.util import experimental, promote_list
1515

1616
if TYPE_CHECKING:
1717
from typing_extensions import Self
@@ -340,9 +340,39 @@ def find(
340340
determined by a breadth-first search.
341341
342342
"""
343-
nodes = Graph.from_bfs(self, filter=filter, context=context).nodes()
343+
graph = Graph.from_bfs(self, filter=filter, context=context)
344344
finder = _coerce_finder(finder, context)
345-
return [node for node in nodes if finder(node)]
345+
return [node for node in graph.nodes() if finder(node)]
346+
347+
@experimental
348+
def find_below(
349+
self,
350+
finder: FinderLike,
351+
filter: Optional[FinderLike] = None,
352+
context: Optional[dict] = None,
353+
) -> list[Node]:
354+
"""Find all nodes below the current node matching a given pattern in the graph.
355+
356+
A variant of find() that only returns nodes below the current node in the graph.
357+
358+
Parameters
359+
----------
360+
finder
361+
A type, tuple of types, a pattern or a callable to match upon.
362+
filter
363+
A type, tuple of types, a pattern or a callable to filter out nodes
364+
from the traversal. The traversal will only visit nodes that match
365+
the given filter and stop otherwise.
366+
context
367+
Optional context to use if `finder` or `filter` is a pattern.
368+
369+
Returns
370+
-------
371+
The list of nodes matching the given pattern.
372+
"""
373+
graph = Graph.from_bfs(self.__children__, filter=filter, context=context)
374+
finder = _coerce_finder(finder, context)
375+
return [node for node in graph.nodes() if finder(node)]
346376

347377
@experimental
348378
def find_topmost(
@@ -620,10 +650,8 @@ def bfs(root: Node) -> Graph:
620650
"""
621651
# fast path for the default no filter case, according to benchmarks
622652
# this is gives a 10% speedup compared to the filtered version
623-
if not isinstance(root, Node):
624-
raise TypeError("node must be an instance of ibis.common.graph.Node")
625-
626-
queue = deque([root])
653+
nodes = _flatten_collections(promote_list(root))
654+
queue = deque(nodes)
627655
graph = Graph()
628656

629657
while queue:
@@ -651,15 +679,10 @@ def bfs_while(root: Node, filter: Finder) -> Graph:
651679
A graph constructed from the root node.
652680
653681
"""
654-
if not isinstance(root, Node):
655-
raise TypeError("node must be an instance of ibis.common.graph.Node")
656-
657-
queue = deque()
682+
nodes = _flatten_collections(promote_list(root))
683+
queue = deque(node for node in nodes if filter(node))
658684
graph = Graph()
659685

660-
if filter(root):
661-
queue.append(root)
662-
663686
while queue:
664687
if (node := queue.popleft()) not in graph:
665688
children = tuple(child for child in node.__children__ if filter(child))
@@ -684,10 +707,8 @@ def dfs(root: Node) -> Graph:
684707
"""
685708
# fast path for the default no filter case, according to benchmarks
686709
# this is gives a 10% speedup compared to the filtered version
687-
if not isinstance(root, Node):
688-
raise TypeError("node must be an instance of ibis.common.graph.Node")
689-
690-
stack = deque([root])
710+
nodes = _flatten_collections(promote_list(root))
711+
stack = deque(nodes)
691712
graph = {}
692713

693714
while stack:
@@ -715,15 +736,10 @@ def dfs_while(root: Node, filter: Finder) -> Graph:
715736
A graph constructed from the root node.
716737
717738
"""
718-
if not isinstance(root, Node):
719-
raise TypeError("node must be an instance of ibis.common.graph.Node")
720-
721-
stack = deque()
739+
nodes = _flatten_collections(promote_list(root))
740+
stack = deque(node for node in nodes if filter(node))
722741
graph = {}
723742

724-
if filter(root):
725-
stack.append(root)
726-
727743
while stack:
728744
if (node := stack.pop()) not in graph:
729745
children = tuple(child for child in node.__children__ if filter(child))

ibis/common/tests/test_graph.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,11 +59,8 @@ def copy(self, name=None, children=None):
5959

6060
def test_bfs():
6161
assert list(bfs(A).keys()) == [A, B, C, D, E]
62-
63-
with pytest.raises(
64-
TypeError, match="must be an instance of ibis.common.graph.Node"
65-
):
66-
bfs(1)
62+
assert list(bfs([D, E, B])) == [D, E, B]
63+
assert bfs(1) == {}
6764

6865

6966
def test_construction():
@@ -82,11 +79,8 @@ def test_graph_repr():
8279

8380
def test_dfs():
8481
assert list(dfs(A).keys()) == [D, E, B, C, A]
85-
86-
with pytest.raises(
87-
TypeError, match="must be an instance of ibis.common.graph.Node"
88-
):
89-
dfs(1)
82+
assert list(dfs([D, E, B])) == [D, E, B]
83+
assert dfs(1) == {}
9084

9185

9286
def test_invert():
@@ -393,6 +387,16 @@ def test_node_find_using_pattern():
393387
assert result == [A, B]
394388

395389

390+
def test_node_find_below():
391+
lowercase = MyNode(name="lowercase", children=[])
392+
root = MyNode(name="root", children=[A, B, lowercase])
393+
result = root.find_below(MyNode)
394+
assert result == [A, B, lowercase, C, D, E]
395+
396+
result = root.find_below(lambda x: x.name.islower(), filter=lambda x: x != root)
397+
assert result == [lowercase]
398+
399+
396400
def test_node_find_topmost_using_type():
397401
class FooNode(MyNode):
398402
pass

0 commit comments

Comments
 (0)