diff --git a/deepmd/common.py b/deepmd/common.py index c9873a6d94..3ab936db67 100644 --- a/deepmd/common.py +++ b/deepmd/common.py @@ -207,6 +207,30 @@ def expand_sys_str(root_dir: Union[str, Path]) -> list[str]: return matches +def rglob_sys_str(root_dir: str, patterns: list[str]) -> list[str]: + """Recursively iterate over directories taking those that contain `type.raw` file. + + Parameters + ---------- + root_dir : str, Path + starting directory + patterns : list[str] + list of glob patterns to match directories + + Returns + ------- + list[str] + list of string pointing to system directories + """ + root_dir = Path(root_dir) + matches = [] + for pattern in patterns: + matches.extend( + [str(d) for d in root_dir.rglob(pattern) if (d / "type.raw").is_file()] + ) + return list(set(matches)) # remove duplicates + + def get_np_precision(precision: "_PRECISION") -> np.dtype: """Get numpy precision constant from string. diff --git a/deepmd/pt/entrypoints/main.py b/deepmd/pt/entrypoints/main.py index 3fe507ecc2..0e248583ec 100644 --- a/deepmd/pt/entrypoints/main.py +++ b/deepmd/pt/entrypoints/main.py @@ -114,9 +114,11 @@ def prepare_trainer_input_single( validation_dataset_params["systems"] if validation_dataset_params else None ) training_systems = training_dataset_params["systems"] - training_systems = process_systems(training_systems) + trn_patterns = training_dataset_params.get("rglob_patterns", None) + training_systems = process_systems(training_systems, patterns=trn_patterns) if validation_systems is not None: - validation_systems = process_systems(validation_systems) + val_patterns = validation_dataset_params.get("rglob_patterns", None) + validation_systems = process_systems(validation_systems, val_patterns) # stat files stat_file_path_single = data_dict_single.get("stat_file", None) diff --git a/deepmd/utils/argcheck.py b/deepmd/utils/argcheck.py index 42ae982f89..fb911550dd 100644 --- a/deepmd/utils/argcheck.py +++ b/deepmd/utils/argcheck.py @@ -2926,6 +2926,9 @@ def training_data_args(): # ! added by Ziyao: new specification style for data "This key can be provided with a list that specifies the systems, or be provided with a string " "by which the prefix of all systems are given and the list of the systems is automatically generated." ) + doc_patterns = ( + "The customized patterns used in `rglob` to collect all training systems. " + ) doc_batch_size = f'This key can be \n\n\ - list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\ - int: all {link_sys} use the same batch size.\n\n\ @@ -2949,6 +2952,13 @@ def training_data_args(): # ! added by Ziyao: new specification style for data Argument( "systems", [list[str], str], optional=False, default=".", doc=doc_systems ), + Argument( + "rglob_patterns", + [list[str]], + optional=True, + default=None, + doc=doc_patterns + doc_only_pt_supported, + ), Argument( "batch_size", [list[int], int, str], @@ -2995,6 +3005,9 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat "This key can be provided with a list that specifies the systems, or be provided with a string " "by which the prefix of all systems are given and the list of the systems is automatically generated." ) + doc_patterns = ( + "The customized patterns used in `rglob` to collect all validation systems. " + ) doc_batch_size = f'This key can be \n\n\ - list: the length of which is the same as the {link_sys}. The batch size of each system is given by the elements of the list.\n\n\ - int: all {link_sys} use the same batch size.\n\n\ @@ -3015,6 +3028,13 @@ def validation_data_args(): # ! added by Ziyao: new specification style for dat Argument( "systems", [list[str], str], optional=False, default=".", doc=doc_systems ), + Argument( + "rglob_patterns", + [list[str]], + optional=True, + default=None, + doc=doc_patterns + doc_only_pt_supported, + ), Argument( "batch_size", [list[int], int, str], diff --git a/deepmd/utils/data_system.py b/deepmd/utils/data_system.py index fbc9c1a684..07dab35a90 100644 --- a/deepmd/utils/data_system.py +++ b/deepmd/utils/data_system.py @@ -17,6 +17,7 @@ from deepmd.common import ( expand_sys_str, make_default_mesh, + rglob_sys_str, ) from deepmd.env import ( GLOBAL_NP_FLOAT_PRECISION, @@ -730,7 +731,9 @@ def prob_sys_size_ext(keywords, nsystems, nbatch): return sys_probs -def process_systems(systems: Union[str, list[str]]) -> list[str]: +def process_systems( + systems: Union[str, list[str]], patterns: Optional[list[str]] = None +) -> list[str]: """Process the user-input systems. If it is a single directory, search for all the systems in the directory. @@ -740,6 +743,8 @@ def process_systems(systems: Union[str, list[str]]) -> list[str]: ---------- systems : str or list of str The user-input systems + patterns : list of str, optional + The patterns to match the systems, by default None Returns ------- @@ -747,7 +752,10 @@ def process_systems(systems: Union[str, list[str]]) -> list[str]: The valid systems """ if isinstance(systems, str): - systems = expand_sys_str(systems) + if patterns is None: + systems = expand_sys_str(systems) + else: + systems = rglob_sys_str(systems, patterns) elif isinstance(systems, list): systems = systems.copy() return systems @@ -777,7 +785,8 @@ def get_data( The data system """ systems = jdata["systems"] - systems = process_systems(systems) + rglob_patterns = jdata.get("rglob_patterns", None) + systems = process_systems(systems, patterns=rglob_patterns) batch_size = jdata["batch_size"] sys_probs = jdata.get("sys_probs", None) diff --git a/source/tests/pt/test_training.py b/source/tests/pt/test_training.py index ad52c5db16..3df95e4b14 100644 --- a/source/tests/pt/test_training.py +++ b/source/tests/pt/test_training.py @@ -516,5 +516,28 @@ def tearDown(self) -> None: shutil.rmtree(f) +class TestCustomizedRGLOB(unittest.TestCase, DPTrainTest): + def setUp(self) -> None: + input_json = str(Path(__file__).parent / "water/se_atten.json") + with open(input_json) as f: + self.config = json.load(f) + self.config["training"]["training_data"]["rglob_patterns"] = [ + "water/data/data_*" + ] + self.config["training"]["training_data"]["systems"] = str(Path(__file__).parent) + self.config["training"]["validation_data"]["rglob_patterns"] = [ + "water/*/data_0" + ] + self.config["training"]["validation_data"]["systems"] = str( + Path(__file__).parent + ) + self.config["model"] = deepcopy(model_dpa1) + self.config["training"]["numb_steps"] = 1 + self.config["training"]["save_freq"] = 1 + + def tearDown(self) -> None: + DPTrainTest.tearDown(self) + + if __name__ == "__main__": unittest.main()