@@ -128,7 +128,7 @@ def visualize_tree(graph, show_instructions=False, save_path=""):
128128 plt .show ()
129129
130130
131- def build_tree_recursive (graph , parent_id , node , start_task_id = 2 ):
131+ def build_tree_recursive (graph , parent_id , node , node_order , start_task_id = 2 ):
132132 """
133133 Recursively builds the entire tree starting from the root node.
134134 Adds nodes and edges to the NetworkX graph.
@@ -143,9 +143,10 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2):
143143 # Add the current node with attributes to the graph
144144 dev_score = node .raw_reward .get ("dev_score" , 0 ) * 100
145145 avg_score = node .avg_value () * 100
146+ order = node_order .index (node .id ) if node .id in node_order else ""
146147 graph .add_node (
147148 parent_id ,
148- label = f"{ node .id } \n Avg: { avg_score :.1f} \n Score: { dev_score :.1f} \n Visits: { node .visited } " ,
149+ label = f"{ node .id } \n Avg: { avg_score :.1f} \n Score: { dev_score :.1f} \n Visits: { node .visited } \n Order: { order } " ,
149150 avg_value = node .avg_value (),
150151 dev_score = dev_score ,
151152 visits = node .visited ,
@@ -159,4 +160,4 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2):
159160 for i , child in enumerate (node .children ):
160161 child_id = f"{ parent_id } -{ i } "
161162 graph .add_edge (parent_id , child_id )
162- build_tree_recursive (graph , child_id , child )
163+ build_tree_recursive (graph , child_id , child , node_order )
0 commit comments