Source code for lhotse.dataset.dataloading

import platform
from collections import deque
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import get_context
from typing import Any, Dict, List

import torch.utils.data

from lhotse.dataset.sampling.base import CutSampler


[docs]class LhotseDataLoader: """ A simplified ``DataLoader`` implementation that relies on a ``ProcessPoolExecutor``. The main difference between this and ``torch.utils.data.DataLoader`` is that :class:`.LhotseDataLoader` allows to launch subprocesses inside of its workers. This is useful for working with dataset classes which perform dynamic batching and need to perform concurrent I/O to read all the necessary data from disk/network. .. note:: :class:`.LhotseDataLoader` does not support ``num_workers=0``. .. warning:: :class:`.LhotseDataLoader` is experimental and not guaranteed to work correctly across all possible edge cases related to subprocess worker termination. If you experience stability problems, contact us or use a standard ``DataLoader`` instead. .. warning:: :class:`.LhotseDataLoader` requires Python >= 3.7. """
[docs] def __init__( self, dataset: torch.utils.data.Dataset, sampler: CutSampler, num_workers: int = 1, prefetch_factor: int = 2, ) -> None: from packaging.version import parse as _version if _version(platform.python_version()) < _version("3.7"): raise RuntimeError("LhotseDataLoader requires Python version at least 3.7") assert num_workers >= 1 assert prefetch_factor >= 1 self.dataset = dataset self.sampler = sampler self.num_workers = num_workers self.prefetch_factor = prefetch_factor # Mutable state self._iter = None self._futures = deque([]) # Start the worker processes. The initializer receives the dataset object # from the main process and caches it globally, so that it can be re-used # for subsequent tasks sent to the worker. This helps avoid excessive # communication between the processes. self.pool = ProcessPoolExecutor( num_workers, initializer=_init_worker, initargs=(dataset,), mp_context=get_context("spawn"), )
def __iter__(self) -> "LhotseDataLoader": """Prepares the sampler for iteration and schedules initial tasks to the workers.""" self._iter = iter(self.sampler) for _ in range(self.prefetch_factor * self.num_workers): self._schedule_one() return self def _schedule_one(self) -> None: """Submits a task and stores the future for results retrieval.""" if self._iter is not None: try: self._futures.append(self.pool.submit(_get_item, next(self._iter))) except StopIteration: self._iter = None def _retrieve_one(self) -> Dict[str, Any]: """Retrieves the result from the earliest submitted task.""" if self._futures: return self._futures.popleft().result() raise StopIteration() def __next__(self) -> Dict[str, Any]: """Submits a new batch to process and then retrieves and returns a completed batch.""" self._schedule_one() return self._retrieve_one()
def _init_worker(dataset: torch.utils.data.Dataset) -> None: """ Stores the dataset in the global state of the process -- this is safe because the process is initialized only once and used for unique dataset in its life span. """ global _GLOBAL_DATASET_CACHE _GLOBAL_DATASET_CACHE = dataset def _get_item(cut_ids: List[str]) -> Dict[str, Any]: """ Queries the globally cached dataset to retrieve a batch. Has to be run inside a worker process that was initialized with :meth:`._init_worker`. """ return _GLOBAL_DATASET_CACHE[cut_ids] _GLOBAL_DATASET_CACHE = None