From 55f7ef66486a6122f710496b23267a9f4d5d33ad Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 27 Dec 2024 12:05:21 +0800 Subject: [PATCH 1/5] support CINN compiler for DPA2 example --- deepmd/pd/train/training.py | 47 +++++++++++++++++++++------------- deepmd/pd/utils/env.py | 51 +++++++++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 25 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 0f3c7a9732..64dc0f9ed9 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -53,6 +53,8 @@ get_sampler_from_params, ) from deepmd.pd.utils.env import ( + CINN, + DEFAULT_PRECISION, DEVICE, JIT, NUM_WORKERS, @@ -397,11 +399,11 @@ def get_lr(lr_params): self.lr_exp = get_lr(config["learning_rate"]) # JIT - if JIT: - raise NotImplementedError( - "JIT is not supported yet when training with Paddle" - ) - self.model = paddle.jit.to_static(self.model) + # if JIT: + # raise NotImplementedError( + # "JIT is not supported yet when training with Paddle" + # ) + # self.model = paddle.jit.to_static(self.model) # Model Wrapper self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) @@ -631,6 +633,19 @@ def warm_up_linear(step, warmup_steps): self.profiling_file = training_params.get("profiling_file", "timeline.json") def run(self): + if JIT: + from paddle import ( + jit, + static, + ) + + build_strategy = static.BuildStrategy() + build_strategy.build_cinn_pass: bool = CINN + self.wrapper.forward = jit.to_static( + full_graph=True, build_strategy=build_strategy + )(self.wrapper.forward) + log.info(f"{'*' * 20} Using Jit {'*' * 20}") + fout = ( open( self.disp_file, @@ -670,9 +685,11 @@ def step(_step_id, task_key="Default") -> None: cur_lr = _lr.value(_step_id) pref_lr = cur_lr self.optimizer.clear_grad(set_to_zero=False) - input_dict, label_dict, log_dict = self.get_data( - is_train=True, task_key=task_key - ) + + with nvprof_context(enable_profiling, "Fetching data"): + input_dict, label_dict, log_dict = self.get_data( + is_train=True, task_key=task_key + ) if SAMPLER_RECORD: print_str = f"Step {_step_id}: sample system{log_dict['sid']} frame{log_dict['fid']}\n" fout1.write(print_str) @@ -686,7 +703,7 @@ def step(_step_id, task_key="Default") -> None: with nvprof_context(enable_profiling, "Forward pass"): model_pred, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=task_key, ) @@ -745,7 +762,7 @@ def log_loss_valid(_task_key="Default"): return {} _, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=_task_key, ) @@ -795,7 +812,7 @@ def log_loss_valid(_task_key="Default"): ) _, loss, more_loss = self.wrapper( **input_dict, - cur_lr=pref_lr, + cur_lr=paddle.full([], pref_lr, DEFAULT_PRECISION), label=label_dict, task_key=_key, ) @@ -905,8 +922,8 @@ def log_loss_valid(_task_key="Default"): else: model_key = "Default" step(step_id, model_key) - if JIT: - break + # if JIT: + # break if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0): if not self.multi_task: @@ -961,10 +978,6 @@ def log_loss_valid(_task_key="Default"): / (elapsed_batch // self.disp_freq * self.disp_freq), ) - if JIT: - raise NotImplementedError( - "Paddle JIT saving during training is not supported yet." - ) log.info(f"Trained model has been saved to: {self.save_ckpt}") if fout: diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index e2abe9a6e5..27f5b2a479 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -32,7 +32,33 @@ paddle.device.set_device(DEVICE) -JIT = False + +def to_bool(flag: int | bool | str) -> bool: + if isinstance(flag, int): + if flag not in [0, 1]: + raise ValueError(f"flag must be either 0 or 1, but received {flag}") + return bool(flag) + + elif isinstance(flag, str): + flag = flag.lower() + if flag not in ["1", "0", "true", "false"]: + raise ValueError( + "flag must be either '0', '1', 'true', 'false', " + f"but received '{flag}'" + ) + return flag in ["1", "true"] + + elif isinstance(flag, bool): + return flag + + else: + raise ValueError( + f"flag must be either int, bool, or str, but received {type(flag).__name__}" + ) + + +JIT = to_bool(os.environ.get("JIT", False)) +CINN = to_bool(os.environ.get("CINN", False)) CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True @@ -138,14 +164,23 @@ def enable_prim(enable: bool = True): ] EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST)) - """Enable running program in primitive C++ API in eager/static mode.""" - from paddle.framework import ( - core, - ) + """Enable running program with primitive operators in eager/static mode.""" + if JIT: + # jit mode + paddle.framework.core._set_prim_all_enabled(enable) + if enable: + # No need to set a blacklist for now in JIT mode. + pass + else: + # eager mode + paddle.framework.core.set_prim_eager_enabled(enable) + if enable: + # Set a blacklist (i.e., disable several composite operators) in eager mode + # to enhance computational performance. + paddle.framework.core._set_prim_backward_blacklist( + *EAGER_COMP_OP_BLACK_LIST + ) - core.set_prim_eager_enabled(enable) - if enable: - paddle.framework.core._set_prim_backward_blacklist(*EAGER_COMP_OP_BLACK_LIST) log = logging.getLogger(__name__) log.info(f"{'Enable' if enable else 'Disable'} prim in eager and static mode.") From 7ca2a9ecd203b4c6088dd3d97c92e1cef348441a Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 27 Dec 2024 13:17:34 +0800 Subject: [PATCH 2/5] refine CINN flag --- deepmd/pd/train/training.py | 25 ++++++++++++++++--------- deepmd/pd/utils/env.py | 9 ++++++++- 2 files changed, 24 insertions(+), 10 deletions(-) diff --git a/deepmd/pd/train/training.py b/deepmd/pd/train/training.py index 64dc0f9ed9..a0328942e4 100644 --- a/deepmd/pd/train/training.py +++ b/deepmd/pd/train/training.py @@ -399,11 +399,11 @@ def get_lr(lr_params): self.lr_exp = get_lr(config["learning_rate"]) # JIT - # if JIT: - # raise NotImplementedError( - # "JIT is not supported yet when training with Paddle" - # ) - # self.model = paddle.jit.to_static(self.model) + if JIT: + raise NotImplementedError( + "JIT is not supported yet when training with Paddle" + ) + self.model = paddle.jit.to_static(self.model) # Model Wrapper self.wrapper = ModelWrapper(self.model, self.loss, model_params=model_params) @@ -633,7 +633,7 @@ def warm_up_linear(step, warmup_steps): self.profiling_file = training_params.get("profiling_file", "timeline.json") def run(self): - if JIT: + if CINN: from paddle import ( jit, static, @@ -644,7 +644,10 @@ def run(self): self.wrapper.forward = jit.to_static( full_graph=True, build_strategy=build_strategy )(self.wrapper.forward) - log.info(f"{'*' * 20} Using Jit {'*' * 20}") + log.info( + "Enable CINN during training, there may be some additional " + "compilation time in the first traning step." + ) fout = ( open( @@ -922,8 +925,8 @@ def log_loss_valid(_task_key="Default"): else: model_key = "Default" step(step_id, model_key) - # if JIT: - # break + if JIT: + break if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0): if not self.multi_task: @@ -978,6 +981,10 @@ def log_loss_valid(_task_key="Default"): / (elapsed_batch // self.disp_freq * self.disp_freq), ) + if JIT: + raise NotImplementedError( + "Paddle JIT saving during training is not supported yet." + ) log.info(f"Trained model has been saved to: {self.save_ckpt}") if fout: diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index 27f5b2a479..87b69c5676 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -59,6 +59,13 @@ def to_bool(flag: int | bool | str) -> bool: JIT = to_bool(os.environ.get("JIT", False)) CINN = to_bool(os.environ.get("CINN", False)) +if CINN: + assert paddle.device.is_compiled_with_cinn(), ( + "CINN is set to True, but PaddlePaddle is not compiled with CINN support. " + "Ensure that your PaddlePaddle installation supports CINN by checking your " + "installation or recompiling with CINN enabled." + ) + CACHE_PER_SYS = 5 # keep at most so many sets per sys in memory ENERGY_BIAS_TRAINABLE = True @@ -165,7 +172,7 @@ def enable_prim(enable: bool = True): EAGER_COMP_OP_BLACK_LIST = list(set(EAGER_COMP_OP_BLACK_LIST)) """Enable running program with primitive operators in eager/static mode.""" - if JIT: + if JIT or CINN: # jit mode paddle.framework.core._set_prim_all_enabled(enable) if enable: From 77c34e73ffe01a38cd54084ad7eed88b684943a5 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 29 Dec 2024 18:59:27 +0800 Subject: [PATCH 3/5] import annotations for compatibility with python<=3.9 --- .pre-commit-config.yaml | 52 ++++++++++++++++++++--------------------- deepmd/pd/utils/env.py | 4 ++++ 2 files changed, 30 insertions(+), 26 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd36fd6e63..f47839650a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,13 +65,13 @@ repos: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$) # markdown, yaml, CSS, javascript - - repo: https://github.com/pre-commit/mirrors-prettier - rev: v4.0.0-alpha.8 - hooks: - - id: prettier - types_or: [markdown, yaml, css] - # workflow files cannot be modified by pre-commit.ci - exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) + # - repo: https://github.com/pre-commit/mirrors-prettier + # rev: v4.0.0-alpha.8 + # hooks: + # - id: prettier + # types_or: [markdown, yaml, css] + # # workflow files cannot be modified by pre-commit.ci + # exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) # Shell - repo: https://github.com/scop/pre-commit-shfmt rev: v3.10.0-2 @@ -83,25 +83,25 @@ repos: hooks: - id: cmake-format #- id: cmake-lint - - repo: https://github.com/njzjz/mirrors-bibtex-tidy - rev: v1.13.0 - hooks: - - id: bibtex-tidy - args: - - --curly - - --numeric - - --align=13 - - --blank-lines - # disable sort: the order of keys and fields has explict meanings - #- --sort=key - - --duplicates=key,doi,citation,abstract - - --merge=combine - #- --sort-fields - #- --strip-comments - - --trailing-commas - - --encode-urls - - --remove-empty-fields - - --wrap=80 + # - repo: https://github.com/njzjz/mirrors-bibtex-tidy + # rev: v1.13.0 + # hooks: + # - id: bibtex-tidy + # args: + # - --curly + # - --numeric + # - --align=13 + # - --blank-lines + # # disable sort: the order of keys and fields has explict meanings + # #- --sort=key + # - --duplicates=key,doi,citation,abstract + # - --merge=combine + # #- --sort-fields + # #- --strip-comments + # - --trailing-commas + # - --encode-urls + # - --remove-empty-fields + # - --wrap=80 # license header - repo: https://github.com/Lucas-C/pre-commit-hooks rev: v1.5.5 diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index 87b69c5676..a21a1244ff 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -1,4 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from __future__ import ( + annotations, +) + import logging import os From 87d8f355ed65f5126159893601c3cb644f7dd526 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Fri, 7 Feb 2025 16:22:19 +0800 Subject: [PATCH 4/5] revert pre-commit change --- .pre-commit-config.yaml | 52 ++++++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index defd3aaed9..edae9c69e3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -65,13 +65,13 @@ repos: - id: clang-format exclude: ^(source/3rdparty|source/lib/src/gpu/cudart/.+\.inc|.+\.ipynb$|.+\.json$) # markdown, yaml, CSS, javascript - # - repo: https://github.com/pre-commit/mirrors-prettier - # rev: v4.0.0-alpha.8 - # hooks: - # - id: prettier - # types_or: [markdown, yaml, css] - # # workflow files cannot be modified by pre-commit.ci - # exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) + - repo: https://github.com/pre-commit/mirrors-prettier + rev: v4.0.0-alpha.8 + hooks: + - id: prettier + types_or: [markdown, yaml, css] + # workflow files cannot be modified by pre-commit.ci + exclude: ^(source/3rdparty|\.github/workflows|\.clang-format) # Shell - repo: https://github.com/scop/pre-commit-shfmt rev: v3.10.0-2 @@ -83,25 +83,25 @@ repos: hooks: - id: cmake-format #- id: cmake-lint - # - repo: https://github.com/njzjz/mirrors-bibtex-tidy - # rev: v1.13.0 - # hooks: - # - id: bibtex-tidy - # args: - # - --curly - # - --numeric - # - --align=13 - # - --blank-lines - # # disable sort: the order of keys and fields has explict meanings - # #- --sort=key - # - --duplicates=key,doi,citation,abstract - # - --merge=combine - # #- --sort-fields - # #- --strip-comments - # - --trailing-commas - # - --encode-urls - # - --remove-empty-fields - # - --wrap=80 + - repo: https://github.com/njzjz/mirrors-bibtex-tidy + rev: v1.13.0 + hooks: + - id: bibtex-tidy + args: + - --curly + - --numeric + - --align=13 + - --blank-lines + # disable sort: the order of keys and fields has explict meanings + #- --sort=key + - --duplicates=key,doi,citation,abstract + - --merge=combine + #- --sort-fields + #- --strip-comments + - --trailing-commas + - --encode-urls + - --remove-empty-fields + - --wrap=80 # license header - repo: https://github.com/Lucas-C/pre-commit-hooks rev: v1.5.5 From 2cf1c659dc7d7b6c3f6a33debc6d737bce915ebb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Feb 2025 08:24:15 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- deepmd/pd/utils/env.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/deepmd/pd/utils/env.py b/deepmd/pd/utils/env.py index a21a1244ff..cb41609b27 100644 --- a/deepmd/pd/utils/env.py +++ b/deepmd/pd/utils/env.py @@ -47,8 +47,7 @@ def to_bool(flag: int | bool | str) -> bool: flag = flag.lower() if flag not in ["1", "0", "true", "false"]: raise ValueError( - "flag must be either '0', '1', 'true', 'false', " - f"but received '{flag}'" + f"flag must be either '0', '1', 'true', 'false', but received '{flag}'" ) return flag in ["1", "true"]