1111from ibis .common .collections import frozendict
1212from ibis .common .patterns import NoMatch , Pattern
1313from ibis .common .typing import _ClassInfo
14- from ibis .util import experimental
14+ from ibis .util import experimental , promote_list
1515
1616if TYPE_CHECKING :
1717 from typing_extensions import Self
@@ -340,9 +340,39 @@ def find(
340340 determined by a breadth-first search.
341341
342342 """
343- nodes = Graph .from_bfs (self , filter = filter , context = context ). nodes ( )
343+ graph = Graph .from_bfs (self , filter = filter , context = context )
344344 finder = _coerce_finder (finder , context )
345- return [node for node in nodes if finder (node )]
345+ return [node for node in graph .nodes () if finder (node )]
346+
347+ @experimental
348+ def find_below (
349+ self ,
350+ finder : FinderLike ,
351+ filter : Optional [FinderLike ] = None ,
352+ context : Optional [dict ] = None ,
353+ ) -> list [Node ]:
354+ """Find all nodes below the current node matching a given pattern in the graph.
355+
356+ A variant of find() that only returns nodes below the current node in the graph.
357+
358+ Parameters
359+ ----------
360+ finder
361+ A type, tuple of types, a pattern or a callable to match upon.
362+ filter
363+ A type, tuple of types, a pattern or a callable to filter out nodes
364+ from the traversal. The traversal will only visit nodes that match
365+ the given filter and stop otherwise.
366+ context
367+ Optional context to use if `finder` or `filter` is a pattern.
368+
369+ Returns
370+ -------
371+ The list of nodes matching the given pattern.
372+ """
373+ graph = Graph .from_bfs (self .__children__ , filter = filter , context = context )
374+ finder = _coerce_finder (finder , context )
375+ return [node for node in graph .nodes () if finder (node )]
346376
347377 @experimental
348378 def find_topmost (
@@ -620,10 +650,8 @@ def bfs(root: Node) -> Graph:
620650 """
621651 # fast path for the default no filter case, according to benchmarks
622652 # this is gives a 10% speedup compared to the filtered version
623- if not isinstance (root , Node ):
624- raise TypeError ("node must be an instance of ibis.common.graph.Node" )
625-
626- queue = deque ([root ])
653+ nodes = _flatten_collections (promote_list (root ))
654+ queue = deque (nodes )
627655 graph = Graph ()
628656
629657 while queue :
@@ -651,15 +679,10 @@ def bfs_while(root: Node, filter: Finder) -> Graph:
651679 A graph constructed from the root node.
652680
653681 """
654- if not isinstance (root , Node ):
655- raise TypeError ("node must be an instance of ibis.common.graph.Node" )
656-
657- queue = deque ()
682+ nodes = _flatten_collections (promote_list (root ))
683+ queue = deque (node for node in nodes if filter (node ))
658684 graph = Graph ()
659685
660- if filter (root ):
661- queue .append (root )
662-
663686 while queue :
664687 if (node := queue .popleft ()) not in graph :
665688 children = tuple (child for child in node .__children__ if filter (child ))
@@ -684,10 +707,8 @@ def dfs(root: Node) -> Graph:
684707 """
685708 # fast path for the default no filter case, according to benchmarks
686709 # this is gives a 10% speedup compared to the filtered version
687- if not isinstance (root , Node ):
688- raise TypeError ("node must be an instance of ibis.common.graph.Node" )
689-
690- stack = deque ([root ])
710+ nodes = _flatten_collections (promote_list (root ))
711+ stack = deque (nodes )
691712 graph = {}
692713
693714 while stack :
@@ -715,15 +736,10 @@ def dfs_while(root: Node, filter: Finder) -> Graph:
715736 A graph constructed from the root node.
716737
717738 """
718- if not isinstance (root , Node ):
719- raise TypeError ("node must be an instance of ibis.common.graph.Node" )
720-
721- stack = deque ()
739+ nodes = _flatten_collections (promote_list (root ))
740+ stack = deque (node for node in nodes if filter (node ))
722741 graph = {}
723742
724- if filter (root ):
725- stack .append (root )
726-
727743 while stack :
728744 if (node := stack .pop ()) not in graph :
729745 children = tuple (child for child in node .__children__ if filter (child ))
0 commit comments