@@ -113,61 +113,61 @@ def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):
113113 return splitted_args , splitted_kwargs
114114
115115
116- def dispatch_one_to_all (worker_group , * args , ** kwargs ):
117- args = tuple ([arg ] * worker_group .world_size for arg in args )
118- kwargs = {k : [v ] * worker_group .world_size for k , v in kwargs .items ()}
116+ def dispatch_one_to_all (role_group , * args , ** kwargs ):
117+ args = tuple ([arg ] * role_group .world_size for arg in args )
118+ kwargs = {k : [v ] * role_group .world_size for k , v in kwargs .items ()}
119119 return args , kwargs
120120
121121
122- def dummy_direct_rollout_call (worker_group , * args , ** kwargs ):
122+ def dummy_direct_rollout_call (role_group , * args , ** kwargs ):
123123 raise NotImplementedError ("Direct rollout call is forbidden." )
124124
125125
126- def dispatch_all_to_all (worker_group , * args , ** kwargs ):
126+ def dispatch_all_to_all (role_group , * args , ** kwargs ):
127127 return args , kwargs
128128
129129
130- def collect_all_to_all (worker_group , output ):
130+ def collect_all_to_all (role_group , output ):
131131 return output
132132
133133
134- def dispatch_megatron_compute (worker_group , * args , ** kwargs ):
134+ def dispatch_megatron_compute (role_group , * args , ** kwargs ):
135135 """
136136 User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
137137 """
138138
139139 all_args = []
140140 for arg in args :
141141 assert (
142- isinstance (arg , (Tuple , List )) and len (arg ) == worker_group .dp_size
142+ isinstance (arg , (Tuple , List )) and len (arg ) == role_group .dp_size ()
143143 )
144144 transformed_args = []
145- for i in range (worker_group .world_size ):
146- local_dp_rank = worker_group .get_megatron_rank_info (rank = i ) .dp_rank
145+ for i in range (role_group .world_size ):
146+ local_dp_rank = role_group .get_megatron_rank_info ()[ i ] .dp_rank
147147 transformed_args .append (arg [local_dp_rank ])
148148 all_args .append (transformed_args )
149149 all_args = tuple (all_args )
150150
151151 all_kwargs = {}
152152 for k , v in kwargs .items ():
153- assert isinstance (v , (Tuple , List )) and len (v ) == worker_group .dp_size
153+ assert isinstance (v , (Tuple , List )) and len (v ) == role_group .dp_size
154154 transformed_v = []
155- for i in range (worker_group .world_size ):
156- local_dp_rank = worker_group .get_megatron_rank_info (rank = i ) .dp_rank
155+ for i in range (role_group .world_size ):
156+ local_dp_rank = role_group .get_megatron_rank_info ()[ i ] .dp_rank
157157 transformed_v .append (v [local_dp_rank ])
158158 all_kwargs [k ] = transformed_v
159159 return all_args , all_kwargs
160160
161161
162- def collect_megatron_compute (worker_group , output ):
162+ def collect_megatron_compute (role_group , output ):
163163 """
164164 Only collect the data from the tp=0 and pp=last and every dp ranks
165165 """
166166
167167 output_in_dp = []
168- pp_size = worker_group .get_megatron_global_info ().pp_size
169- for global_rank in range (worker_group .world_size ):
170- local_rank_info = worker_group .get_megatron_rank_info (rank = global_rank )
168+ pp_size = role_group .get_megatron_global_info ().pp_size
169+ for global_rank in range (role_group .world_size ):
170+ local_rank_info = role_group .get_megatron_rank_info ()[ global_rank ]
171171 if (
172172 local_rank_info .tp_rank == 0
173173 and local_rank_info .pp_rank == pp_size - 1
@@ -177,16 +177,16 @@ def collect_megatron_compute(worker_group, output):
177177 return output_in_dp
178178
179179
180- def dispatch_megatron_compute_data_proto (worker_group , * args , ** kwargs ):
180+ def dispatch_megatron_compute_data_proto (role_group , * args , ** kwargs ):
181181 """
182182 All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
183183 """
184184
185185 splitted_args , splitted_kwargs = _split_args_kwargs_data_proto (
186- worker_group .dp_size , * args , ** kwargs
186+ role_group .dp_size () , * args , ** kwargs
187187 )
188188 return dispatch_megatron_compute (
189- worker_group , * splitted_args , ** splitted_kwargs
189+ role_group , * splitted_args , ** splitted_kwargs
190190 )
191191
192192
@@ -208,14 +208,14 @@ def _concat_data_proto_or_future(output: List):
208208 raise NotImplementedError
209209
210210
211- def collect_megatron_compute_data_proto (worker_group , output ):
211+ def collect_megatron_compute_data_proto (role_group , output ):
212212 """
213213 Each output must be a DataProto. We concat the dim=0 of output
214214 """
215215 import ray
216216 from verl .protocol import DataProto
217217
218- output = collect_megatron_compute (worker_group , output )
218+ output = collect_megatron_compute (role_group , output )
219219 for o in output :
220220 assert isinstance (
221221 o , (DataProto , ray .ObjectRef )
@@ -224,24 +224,25 @@ def collect_megatron_compute_data_proto(worker_group, output):
224224 return _concat_data_proto_or_future (output )
225225
226226
227- def dispatch_megatron_pp_as_dp (worker_group , * args , ** kwargs ):
227+ def dispatch_megatron_pp_as_dp (role_group , * args , ** kwargs ):
228228 """
229229 treat pp as dp.
230230 """
231231
232- pp_size = worker_group .pp_size
233- dp_size = worker_group .dp_size
234- cp_size = worker_group .cp_size
232+ pp_size = role_group .pp_size ()
233+ dp_size = role_group .dp_size ()
234+ cp_size = role_group .cp_size ()
235235 pp_dp_cp_size = pp_size * dp_size * cp_size
236236
237237 all_args = []
238238 for arg in args :
239239 assert isinstance (arg , (List , Tuple )) and len (arg ) == pp_dp_cp_size
240240 transformed_args = []
241- for i in range (worker_group .world_size ):
242- local_dp_rank = worker_group .get_megatron_rank_info (rank = i ).dp_rank
243- local_pp_rank = worker_group .get_megatron_rank_info (rank = i ).pp_rank
244- local_cp_rank = worker_group .get_megatron_rank_info (rank = i ).cp_rank
241+ for i in range (role_group .world_size ):
242+ rank_info = role_group .get_megatron_rank_info ()[i ]
243+ local_dp_rank = rank_info .dp_rank
244+ local_pp_rank = rank_info .pp_rank
245+ local_cp_rank = rank_info .cp_rank
245246 # compute the rank in arg. Note that the order is dp then cp then pp
246247 # Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected.
247248 # For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order:
@@ -264,10 +265,11 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
264265 isinstance (v , (List , Tuple )) and len (v ) == pp_dp_cp_size
265266 ), f"expect len(v)=={ pp_dp_cp_size } , got { len (v )} "
266267 transformed_v = []
267- for i in range (worker_group .world_size ):
268- local_dp_rank = worker_group .get_megatron_rank_info (rank = i ).dp_rank
269- local_pp_rank = worker_group .get_megatron_rank_info (rank = i ).pp_rank
270- local_cp_rank = worker_group .get_megatron_rank_info (rank = i ).cp_rank
268+ for i in range (role_group .world_size ):
269+ rank_info = role_group .get_megatron_rank_info ()[i ]
270+ local_dp_rank = rank_info .dp_rank
271+ local_pp_rank = rank_info .pp_rank
272+ local_cp_rank = rank_info .cp_rank
271273 # compute the rank in arg. Note that the order is dp then cp then pp
272274 dp_cp_rank = local_cp_rank * dp_size + local_dp_rank
273275 arg_rank = dp_cp_rank * pp_size + local_pp_rank
@@ -276,94 +278,92 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
276278 return all_args , all_kwargs
277279
278280
279- def collect_megatron_pp_as_dp (worker_group , output ):
281+ def collect_megatron_pp_as_dp (role_group , output ):
280282 """
281283 treat pp as dp. Only collect data on tp=0
282284 """
283285 output_in_dp = []
284- for global_rank in range (worker_group .world_size ):
285- local_rank_info = worker_group .get_megatron_rank_info (rank = global_rank )
286+ for global_rank in range (role_group .world_size ):
287+ local_rank_info = role_group .get_megatron_rank_info ()[ global_rank ]
286288 if local_rank_info .tp_rank == 0 :
287289 output_in_dp .append (output [global_rank ])
288290 return output_in_dp
289291
290292
291- def collect_megatron_pp_only (worker_group , output ):
293+ def collect_megatron_pp_only (role_group , output ):
292294 """
293295 Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
294296 """
295297 output_in_pp = []
296- for global_rank in range (worker_group .world_size ):
297- local_rank_info = worker_group .get_megatron_rank_info (rank = global_rank )
298+ for global_rank in range (role_group .world_size ):
299+ local_rank_info = role_group .get_megatron_rank_info ()[ global_rank ]
298300 if local_rank_info .tp_rank == 0 and local_rank_info .dp_rank == 0 :
299301 output_in_pp .append (output [global_rank ])
300302 return output_in_pp
301303
302304
303- def dispatch_megatron_pp_as_dp_data_proto (worker_group , * args , ** kwargs ):
305+ def dispatch_megatron_pp_as_dp_data_proto (role_group , * args , ** kwargs ):
304306 pp_dp_cp_size = (
305- worker_group .dp_size * worker_group .pp_size * worker_group .cp_size
307+ role_group .dp_size () * role_group .pp_size () * role_group .cp_size ()
306308 )
307309 splitted_args , splitted_kwargs = _split_args_kwargs_data_proto (
308310 pp_dp_cp_size , * args , ** kwargs
309311 )
310312 ret = dispatch_megatron_pp_as_dp (
311- worker_group , * splitted_args , ** splitted_kwargs
313+ role_group , * splitted_args , ** splitted_kwargs
312314 )
313315 return ret
314316
315317
316- def collect_megatron_pp_as_dp_data_proto (worker_group , output ):
317- output = collect_megatron_pp_as_dp (worker_group , output )
318+ def collect_megatron_pp_as_dp_data_proto (role_group , output ):
319+ output = collect_megatron_pp_as_dp (role_group , output )
318320 return _concat_data_proto_or_future (output )
319321
320322
321- def dispatch_dp_compute (worker_group , * args , ** kwargs ):
323+ def dispatch_dp_compute (role_group , * args , ** kwargs ):
322324 for arg in args :
323325 assert (
324326 isinstance (arg , (Tuple , List ))
325- and len (arg ) == worker_group .world_size
327+ and len (arg ) == role_group .world_size
326328 )
327329 for k , v in kwargs .items ():
328- assert (
329- isinstance (v , (Tuple , List )) and len (v ) == worker_group .world_size
330- )
330+ assert isinstance (v , (Tuple , List )) and len (v ) == role_group .world_size
331331 return args , kwargs
332332
333333
334- def collect_dp_compute (worker_group , output ):
335- assert len (output ) == worker_group .world_size
334+ def collect_dp_compute (role_group , output ):
335+ assert len (output ) == role_group .world_size
336336 return output
337337
338338
339- def dispatch_dp_compute_data_proto (worker_group , * args , ** kwargs ):
339+ def dispatch_dp_compute_data_proto (role_group , * args , ** kwargs ):
340340 # Note: enable auto padding for dp compute DatapProto
341341 (
342342 splitted_args ,
343343 splitted_kwargs ,
344344 ) = _split_args_kwargs_data_proto_with_auto_padding (
345- worker_group .world_size ,
345+ role_group .world_size ,
346346 * args ,
347347 ** kwargs ,
348348 )
349349 return splitted_args , splitted_kwargs
350350
351351
352- def dispatch_dp_compute_data_proto_with_func (worker_group , * args , ** kwargs ):
352+ def dispatch_dp_compute_data_proto_with_func (role_group , * args , ** kwargs ):
353353 assert isinstance (
354354 args [0 ], FunctionType
355355 ) # NOTE: The first one args is a function!
356356
357357 splitted_args , splitted_kwargs = _split_args_kwargs_data_proto (
358- worker_group .world_size , * args [1 :], ** kwargs
358+ role_group .world_size , * args [1 :], ** kwargs
359359 )
360360 splitted_args_with_func = [
361- [args [0 ]] * worker_group .world_size
361+ [args [0 ]] * role_group .world_size
362362 ] + splitted_args
363363 return splitted_args_with_func , splitted_kwargs
364364
365365
366- def collect_dp_compute_data_proto (worker_group , output ):
366+ def collect_dp_compute_data_proto (role_group , output ):
367367 import ray
368368 from verl .protocol import DataProto
369369
@@ -372,7 +372,7 @@ def collect_dp_compute_data_proto(worker_group, output):
372372 o , (DataProto , ray .ObjectRef )
373373 ), f"expecting { o } to be DataProto, but got { type (o )} "
374374
375- output = collect_dp_compute (worker_group , output )
375+ output = collect_dp_compute (role_group , output )
376376 return _concat_data_proto_or_future (output )
377377
378378
0 commit comments