diff --git a/deepmd/entrypoints/train.py b/deepmd/entrypoints/train.py index a99ce6eae8..1ff0cff02e 100755 --- a/deepmd/entrypoints/train.py +++ b/deepmd/entrypoints/train.py @@ -6,6 +6,7 @@ import json import logging import time +import os from typing import Dict, TYPE_CHECKING, List, Optional, Any import numpy as np @@ -271,6 +272,22 @@ def get_data(jdata: Dict[str, Any], rcut, type_map, modifier): systems = j_must_have(jdata, "systems") if isinstance(systems, str): systems = expand_sys_str(systems) + help_msg = 'Please check your setting for data systems' + # check length of systems + if len(systems) == 0: + msg = 'cannot find valid a data system' + log.fatal(msg) + raise IOError(msg, help_msg) + # rougly check all items in systems are valid + for ii in systems: + if (not os.path.isdir(ii)): + msg = f'dir {ii} is not a valid dir' + log.fatal(msg) + raise IOError(msg, help_msg) + if (not os.path.isfile(os.path.join(ii, 'type.raw'))): + msg = f'dir {ii} is not a valid data system dir' + log.fatal(msg) + raise IOError(msg, help_msg) batch_size = j_must_have(jdata, "batch_size") sys_probs = jdata.get("sys_probs", None)