From 5ce51a08f93a8a0c3bdcf9f69ac8f82b1fad67ee Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 23 Jun 2024 19:48:51 -0400 Subject: [PATCH 1/4] feat(pt): support `training/profiling` in PT Signed-off-by: Jinzhe Zeng --- deepmd/pt/train/training.py | 14 ++++++++++---- deepmd/utils/argcheck.py | 4 ++-- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 2aa672bd60..14334063e5 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -699,6 +699,8 @@ def warm_up_linear(step, warmup_steps): self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log") self.tensorboard_freq = training_params.get("tensorboard_freq", 1) self.enable_profiler = training_params.get("enable_profiler", False) + self.profiling = training_params["profiling"] + self.profiling_file = training_params["profiling_file"] def run(self): fout = ( @@ -716,12 +718,14 @@ def run(self): ) writer = SummaryWriter(log_dir=self.tensorboard_log_dir) - if self.enable_profiler: + if self.enable_profiler or self.profiling: prof = torch.profiler.profile( schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler( self.tensorboard_log_dir - ), + ) + if self.enable_profiler + else None, record_shapes=True, with_stack=True, ) @@ -729,7 +733,7 @@ def run(self): def step(_step_id, task_key="Default"): # PyTorch Profiler - if self.enable_profiler: + if self.enable_profiler or self.profiling: prof.step() self.wrapper.train() if isinstance(self.lr_exp, dict): @@ -1061,8 +1065,10 @@ def log_loss_valid(_task_key="Default"): fout1.close() if self.enable_tensorboard: writer.close() - if self.enable_profiler: + if self.enable_profiler or self.profiling: prof.stop() + if self.profiling: + prof.export_chrome_trace(self.profiling_file) def save_model(self, save_path, lr=0.0, step=0): module = ( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index bbb203eea9..1cf3d95ceb 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2405,14 +2405,14 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. bool, optional=True, default=False, - doc=doc_only_tf_supported + doc_profiling, + doc=doc_profiling, ), Argument( "profiling_file", str, optional=True, default="timeline.json", - doc=doc_only_tf_supported + doc_profiling_file, + doc=doc_profiling_file, ), Argument( "enable_profiler", From 44f672102ccd533f28b2bdc465b7950d90efd87d Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 23 Jun 2024 19:54:51 -0400 Subject: [PATCH 2/4] add a log Signed-off-by: Jinzhe Zeng --- deepmd/pt/train/training.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 14334063e5..a22d2af52b 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -1069,6 +1069,9 @@ def log_loss_valid(_task_key="Default"): prof.stop() if self.profiling: prof.export_chrome_trace(self.profiling_file) + log.info( + f"The profiling trace have been saved to: {self.profiling_file}" + ) def save_model(self, save_path, lr=0.0, step=0): module = ( From 7643cfbc445bf72856f97c3ec693b31223302bee Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 23 Jun 2024 20:07:55 -0400 Subject: [PATCH 3/4] set default values as some tests don't use argcheck Signed-off-by: Jinzhe Zeng --- deepmd/pt/train/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index a22d2af52b..9d5c9ea51e 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -699,8 +699,8 @@ def warm_up_linear(step, warmup_steps): self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log") self.tensorboard_freq = training_params.get("tensorboard_freq", 1) self.enable_profiler = training_params.get("enable_profiler", False) - self.profiling = training_params["profiling"] - self.profiling_file = training_params["profiling_file"] + self.profiling = training_params.get("profiling", False) + self.profiling_file = training_params.get("profiling_file", "timeline.json") def run(self): fout = ( From b60f5343cb62f06f865a8f643e3497d05208c643 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Sun, 23 Jun 2024 22:38:26 -0400 Subject: [PATCH 4/4] update documentation Signed-off-by: Jinzhe Zeng --- deepmd/utils/argcheck.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 1cf3d95ceb..cc91e003c9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2347,9 +2347,9 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. ) doc_disp_training = "Displaying verbose information during training." doc_time_training = "Timing durining training." - doc_profiling = "Profiling during training." + doc_profiling = "Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to `profiling_file`." doc_profiling_file = "Output file for profiling." - doc_enable_profiler = "Enable TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler to analyze performance. The log will be saved to `tensorboard_log_dir`." + doc_enable_profiler = "Export the profiling results to the TensorBoard log for performance analysis, driven by TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler. The log will be saved to `tensorboard_log_dir`." doc_tensorboard = "Enable tensorboard" doc_tensorboard_log_dir = "The log directory of tensorboard outputs" doc_tensorboard_freq = "The frequency of writing tensorboard events."