From 62dc7d5d68f6209eb7fd305120bd71f9e247d6f5 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 26 Aug 2025 15:10:59 +0800 Subject: [PATCH 1/4] support gradient accumulation --- deepmd/pd/train/training.py | 36 +++++++++++++++++++------------- source/tests/pd/test_training.py | 20 ++++++++++++++++++ 2 files changed, 41 insertions(+), 15 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index c85e67a362..bf01f7b610 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -133,6 +133,9 @@ def __init__( # Iteration config self.num_steps = training_params["numb_steps"] + self.acc_freq: int = training_params.get( + "acc_freq", 1 + ) # gradient accumulation steps self.disp_file = training_params.get("disp_file", "lcurve.out") self.disp_freq = training_params.get("disp_freq", 1000) self.save_ckpt = training_params.get("save_ckpt", "model.ckpt") @@ -744,7 +747,6 @@ def step(_step_id, task_key="Default") -> None: _lr = self.lr_exp cur_lr = _lr.value(_step_id) pref_lr = cur_lr - self.optimizer.clear_grad(set_to_zero=False) with nvprof_context(enable_profiling, "Fetching data"): input_dict, label_dict, log_dict = self.get_data( @@ -780,22 +782,26 @@ def step(_step_id, task_key="Default") -> None: with nvprof_context(enable_profiling, "Backward pass"): loss.backward() - # fuse + allreduce manually before optimization if use DDP + no_sync - # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 - if self.world_size > 1: - hpu.fused_allreduce_gradients(list(self.wrapper.parameters()), None) - - if self.gradient_max_norm > 0.0: - with nvprof_context(enable_profiling, "Gradient clip"): - paddle.nn.utils.clip_grad_norm_( - self.wrapper.parameters(), - self.gradient_max_norm, - error_if_nonfinite=True, + # gradient accumulation + if _step_id % self.acc_freq == 0: + # fuse + allreduce manually before optimization if use DDP + no_sync + # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 + if self.world_size > 1: + hpu.fused_allreduce_gradients( + list(self.wrapper.parameters()), None ) - with nvprof_context(enable_profiling, "Adam update"): - self.optimizer.step() - self.scheduler.step() + if self.gradient_max_norm > 0.0: + with nvprof_context(enable_profiling, "Gradient clip"): + paddle.nn.utils.clip_grad_norm_( + self.wrapper.parameters(), + self.gradient_max_norm, + error_if_nonfinite=True, + ) + with nvprof_context(enable_profiling, "Adam update"): + self.optimizer.step() + self.optimizer.clear_grad(set_to_zero=False) + self.scheduler.step() else: raise ValueError(f"Not supported optimizer type '{self.opt_type}'") diff --git a/source/tests/pd/test_training.py b/source/tests/pd/test_training.py index 8958dcb165..1f547ffe8b 100644 --- a/source/tests/pd/test_training.py +++ b/source/tests/pd/test_training.py @@ -158,6 +158,26 @@ def tearDown(self) -> None: DPTrainTest.tearDown(self) +class TestEnergyModelGradientAccumulation(unittest.TestCase, DPTrainTest): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + data_file = [str(Path(__file__).parent / "water/data/data_0")] + self.config["training"]["training_data"]["systems"] = data_file + self.config["training"]["validation_data"]["systems"] = data_file + self.config["model"] = deepcopy(model_se_e2_a) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + self.config["training"]["acc_freq"] = 4 + # import paddle + enable_prim(True) + # assert paddle.framework.core._is_eager_prim_enabled() + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + class TestFparam(unittest.TestCase, DPTrainTest): """Test if `fparam` can be loaded correctly.""" From ed992ad45ebdee91f365bcad0cbbc9a362db1c13 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Tue, 26 Aug 2025 15:22:06 +0800 Subject: [PATCH 2/4] fix --- deepmd/pd/train/training.py | 3 ++- source/tests/pd/test_training.py | 4 ---- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index bf01f7b610..4e5fea081f 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -783,7 +783,7 @@ def step(_step_id, task_key="Default") -> None: loss.backward() # gradient accumulation - if _step_id % self.acc_freq == 0: + if (_step_id + 1) % self.acc_freq == 0: # fuse + allreduce manually before optimization if use DDP + no_sync # details in https://github.com/PaddlePaddle/Paddle/issues/48898#issuecomment-1343838622 if self.world_size > 1: @@ -798,6 +798,7 @@ def step(_step_id, task_key="Default") -> None: self.gradient_max_norm, error_if_nonfinite=True, ) + with nvprof_context(enable_profiling, "Adam update"): self.optimizer.step() self.optimizer.clear_grad(set_to_zero=False) diff --git a/source/tests/pd/test_training.py b/source/tests/pd/test_training.py index 1f547ffe8b..0dc36fa314 100644 --- a/source/tests/pd/test_training.py +++ b/source/tests/pd/test_training.py @@ -150,9 +150,7 @@ def setUp(self) -> None: self.config["model"] = deepcopy(model_se_e2_a) self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 - # import paddle enable_prim(True) - # assert paddle.framework.core._is_eager_prim_enabled() def tearDown(self) -> None: DPTrainTest.tearDown(self) @@ -170,9 +168,7 @@ def setUp(self) -> None: self.config["training"]["numb_steps"] = 1 self.config["training"]["save_freq"] = 1 self.config["training"]["acc_freq"] = 4 - # import paddle enable_prim(True) - # assert paddle.framework.core._is_eager_prim_enabled() def tearDown(self) -> None: DPTrainTest.tearDown(self) From d272b50251802f38f1e1d59982b53ed5ca3b0077 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 27 Aug 2025 10:38:16 +0800 Subject: [PATCH 3/4] add acc_freq doc --- deepmd/utils/argcheck.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index e446674db7..41a37053a2 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -40,6 +40,7 @@ doc_only_tf_supported = "(Supported Backend: TensorFlow) " doc_only_pt_supported = "(Supported Backend: PyTorch) " +doc_only_pd_supported = "(Supported Backend: Paddle) " # descriptors doc_loc_frame = "Defines a local frame at each atom, and the compute the descriptor as local coordinates under this frame." doc_se_e2_a = "Used by the smooth edition of Deep Potential. The full relative coordinates are used to construct the descriptor." @@ -3167,6 +3168,7 @@ def training_args( doc_kf_blocksize = "The blocksize for the Kalman filter." doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode." doc_data_dict = "The multiple definition of the data, used in the multi-task mode." + doc_acc_freq = "The accumulation steps for the gradients." arg_training_data = training_data_args() arg_validation_data = validation_data_args() @@ -3269,6 +3271,12 @@ def training_args( optional=True, doc=doc_only_pt_supported + doc_gradient_max_norm, ), + Argument( + "acc_freq", + int, + optional=True, + doc=doc_only_pd_supported + doc_acc_freq, + ), ] variants = [ Variant( From 1d77808c0351cb5df714850bc03bfdb63c5567a0 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Wed, 27 Aug 2025 18:47:41 +0800 Subject: [PATCH 4/4] refine docstring --- deepmd/utils/argcheck.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 41a37053a2..799e806bb2 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -3168,7 +3168,7 @@ def training_args( doc_kf_blocksize = "The blocksize for the Kalman filter." doc_model_prob = "The visiting probability of each model for each training step in the multi-task mode." doc_data_dict = "The multiple definition of the data, used in the multi-task mode." - doc_acc_freq = "The accumulation steps for the gradients." + doc_acc_freq = "Gradient accumulation steps (number of steps to accumulate gradients before performing an update)." arg_training_data = training_data_args() arg_validation_data = validation_data_args() @@ -3275,6 +3275,7 @@ def training_args( "acc_freq", int, optional=True, + default=1, doc=doc_only_pd_supported + doc_acc_freq, ), ]