Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,16 @@ class TrainingArguments:
)
},
)
sharding_offload_opt_buffersize_GB: int = field(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个参数以GB为单位,同时看到还有些 MB的单位的size。

目前框架中支持GB为单位吗? 看是否有必要支持MB单位的配置,后续也可以这个参数转为float,支持更精细的配置。

default=-1,
metadata={
"help": (
"Set the size of the optimizer offload buffer when need_hack_offload_optimizer() is True. This option only takes effect when "
"use DygraphShardingOptimizerV2. The default value is -1, which means that all of the optimizer states will be offloaded. Only "
"works when export HACK_OFFLOAD_OPTIMIZER=1. "
)
},
)

save_sharded_model: bool = field(
default=False,
Expand Down Expand Up @@ -1531,6 +1541,11 @@ def is_context_parallel_supported():
self.sharding_comm_buffer_size_MB
)

if hasattr(strategy.hybrid_configs["sharding_configs"], "offload_opt_buffer_size"):
strategy.hybrid_configs["sharding_configs"].offload_opt_buffer_size = int(
self.sharding_offload_opt_buffersize_GB
)

if "split_param" in sharding_parallel_config:
strategy.hybrid_configs["sharding_configs"].split_param = True
assert self.amp_master_grad, "Currently sharding stage1 v2 only support amp_master_grad"
Expand Down Expand Up @@ -1631,6 +1646,7 @@ def is_context_parallel_supported():
self.sharding_parallel_degree
* self.tensor_parallel_degree
* self.sep_parallel_degree
* self.context_parallel_degree
* self.pipeline_parallel_degree
)

Expand Down
12 changes: 9 additions & 3 deletions paddlenlp/trainer/utils/offload_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,11 @@ def new_opt_op(*args):
reload(arg)

ret = origin_op(*args)

is_offload_opt = getattr(args[0], "is_offload_opt", False)
for i, arg in enumerate(args):
if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient
if (
i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt
): # do not offload parameter and gradient
offload(arg)
return ret

Expand All @@ -74,7 +76,11 @@ def new_insert_sync(self, sync_var, *args, **kwargs):
origin_place = sync_var.place
reload(sync_var)
ret = origin_insert_sync(self, sync_var, *args, **kwargs)
new_sync_var = to_device(sync_var, origin_place)
is_offload_opt = getattr(sync_var, "is_offload_opt", False)
if is_offload_opt:
new_sync_var = to_device(sync_var, origin_place)
else:
new_sync_var = sync_var
assert new_sync_var is sync_var, "to_device must be inplace operation"
return ret

Expand Down
Loading