2121from dlrover .python .rl .common .constant import RLMasterConstant
2222from dlrover .python .rl .common .enums import RLRoleType
2323from dlrover .python .rl .common .rl_context import RLContext
24- from dlrover .python .rl .master .graph import RLExecutionGraph
24+ from dlrover .python .rl .master .graph import (
25+ FunctionInfo ,
26+ RLExecutionEdge ,
27+ RLExecutionGraph ,
28+ VertexInvocationMeta ,
29+ )
2530from dlrover .python .rl .master .scheduler import GroupOrderedScheduler
2631from dlrover .python .rl .tests .test_class import TestActor , TestRollout
2732from dlrover .python .rl .tests .test_data import TestData
@@ -53,6 +58,8 @@ def test_basic(self):
5358 self .assertEqual (len (graph .get_all_vertices ()), 1 + 1 + 1 + 1 )
5459 self .assertEqual (len (graph .name_vertex_mapping ), 1 + 1 + 1 + 1 )
5560 self .assertEqual (len (graph .name_actor_mapping ), 0 )
61+ # not used for now
62+ self .assertEqual (len (graph .execution_edges ), 0 )
5663
5764 actor_vertices = graph .get_vertices_by_role_type (RLRoleType .ACTOR )
5865 self .assertEqual (len (actor_vertices ), 1 )
@@ -69,6 +76,8 @@ def test_basic(self):
6976 self .assertEqual (rollout_vertex_0 .rank , 0 )
7077 self .assertEqual (rollout_vertex_0 .world_size , 1 )
7178
79+ self .assertIsNotNone (graph .get_unit_resource_by_role (RLRoleType .ACTOR ))
80+
7281 now = int (time .time ())
7382 rollout_vertex_0 .update_runtime_info (
7483 create_time = now , hostname = "test.com" , restart_count = 2
@@ -199,3 +208,28 @@ def test_serialization(self):
199208 vertex .pg_bundle_index ,
200209 graph .name_vertex_mapping [name ].pg_bundle_index ,
201210 )
211+
212+ def test_vertex_invocation_meta (self ):
213+ def test_input ():
214+ pass
215+
216+ function_info = FunctionInfo ("test" , test_input )
217+ self .assertIsNotNone (function_info )
218+ self .assertEqual (function_info .name , "test" )
219+
220+ vertex_invocation_meta = VertexInvocationMeta (
221+ {function_info .name : function_info }
222+ )
223+ self .assertIsNotNone (vertex_invocation_meta )
224+ self .assertEqual (
225+ vertex_invocation_meta .get_func ("test" ), function_info
226+ )
227+
228+ def test_edge_basic (self ):
229+ edge = RLExecutionEdge (0 , RLRoleType .ACTOR , RLRoleType .ROLLOUT , "test" )
230+ self .assertIsNotNone (edge )
231+ self .assertEqual (edge .index , 0 )
232+ self .assertEqual (edge .from_role , RLRoleType .ACTOR )
233+ self .assertEqual (edge .to_role , RLRoleType .ROLLOUT )
234+ self .assertEqual (edge .invocation_name , "test" )
235+ self .assertIsNone (edge .async_group )
0 commit comments