Skip to content

Commit 06ccf5c

Browse files
committed
doc updated
1 parent 8b5acb2 commit 06ccf5c

26 files changed

Lines changed: 422 additions & 399 deletions

dlrover/python/rl/trainer/example/realhf/__init__.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

dlrover/python/rl/trainer/example/verl/base/decorator.py

Lines changed: 59 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -113,61 +113,61 @@ def _split_args_kwargs_data_proto_with_auto_padding(chunks, *args, **kwargs):
113113
return splitted_args, splitted_kwargs
114114

115115

116-
def dispatch_one_to_all(worker_group, *args, **kwargs):
117-
args = tuple([arg] * worker_group.world_size for arg in args)
118-
kwargs = {k: [v] * worker_group.world_size for k, v in kwargs.items()}
116+
def dispatch_one_to_all(role_group, *args, **kwargs):
117+
args = tuple([arg] * role_group.world_size for arg in args)
118+
kwargs = {k: [v] * role_group.world_size for k, v in kwargs.items()}
119119
return args, kwargs
120120

121121

122-
def dummy_direct_rollout_call(worker_group, *args, **kwargs):
122+
def dummy_direct_rollout_call(role_group, *args, **kwargs):
123123
raise NotImplementedError("Direct rollout call is forbidden.")
124124

125125

126-
def dispatch_all_to_all(worker_group, *args, **kwargs):
126+
def dispatch_all_to_all(role_group, *args, **kwargs):
127127
return args, kwargs
128128

129129

130-
def collect_all_to_all(worker_group, output):
130+
def collect_all_to_all(role_group, output):
131131
return output
132132

133133

134-
def dispatch_megatron_compute(worker_group, *args, **kwargs):
134+
def dispatch_megatron_compute(role_group, *args, **kwargs):
135135
"""
136136
User passes in dp data. The data is dispatched to all tp/pp ranks with the same dp
137137
"""
138138

139139
all_args = []
140140
for arg in args:
141141
assert (
142-
isinstance(arg, (Tuple, List)) and len(arg) == worker_group.dp_size
142+
isinstance(arg, (Tuple, List)) and len(arg) == role_group.dp_size()
143143
)
144144
transformed_args = []
145-
for i in range(worker_group.world_size):
146-
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
145+
for i in range(role_group.world_size):
146+
local_dp_rank = role_group.get_megatron_rank_info()[i].dp_rank
147147
transformed_args.append(arg[local_dp_rank])
148148
all_args.append(transformed_args)
149149
all_args = tuple(all_args)
150150

151151
all_kwargs = {}
152152
for k, v in kwargs.items():
153-
assert isinstance(v, (Tuple, List)) and len(v) == worker_group.dp_size
153+
assert isinstance(v, (Tuple, List)) and len(v) == role_group.dp_size
154154
transformed_v = []
155-
for i in range(worker_group.world_size):
156-
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
155+
for i in range(role_group.world_size):
156+
local_dp_rank = role_group.get_megatron_rank_info()[i].dp_rank
157157
transformed_v.append(v[local_dp_rank])
158158
all_kwargs[k] = transformed_v
159159
return all_args, all_kwargs
160160

161161

162-
def collect_megatron_compute(worker_group, output):
162+
def collect_megatron_compute(role_group, output):
163163
"""
164164
Only collect the data from the tp=0 and pp=last and every dp ranks
165165
"""
166166

