diff --git a/deepmd/pt/train/training.py b/deepmd/pt/train/training.py index 2aa672bd60..9d5c9ea51e 100644 --- a/deepmd/pt/train/training.py +++ b/deepmd/pt/train/training.py @@ -699,6 +699,8 @@ def warm_up_linear(step, warmup_steps): self.tensorboard_log_dir = training_params.get("tensorboard_log_dir", "log") self.tensorboard_freq = training_params.get("tensorboard_freq", 1) self.enable_profiler = training_params.get("enable_profiler", False) + self.profiling = training_params.get("profiling", False) + self.profiling_file = training_params.get("profiling_file", "timeline.json") def run(self): fout = ( @@ -716,12 +718,14 @@ def run(self): ) writer = SummaryWriter(log_dir=self.tensorboard_log_dir) - if self.enable_profiler: + if self.enable_profiler or self.profiling: prof = torch.profiler.profile( schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=1), on_trace_ready=torch.profiler.tensorboard_trace_handler( self.tensorboard_log_dir - ), + ) + if self.enable_profiler + else None, record_shapes=True, with_stack=True, ) @@ -729,7 +733,7 @@ def run(self): def step(_step_id, task_key="Default"): # PyTorch Profiler - if self.enable_profiler: + if self.enable_profiler or self.profiling: prof.step() self.wrapper.train() if isinstance(self.lr_exp, dict): @@ -1061,8 +1065,13 @@ def log_loss_valid(_task_key="Default"): fout1.close() if self.enable_tensorboard: writer.close() - if self.enable_profiler: + if self.enable_profiler or self.profiling: prof.stop() + if self.profiling: + prof.export_chrome_trace(self.profiling_file) + log.info( + f"The profiling trace have been saved to: {self.profiling_file}" + ) def save_model(self, save_path, lr=0.0, step=0): module = ( diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index bbb203eea9..cc91e003c9 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2347,9 +2347,9 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. ) doc_disp_training = "Displaying verbose information during training." doc_time_training = "Timing durining training." - doc_profiling = "Profiling during training." + doc_profiling = "Export the profiling results to the Chrome JSON file for performance analysis, driven by the legacy TensorFlow profiling API or PyTorch Profiler. The output file will be saved to `profiling_file`." doc_profiling_file = "Output file for profiling." - doc_enable_profiler = "Enable TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler to analyze performance. The log will be saved to `tensorboard_log_dir`." + doc_enable_profiler = "Export the profiling results to the TensorBoard log for performance analysis, driven by TensorFlow Profiler (available in TensorFlow 2.3) or PyTorch Profiler. The log will be saved to `tensorboard_log_dir`." doc_tensorboard = "Enable tensorboard" doc_tensorboard_log_dir = "The log directory of tensorboard outputs" doc_tensorboard_freq = "The frequency of writing tensorboard events." @@ -2405,14 +2405,14 @@ def training_args(): # ! modified by Ziyao: data configuration isolated. bool, optional=True, default=False, - doc=doc_only_tf_supported + doc_profiling, + doc=doc_profiling, ), Argument( "profiling_file", str, optional=True, default="timeline.json", - doc=doc_only_tf_supported + doc_profiling_file, + doc=doc_profiling_file, ), Argument( "enable_profiler",