Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 24 additions & 41 deletions deepmd/pt/utils/dataloader.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import logging
import os
import queue
import time
from multiprocessing.dummy import (
Pool,
)
from queue import (
Queue,
)
from threading import (
Thread,
)
Expand Down Expand Up @@ -202,70 +204,51 @@ def print_summary(
)


_sentinel = object()
QUEUESIZE = 32


class BackgroundConsumer(Thread):
def __init__(self, queue, source, max_len) -> None:
Thread.__init__(self)
def __init__(self, queue, source) -> None:
super().__init__()
self.daemon = True
self._queue = queue
self._source = source # Main DL iterator
self._max_len = max_len #

def run(self) -> None:
for item in self._source:
self._queue.put(item) # Blocking if the queue is full

# Signal the consumer we are done.
self._queue.put(_sentinel)
# Signal the consumer we are done; this should not happen for DataLoader
self._queue.put(StopIteration())

Comment thread
caic99 marked this conversation as resolved.

QUEUESIZE = 32


class BufferedIterator:
def __init__(self, iterable) -> None:
self._queue = queue.Queue(QUEUESIZE)
self._queue = Queue(QUEUESIZE)
self._iterable = iterable
self._consumer = None

self.start_time = time.time()
self.warning_time = None
self.total = len(iterable)

def _create_consumer(self) -> None:
self._consumer = BackgroundConsumer(self._queue, self._iterable, self.total)
self._consumer.daemon = True
self._consumer = BackgroundConsumer(self._queue, self._iterable)
self._consumer.start()
self.last_warning_time = time.time()

def __iter__(self):
return self

def __len__(self) -> int:
return self.total
return len(self._iterable)

def __next__(self):
# Create consumer if not created yet
if self._consumer is None:
self._create_consumer()
# Notify the user if there is a data loading bottleneck
if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)):
if time.time() - self.start_time > 5 * 60:
if (
self.warning_time is None
or time.time() - self.warning_time > 15 * 60
):
log.warning(
"Data loading buffer is empty or nearly empty. This may "
"indicate a data loading bottleneck, and increasing the "
"number of workers (--num-workers) may help."
)
self.warning_time = time.time()

# Get next example
start_wait = time.time()
item = self._queue.get()
wait_time = time.time() - start_wait
if (
wait_time > 1.0 and start_wait - self.last_warning_time > 15 * 60
): # Even for Multi-Task training, each step usually takes < 1s
log.warning(
f"Data loading is slow, waited {wait_time:.2f} seconds. Ignoring this warning for 15 minutes."
)
self.last_warning_time = start_wait
if isinstance(item, Exception):
raise item
if item is _sentinel:
raise StopIteration
return item


Expand Down
10 changes: 10 additions & 0 deletions deepmd/utils/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,10 @@ def is_file(self) -> bool:
def is_dir(self) -> bool:
"""Check if self is directory."""

@abstractmethod
def __getnewargs__(self):
"""Return the arguments to be passed to __new__ when unpickling an instance."""

@abstractmethod
def __truediv__(self, key: str) -> "DPPath":
"""Used for / operator."""
Expand Down Expand Up @@ -171,6 +175,9 @@ def __init__(self, path: str, mode: str = "r") -> None:
else:
self.path = Path(path)

def __getnewargs__(self):
return (self.path, self.mode)

def load_numpy(self) -> np.ndarray:
"""Load NumPy array.

Expand Down Expand Up @@ -304,6 +311,9 @@ def __init__(self, path: str, mode: str = "r") -> None:
# h5 path: default is the root path
self._name = s[1] if len(s) > 1 else "/"

def __getnewargs__(self):
return (self.root_path, self.mode)

@classmethod
@lru_cache(None)
def _load_h5py(cls, path: str, mode: str = "r") -> h5py.File:
Expand Down