Skip to content

Commit a323007

Browse files
committed
fix ci
Signed-off-by: Hollow Man <hollowman@opensuse.org>
1 parent ebf84f8 commit a323007

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

verl/utils/fsdp_utils.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def _get_attr(attr_name, default_value=None):
9696
return None
9797

9898
default_transformer_cls_names_to_wrap = getattr(module, "_no_split_modules", None)
99-
fsdp_transformer_layer_cls_to_wrap = _get_attr(
100-
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
99+
fsdp_transformer_layer_cls_to_wrap = _normalize_transformer_layer_cls_to_wrap(
100+
_get_attr("transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap)
101101
)
102102
min_num_params = _get_attr("min_num_params", 0)
103103
auto_wrap_policy = None
@@ -143,6 +143,22 @@ def lambda_policy_fn(module):
143143
return auto_wrap_policy
144144

145145

146+
def _normalize_transformer_layer_cls_to_wrap(transformer_layer_cls_to_wrap):
147+
"""Normalize transformer layer class names to a reusable tuple."""
148+
if transformer_layer_cls_to_wrap is None:
149+
return None
150+
151+
if isinstance(transformer_layer_cls_to_wrap, str):
152+
transformer_layer_cls_to_wrap = (transformer_layer_cls_to_wrap,)
153+
else:
154+
transformer_layer_cls_to_wrap = tuple(transformer_layer_cls_to_wrap)
155+
156+
if not transformer_layer_cls_to_wrap or any(layer_class is None for layer_class in transformer_layer_cls_to_wrap):
157+
raise AssertionError("transformer_layer_cls_to_wrap must contain at least one non-None class name")
158+
159+
return transformer_layer_cls_to_wrap
160+
161+
146162
@torch.no_grad()
147163
def offload_fsdp_model_to_cpu(model: FSDP, empty_cache: bool = True):
148164
if fsdp_version(model) == 2 or fsdp_version(model) == 0:
@@ -536,15 +552,11 @@ def apply_fsdp2(model, fsdp_kwargs, config):
536552
assert CPUOffloadPolicy is not None, "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
537553

538554
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
539-
fsdp_transformer_layer_cls_to_wrap = config.get("wrap_policy", {}).get(
540-
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
555+
wrap_policy = config.get("wrap_policy", {}) or {}
556+
fsdp_transformer_layer_cls_to_wrap = _normalize_transformer_layer_cls_to_wrap(
557+
wrap_policy.get("transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap)
541558
)
542559

543-
if isinstance(fsdp_transformer_layer_cls_to_wrap, str):
544-
fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap]
545-
546-
assert len(fsdp_transformer_layer_cls_to_wrap) > 0 and fsdp_transformer_layer_cls_to_wrap[0] is not None
547-
548560
modules = _select_fsdp2_wrap_targets(model, fsdp_transformer_layer_cls_to_wrap)
549561

550562
for idx, module in enumerate(modules):

0 commit comments

Comments
 (0)