Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions deepmd/pt/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,12 @@ def prepare_trainer_input_single(
if rank != 0:
stat_file_path_single = None
elif stat_file_path_single is not None:
if Path(stat_file_path_single).is_dir():
raise ValueError(
f"stat_file should be a file, not a directory: {stat_file_path_single}"
)
if not Path(stat_file_path_single).is_file():
with h5py.File(stat_file_path_single, "w") as f:
pass
if not Path(stat_file_path_single).exists():
if stat_file_path_single.endswith((".h5", ".hdf5")):
with h5py.File(stat_file_path_single, "w") as f:
pass
else:
Path(stat_file_path_single).mkdir()
stat_file_path_single = DPPath(stat_file_path_single, "a")

# validation and training data
Expand Down
4 changes: 3 additions & 1 deletion deepmd/utils/argcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -2179,7 +2179,9 @@ def training_args(): # ! modified by Ziyao: data configuration isolated.
doc_stat_file = (
"The file path for saving the data statistics results. "
"If set, the results will be saved and directly loaded during the next training session, "
"avoiding the need to recalculate the statistics"
"avoiding the need to recalculate the statistics. "
"If the file extension is .h5 or .hdf5, an HDF5 file is used to store the statistics; "
"otherwise, a directory containing NumPy binary files are used."
)
doc_opt_type = "The type of optimizer to use."
doc_kf_blocksize = "The blocksize for the Kalman filter."
Expand Down
6 changes: 3 additions & 3 deletions deepmd/utils/env_mat_stat.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def save_stats(self, path: DPPath) -> None:

Parameters
----------
path : DPH5Path
path : DPPath
The path to save the statistics of the environment matrix.
"""
if len(self.stats) == 0:
Expand All @@ -146,7 +146,7 @@ def load_stats(self, path: DPPath) -> None:

Parameters
----------
path : DPH5Path
path : DPPath
The path to load the statistics of the environment matrix.
"""
if len(self.stats) > 0:
Expand All @@ -166,7 +166,7 @@ def load_or_compute_stats(

Parameters
----------
path : DPH5Path
path : DPPath
The path to load the statistics of the environment matrix.
data : List[Dict[str, np.ndarray]]
The environment matrix.
Expand Down
2 changes: 1 addition & 1 deletion examples/water/dpa2/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./dpa2",
"stat_file": "./dpa2.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion examples/water/se_atten/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./dpa1",
"stat_file": "./dpa1.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion examples/water/se_e2_a/input_torch.json
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
"_comment": " that's all"
},
"training": {
"stat_file": "./se_e2_a",
"stat_file": "./se_e2_a.hdf5",
"training_data": {
"systems": [
"../data/data_0",
Expand Down
2 changes: 1 addition & 1 deletion source/tests/pt/model/test_saveload_dpa1.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def create_wrapper(self, read: bool):
model_config = copy.deepcopy(self.config["model"])
model_config["resuming"] = read
model_config["stat_file_dir"] = "stat_files"
model_config["stat_file"] = "stat.npz"
model_config["stat_file"] = "stat.hdf5"
model_config["stat_file_path"] = os.path.join(
model_config["stat_file_dir"], model_config["stat_file"]
)
Expand Down
4 changes: 2 additions & 2 deletions source/tests/pt/model/water/multitask.json
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
},
"data_dict": {
"model_1": {
"stat_file": "./stat_files/model_1",
"stat_file": "./stat_files/model_1.hdf5",
"training_data": {
"systems": [
"pt/water/data/data_0"
Expand All @@ -112,7 +112,7 @@
}
},
"model_2": {
"stat_file": "./stat_files/model_2",
"stat_file": "./stat_files/model_2.hdf5",
"training_data": {
"systems": [
"pt/water/data/data_0"
Expand Down