diff --git a/deepmd/pt/utils/dataloader.py b/deepmd/pt/utils/dataloader.py index 2fea6b72d2..b32448d8ef 100644 --- a/deepmd/pt/utils/dataloader.py +++ b/deepmd/pt/utils/dataloader.py @@ -1,11 +1,13 @@ # SPDX-License-Identifier: LGPL-3.0-or-later import logging import os -import queue import time from multiprocessing.dummy import ( Pool, ) +from queue import ( + Queue, +) from threading import ( Thread, ) @@ -202,70 +204,51 @@ def print_summary( ) -_sentinel = object() -QUEUESIZE = 32 - - class BackgroundConsumer(Thread): - def __init__(self, queue, source, max_len) -> None: - Thread.__init__(self) + def __init__(self, queue, source) -> None: + super().__init__() + self.daemon = True self._queue = queue self._source = source # Main DL iterator - self._max_len = max_len # def run(self) -> None: for item in self._source: self._queue.put(item) # Blocking if the queue is full - # Signal the consumer we are done. - self._queue.put(_sentinel) + # Signal the consumer we are done; this should not happen for DataLoader + self._queue.put(StopIteration()) + + +QUEUESIZE = 32 class BufferedIterator: def __init__(self, iterable) -> None: - self._queue = queue.Queue(QUEUESIZE) + self._queue = Queue(QUEUESIZE) self._iterable = iterable - self._consumer = None - - self.start_time = time.time() - self.warning_time = None - self.total = len(iterable) - - def _create_consumer(self) -> None: - self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total) - self._consumer.daemon = True + self._consumer = BackgroundConsumer(self._queue, self._iterable) self._consumer.start() + self.last_warning_time = time.time() def __iter__(self): return self def __len__(self) -> int: - return self.total + return len(self._iterable) def __next__(self): - # Create consumer if not created yet - if self._consumer is None: - self._create_consumer() - # Notify the user if there is a data loading bottleneck - if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): - if time.time() - self.start_time > 5 * 60: - if ( - self.warning_time is None - or time.time() - self.warning_time > 15 * 60 - ): - log.warning( - "Data loading buffer is empty or nearly empty. This may " - "indicate a data loading bottleneck, and increasing the " - "number of workers (--num-workers) may help." - ) - self.warning_time = time.time() - - # Get next example + start_wait = time.time() item = self._queue.get() + wait_time = time.time() - start_wait + if ( + wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60 + ): # Even for Multi-Task training, each step usually takes < 1s + log.warning( + f"Data loading is slow, waited {wait_time:.2f} seconds. Ignoring this warning for 15 minutes." + ) + self.last_warning_time = start_wait if isinstance(item, Exception): raise item - if item is _sentinel: - raise StopIteration return item diff --git a/deepmd/utils/path.py b/deepmd/utils/path.py index 6c52caac1d..aed7d0b73d 100644 --- a/deepmd/utils/path.py +++ b/deepmd/utils/path.py @@ -113,6 +113,10 @@ def is_file(self) -> bool: def is_dir(self) -> bool: """Check if self is directory.""" + @abstractmethod + def __getnewargs__(self): + """Return the arguments to be passed to __new__ when unpickling an instance.""" + @abstractmethod def __truediv__(self, key: str) -> "DPPath": """Used for / operator.""" @@ -171,6 +175,9 @@ def __init__(self, path: str, mode: str = "r") -> None: else: self.path = Path(path) + def __getnewargs__(self): + return (self.path, self.mode) + def load_numpy(self) -> np.ndarray: """Load NumPy array. @@ -304,6 +311,9 @@ def __init__(self, path: str, mode: str = "r") -> None: # h5 path: default is the root path self._name = s[1] if len(s) > 1 else "/" + def __getnewargs__(self): + return (self.root_path, self.mode) + @classmethod @lru_cache(None) def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File: