@@ -96,8 +96,8 @@ def _get_attr(attr_name, default_value=None):
9696 return None
9797
9898 default_transformer_cls_names_to_wrap = getattr (module , "_no_split_modules" , None )
99- fsdp_transformer_layer_cls_to_wrap = _get_attr (
100- "transformer_layer_cls_to_wrap" , default_transformer_cls_names_to_wrap
99+ fsdp_transformer_layer_cls_to_wrap = _normalize_transformer_layer_cls_to_wrap (
100+ _get_attr ( "transformer_layer_cls_to_wrap" , default_transformer_cls_names_to_wrap )
101101 )
102102 min_num_params = _get_attr ("min_num_params" , 0 )
103103 auto_wrap_policy = None
@@ -143,6 +143,22 @@ def lambda_policy_fn(module):
143143 return auto_wrap_policy
144144
145145
146+ def _normalize_transformer_layer_cls_to_wrap (transformer_layer_cls_to_wrap ):
147+ """Normalize transformer layer class names to a reusable tuple."""
148+ if transformer_layer_cls_to_wrap is None :
149+ return None
150+
151+ if isinstance (transformer_layer_cls_to_wrap , str ):
152+ transformer_layer_cls_to_wrap = (transformer_layer_cls_to_wrap ,)
153+ else :
154+ transformer_layer_cls_to_wrap = tuple (transformer_layer_cls_to_wrap )
155+
156+ if not transformer_layer_cls_to_wrap or any (layer_class is None for layer_class in transformer_layer_cls_to_wrap ):
157+ raise AssertionError ("transformer_layer_cls_to_wrap must contain at least one non-None class name" )
158+
159+ return transformer_layer_cls_to_wrap
160+
161+
146162@torch .no_grad ()
147163def offload_fsdp_model_to_cpu (model : FSDP , empty_cache : bool = True ):
148164 if fsdp_version (model ) == 2 or fsdp_version (model ) == 0 :
@@ -536,15 +552,11 @@ def apply_fsdp2(model, fsdp_kwargs, config):
536552 assert CPUOffloadPolicy is not None , "PyTorch version >= 2.4 is required for using fully_shard API (FSDP2)"
537553
538554 default_transformer_cls_names_to_wrap = getattr (model , "_no_split_modules" , None )
539- fsdp_transformer_layer_cls_to_wrap = config .get ("wrap_policy" , {}).get (
540- "transformer_layer_cls_to_wrap" , default_transformer_cls_names_to_wrap
555+ wrap_policy = config .get ("wrap_policy" , {}) or {}
556+ fsdp_transformer_layer_cls_to_wrap = _normalize_transformer_layer_cls_to_wrap (
557+ wrap_policy .get ("transformer_layer_cls_to_wrap" , default_transformer_cls_names_to_wrap )
541558 )
542559
543- if isinstance (fsdp_transformer_layer_cls_to_wrap , str ):
544- fsdp_transformer_layer_cls_to_wrap = [fsdp_transformer_layer_cls_to_wrap ]
545-
546- assert len (fsdp_transformer_layer_cls_to_wrap ) > 0 and fsdp_transformer_layer_cls_to_wrap [0 ] is not None
547-
548560 modules = _select_fsdp2_wrap_targets (model , fsdp_transformer_layer_cls_to_wrap )
549561
550562 for idx , module in enumerate (modules ):
0 commit comments