167167
output_in_dp = []
168-
pp_size = worker_group.get_megatron_global_info().pp_size
169-
for global_rank in range(worker_group.world_size):
170-
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
168+
pp_size = role_group.get_megatron_global_info().pp_size
169+
for global_rank in range(role_group.world_size):
170+
local_rank_info = role_group.get_megatron_rank_info()[global_rank]
171171
if (
172172
local_rank_info.tp_rank == 0
173173
and local_rank_info.pp_rank == pp_size - 1
@@ -177,16 +177,16 @@ def collect_megatron_compute(worker_group, output):
177177
return output_in_dp
178178

179179

180-
def dispatch_megatron_compute_data_proto(worker_group, *args, **kwargs):
180+
def dispatch_megatron_compute_data_proto(role_group, *args, **kwargs):
181181
"""
182182
All the args and kwargs must be DataProto. The batch will be chunked by dp_size and passed to each rank
183183
"""
184184

185185
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(
186-
worker_group.dp_size, *args, **kwargs
186+
role_group.dp_size(), *args, **kwargs
187187
)
188188
return dispatch_megatron_compute(
189-
worker_group, *splitted_args, **splitted_kwargs
189+
role_group, *splitted_args, **splitted_kwargs
190190
)
191191

192192

@@ -208,14 +208,14 @@ def _concat_data_proto_or_future(output: List):
208208
raise NotImplementedError
209209

210210

211-
def collect_megatron_compute_data_proto(worker_group, output):
211+
def collect_megatron_compute_data_proto(role_group, output):
212212
"""
213213
Each output must be a DataProto. We concat the dim=0 of output
214214
"""
215215
import ray
216216
from verl.protocol import DataProto
217217

218-
output = collect_megatron_compute(worker_group, output)
218+
output = collect_megatron_compute(role_group, output)
219219
for o in output:
220220
assert isinstance(
221221
o, (DataProto, ray.ObjectRef)
@@ -224,24 +224,25 @@ def collect_megatron_compute_data_proto(worker_group, output):
224224
return _concat_data_proto_or_future(output)
225225

226226

227-
def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
227+
def dispatch_megatron_pp_as_dp(role_group, *args, **kwargs):
228228
"""
229229
treat pp as dp.
230230
"""
231231

232-
pp_size = worker_group.pp_size
233-
dp_size = worker_group.dp_size
234-
cp_size = worker_group.cp_size
232+
pp_size = role_group.pp_size()
233+
dp_size = role_group.dp_size()
234+
cp_size = role_group.cp_size()
235235
pp_dp_cp_size = pp_size * dp_size * cp_size
236236

237237
all_args = []
238238
for arg in args:
239239
assert isinstance(arg, (List, Tuple)) and len(arg) == pp_dp_cp_size
240240
transformed_args = []
241-
for i in range(worker_group.world_size):
242-
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
243-
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
244-
local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank
241+
for i in range(role_group.world_size):
242+
rank_info = role_group.get_megatron_rank_info()[i]
243+
local_dp_rank = rank_info.dp_rank
244+
local_pp_rank = rank_info.pp_rank
245+
local_cp_rank = rank_info.cp_rank
245246
# compute the rank in arg. Note that the order is dp then cp then pp
246247
# Also note that the outputs within a pp group will be firstly allgathered, then only the output of pp0 will be collected.
247248
# For pp=2 dp=4, a batch of data "ABCDEFGH" should be dispatched and collected in below order:
@@ -264,10 +265,11 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
264265
isinstance(v, (List, Tuple)) and len(v) == pp_dp_cp_size
265266
), f"expect len(v)=={pp_dp_cp_size}, got {len(v)}"
266267
transformed_v = []
267-
for i in range(worker_group.world_size):
268-
local_dp_rank = worker_group.get_megatron_rank_info(rank=i).dp_rank
269-
local_pp_rank = worker_group.get_megatron_rank_info(rank=i).pp_rank
270-
local_cp_rank = worker_group.get_megatron_rank_info(rank=i).cp_rank
268+
for i in range(role_group.world_size):
269+
rank_info = role_group.get_megatron_rank_info()[i]
270+
local_dp_rank = rank_info.dp_rank
271+
local_pp_rank = rank_info.pp_rank
272+
local_cp_rank = rank_info.cp_rank
271273
# compute the rank in arg. Note that the order is dp then cp then pp
272274
dp_cp_rank = local_cp_rank * dp_size + local_dp_rank
273275
arg_rank = dp_cp_rank * pp_size + local_pp_rank
@@ -276,94 +278,92 @@ def dispatch_megatron_pp_as_dp(worker_group, *args, **kwargs):
276278
return all_args, all_kwargs
277279

278280

279-
def collect_megatron_pp_as_dp(worker_group, output):
281+
def collect_megatron_pp_as_dp(role_group, output):
280282
"""
281283
treat pp as dp. Only collect data on tp=0
282284
"""
283285
output_in_dp = []
284-
for global_rank in range(worker_group.world_size):
285-
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
286+
for global_rank in range(role_group.world_size):
287+
local_rank_info = role_group.get_megatron_rank_info()[global_rank]
286288
if local_rank_info.tp_rank == 0:
287289
output_in_dp.append(output[global_rank])
288290
return output_in_dp
289291

290292

291-
def collect_megatron_pp_only(worker_group, output):
293+
def collect_megatron_pp_only(role_group, output):
292294
"""
293295
Only collect output of megatron pp. This is useful when examine weight names as they are identical in tp/dp
294296
"""
295297
output_in_pp = []
296-
for global_rank in range(worker_group.world_size):
297-
local_rank_info = worker_group.get_megatron_rank_info(rank=global_rank)
298+
for global_rank in range(role_group.world_size):
299+
local_rank_info = role_group.get_megatron_rank_info()[global_rank]
298300
if local_rank_info.tp_rank == 0 and local_rank_info.dp_rank == 0:
299301
output_in_pp.append(output[global_rank])
300302
return output_in_pp
301303

302304

303-
def dispatch_megatron_pp_as_dp_data_proto(worker_group, *args, **kwargs):
305+
def dispatch_megatron_pp_as_dp_data_proto(role_group, *args, **kwargs):
304306
pp_dp_cp_size = (
305-
worker_group.dp_size * worker_group.pp_size * worker_group.cp_size
307+
role_group.dp_size() * role_group.pp_size() * role_group.cp_size()
306308
)
307309
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(
308310
pp_dp_cp_size, *args, **kwargs
309311
)
310312
ret = dispatch_megatron_pp_as_dp(
311-
worker_group, *splitted_args, **splitted_kwargs
313+
role_group, *splitted_args, **splitted_kwargs
312314
)
313315
return ret
314316

315317

316-
def collect_megatron_pp_as_dp_data_proto(worker_group, output):
317-
output = collect_megatron_pp_as_dp(worker_group, output)
318+
def collect_megatron_pp_as_dp_data_proto(role_group, output):
319+
output = collect_megatron_pp_as_dp(role_group, output)
318320
return _concat_data_proto_or_future(output)
319321

320322

321-
def dispatch_dp_compute(worker_group, *args, **kwargs):
323+
def dispatch_dp_compute(role_group, *args, **kwargs):
322324
for arg in args:
323325
assert (
324326
isinstance(arg, (Tuple, List))
325-
and len(arg) == worker_group.world_size
327+
and len(arg) == role_group.world_size
326328
)
327329
for k, v in kwargs.items():
328-
assert (
329-
isinstance(v, (Tuple, List)) and len(v) == worker_group.world_size
330-
)
330+
assert isinstance(v, (Tuple, List)) and len(v) == role_group.world_size
331331
return args, kwargs
332332

333333

334-
def collect_dp_compute(worker_group, output):
335-
assert len(output) == worker_group.world_size
334+
def collect_dp_compute(role_group, output):
335+
assert len(output) == role_group.world_size
336336
return output
337337

338338

339-
def dispatch_dp_compute_data_proto(worker_group, *args, **kwargs):
339+
def dispatch_dp_compute_data_proto(role_group, *args, **kwargs):
340340
# Note: enable auto padding for dp compute DatapProto
341341
(
342342
splitted_args,
343343
splitted_kwargs,
344344
) = _split_args_kwargs_data_proto_with_auto_padding(
345-
worker_group.world_size,
345+
role_group.world_size,
346346
*args,
347347
**kwargs,
348348
)
349349
return splitted_args, splitted_kwargs
350350

351351

352-
def dispatch_dp_compute_data_proto_with_func(worker_group, *args, **kwargs):
352+
def dispatch_dp_compute_data_proto_with_func(role_group, *args, **kwargs):
353353
assert isinstance(
354354
args[0], FunctionType
355355
) # NOTE: The first one args is a function!
356356

357357
splitted_args, splitted_kwargs = _split_args_kwargs_data_proto(
358-
worker_group.world_size, *args[1:], **kwargs
358+
role_group.world_size, *args[1:], **kwargs
359359
)
360360
splitted_args_with_func = [
361-
[args[0]] * worker_group.world_size
361+
[args[0]] * role_group.world_size
362362
] + splitted_args
363363
return splitted_args_with_func, splitted_kwargs
364364

365365

366-
def collect_dp_compute_data_proto(worker_group, output):
366+
def collect_dp_compute_data_proto(role_group, output):
367367
import ray
368368
from verl.protocol import DataProto
369369

@@ -372,7 +372,7 @@ def collect_dp_compute_data_proto(worker_group, output):
372372
o, (DataProto, ray.ObjectRef)
373373
), f"expecting {o} to be DataProto, but got {type(o)}"
374374

375-
output = collect_dp_compute(worker_group, output)
375+
output = collect_dp_compute(role_group, output)
376376
return _concat_data_proto_or_future(output)
377377

378378

dlrover/python/rl/trainer/example/verl/base/worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ def get_megatron_global_info(self):
183183
)
184184
return info
185185

186+
@trainer_invocation()
186187
def get_megatron_rank_info(self):
187188
from megatron.core import parallel_state as mpu
188189

0 commit comments

Comments
 (0)