Skip to content

Commit 4bed19b

Browse files
authored
Merge pull request FoundationAgents#1526 from garylin2099/sela-lyz
add visit order
2 parents d304fc3 + 3a8fdc6 commit 4bed19b

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

expo/evaluation/visualize_mcts.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def visualize_tree(graph, show_instructions=False, save_path=""):
128128
plt.show()
129129

130130

131-
def build_tree_recursive(graph, parent_id, node, start_task_id=2):
131+
def build_tree_recursive(graph, parent_id, node, node_order, start_task_id=2):
132132
"""
133133
Recursively builds the entire tree starting from the root node.
134134
Adds nodes and edges to the NetworkX graph.
@@ -143,9 +143,10 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2):
143143
# Add the current node with attributes to the graph
144144
dev_score = node.raw_reward.get("dev_score", 0) * 100
145145
avg_score = node.avg_value() * 100
146+
order = node_order.index(node.id) if node.id in node_order else ""
146147
graph.add_node(
147148
parent_id,
148-
label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}",
149+
label=f"{node.id}\nAvg: {avg_score:.1f}\nScore: {dev_score:.1f}\nVisits: {node.visited}\nOrder: {order}",
149150
avg_value=node.avg_value(),
150151
dev_score=dev_score,
151152
visits=node.visited,
@@ -159,4 +160,4 @@ def build_tree_recursive(graph, parent_id, node, start_task_id=2):
159160
for i, child in enumerate(node.children):
160161
child_id = f"{parent_id}-{i}"
161162
graph.add_edge(parent_id, child_id)
162-
build_tree_recursive(graph, child_id, child)
163+
build_tree_recursive(graph, child_id, child, node_order)

expo/scripts/visualize_experiment.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
)
1818

1919
mcts.load_tree()
20+
mcts.load_node_order()
2021
root = mcts.root_node
22+
node_order = mcts.node_order
2123
G = nx.DiGraph()
22-
build_tree_recursive(G, "0", root)
24+
build_tree_recursive(G, "0", root, node_order)
2325
visualize_tree(G, save_path=f"results/{args.task}-tree.png")

0 commit comments

Comments
 (0)