diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index 42e21d2ba4..4a67eee9d4 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -154,7 +154,7 @@ def _do_work(jdata: Dict[str, Any], run_opt: RunOptions, is_compress: bool = Fal train_data = get_data(jdata["training"]["training_data"], rcut, ipt_type_map, modifier) train_data.print_summary("training") if jdata["training"].get("validation_data", None) is not None: - valid_data = get_data(jdata["training"]["validation_data"], rcut, ipt_type_map, modifier) + valid_data = get_data(jdata["training"]["validation_data"], rcut, train_data.type_map, modifier) valid_data.print_summary("validation") # get training info