From 7845d6393075b1f18f6e0f37067554b61929e783 Mon Sep 17 00:00:00 2001 From: Wennie396 Date: Mon, 8 Sep 2025 15:01:29 +0800 Subject: [PATCH 1/4] add chunk offload optimizer --- paddlenlp/trainer/training_args.py | 14 ++++++++++++++ paddlenlp/trainer/utils/offload_optimizer.py | 12 +++++++++--- 2 files changed, 23 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index ec005732ac55..58af44a2f907 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -614,6 +614,16 @@ class TrainingArguments: ) }, ) + sharding_offload_opt_buffersize_GB: int = field( + default=-1, + metadata={ + "help": ( + "Set the size of the optimizer offload buffer when need_hack_offload_optimizer() is True. This option only takes effect when " + "use DygraphShardingOptimizerV2. The default value is -1, which means that all of the optimizer states will be offloaded. Only " + "works when export HACK_OFFLOAD_OPTIMIZER=1. " + ) + }, + ) save_sharded_model: bool = field( default=False, @@ -1531,6 +1541,10 @@ def is_context_parallel_supported(): self.sharding_comm_buffer_size_MB ) + strategy.hybrid_configs["sharding_configs"].offload_opt_buffer_size = int( + self.sharding_offload_opt_buffersize_GB + ) + if "split_param" in sharding_parallel_config: strategy.hybrid_configs["sharding_configs"].split_param = True assert self.amp_master_grad, "Currently sharding stage1 v2 only support amp_master_grad" diff --git a/paddlenlp/trainer/utils/offload_optimizer.py b/paddlenlp/trainer/utils/offload_optimizer.py index 65f5b77e2e5d..f20066f1e29b 100644 --- a/paddlenlp/trainer/utils/offload_optimizer.py +++ b/paddlenlp/trainer/utils/offload_optimizer.py @@ -58,9 +58,11 @@ def new_opt_op(*args): reload(arg) ret = origin_op(*args) - + is_offload_opt = getattr(args[0], "is_offload_opt", False) for i, arg in enumerate(args): - if i >= 2 and isinstance(arg, paddle.Tensor): # do not offload parameter and gradient + if ( + i >= 2 and isinstance(arg, paddle.Tensor) and is_offload_opt + ): # do not offload parameter and gradient offload(arg) return ret @@ -74,7 +76,11 @@ def new_insert_sync(self, sync_var, *args, **kwargs): origin_place = sync_var.place reload(sync_var) ret = origin_insert_sync(self, sync_var, *args, **kwargs) - new_sync_var = to_device(sync_var, origin_place) + is_offload_opt = getattr(sync_var, "is_offload_opt", False) + if is_offload_opt: + new_sync_var = to_device(sync_var, origin_place) + else: + new_sync_var = sync_var assert new_sync_var is sync_var, "to_device must be inplace operation" return ret From d7155c53a2708d869d0e3900dc4a6834c943872e Mon Sep 17 00:00:00 2001 From: Wennie396 Date: Thu, 11 Sep 2025 11:44:51 +0800 Subject: [PATCH 2/4] fix get offload_opt_buffer_size arg --- paddlenlp/trainer/training_args.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 58af44a2f907..c30fb268664e 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1541,9 +1541,10 @@ def is_context_parallel_supported(): self.sharding_comm_buffer_size_MB ) - strategy.hybrid_configs["sharding_configs"].offload_opt_buffer_size = int( - self.sharding_offload_opt_buffersize_GB - ) + if getattr(strategy.hybrid_configs["sharding_configs"], "offload_opt_buffer_size", None): + strategy.hybrid_configs["sharding_configs"].offload_opt_buffer_size = int( + self.sharding_offload_opt_buffersize_GB + ) if "split_param" in sharding_parallel_config: strategy.hybrid_configs["sharding_configs"].split_param = True From 1bc5e8b973c0802de08a485d1d0ef0b8ae397ced Mon Sep 17 00:00:00 2001 From: Wennie396 Date: Thu, 11 Sep 2025 14:18:01 +0800 Subject: [PATCH 3/4] fix --- paddlenlp/trainer/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index c30fb268664e..496b92926084 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1541,7 +1541,7 @@ def is_context_parallel_supported(): self.sharding_comm_buffer_size_MB ) - if getattr(strategy.hybrid_configs["sharding_configs"], "offload_opt_buffer_size", None): + if hasattr(strategy.hybrid_configs["sharding_configs"], "offload_opt_buffer_size"): strategy.hybrid_configs["sharding_configs"].offload_opt_buffer_size = int( self.sharding_offload_opt_buffersize_GB ) From 0b2a0dc856f96ec1368e57db811cca715102202e Mon Sep 17 00:00:00 2001 From: Wennie396 Date: Thu, 11 Sep 2025 19:49:13 +0800 Subject: [PATCH 4/4] fix --- paddlenlp/trainer/training_args.py | 1 + 1 file changed, 1 insertion(+) diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 496b92926084..88da91adced6 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1646,6 +1646,7 @@ def is_context_parallel_supported(): self.sharding_parallel_degree * self.tensor_parallel_degree * self.sep_parallel_degree + * self.context_parallel_degree * self.pipeline_parallel_degree )