Source code for lhotse.shar.readers.lazy

import random
from pathlib import Path
from typing import (
    Callable,
    Dict,
    Generator,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
)

from lhotse.cut import Cut
from lhotse.dataset.dataloading import resolve_seed
from lhotse.lazy import (
    Dillable,
    LazyIteratorChain,
    LazyJsonlIterator,
    LazyManifestIterator,
    count_newlines_fast,
)
from lhotse.serialization import extension_contains
from lhotse.shar.readers.tar import TarIterator
from lhotse.utils import Pathlike, exactly_one_not_null, ifnone


[docs] class LazySharIterator(Dillable): """ LazySharIterator reads cuts and their corresponding data from multiple shards, also recognized as the Lhotse Shar format. Each shard is numbered and represented as a collection of one text manifest and one or more binary tarfiles. Each tarfile contains a single type of data, e.g., recordings, features, or custom fields. Given an example directory named ``some_dir``, its expected layout is ``some_dir/cuts.000000.jsonl.gz``, ``some_dir/recording.000000.tar``, ``some_dir/features.000000.tar``, and then the same names but numbered with ``000001``, etc. There may also be other files if the cuts have custom data attached to them. The main idea behind Lhotse Shar format is to optimize dataloading with sequential reads, while keeping the data composition more flexible than e.g. WebDataset tar archives do. To achieve this, Lhotse Shar keeps each data type in a separate archive, along a single CutSet JSONL manifest. This way, the metadata can be investigated without iterating through the binary data. The format also allows iteration over a subset of fields, or extension of existing data with new fields. As you iterate over cuts from ``LazySharIterator``, it keeps a file handle open for the JSONL manifest and all of the tar files that correspond to the current shard. The tar files are read item by item together, and their binary data is attached to the cuts. It can be normally accessed using methods such as ``cut.load_audio()``. We can simply load a directory created by :class:`~lhotse.shar.writers.shar.SharWriter`. Example:: >>> cuts = LazySharIterator(in_dir="some_dir") ... for cut in cuts: ... print("Cut", cut.id, "has duration of", cut.duration) ... audio = cut.load_audio() ... fbank = cut.load_features() :class:`.LazySharIterator` can also be initialized from a dict, where the keys indicate fields to be read, and the values point to actual shard locations. This is useful when only a subset of data is needed, or it is stored in different directories. Example:: >>> cuts = LazySharIterator({ ... "cuts": ["some_dir/cuts.000000.jsonl.gz"], ... "recording": ["another_dir/recording.000000.tar"], ... "features": ["yet_another_dir/features.000000.tar"], ... }) ... for cut in cuts: ... print("Cut", cut.id, "has duration of", cut.duration) ... audio = cut.load_audio() ... fbank = cut.load_features() We also support providing shell commands as shard sources, inspired by WebDataset. Example:: >>> cuts = LazySharIterator({ ... "cuts": ["pipe:curl https://my.page/cuts.000000.jsonl.gz"], ... "recording": ["pipe:curl https://my.page/recording.000000.tar"], ... }) ... for cut in cuts: ... print("Cut", cut.id, "has duration of", cut.duration) ... audio = cut.load_audio() Finally, we allow specifying URLs or cloud storage URIs for the shard sources. We defer to ``smart_open`` library to handle those. Example:: >>> cuts = LazySharIterator({ ... "cuts": ["s3://my-bucket/cuts.000000.jsonl.gz"], ... "recording": ["s3://my-bucket/recording.000000.tar"], ... }) ... for cut in cuts: ... print("Cut", cut.id, "has duration of", cut.duration) ... audio = cut.load_audio() :param fields: a dict whose keys specify which fields to load, and values are lists of shards (either paths or shell commands). The field "cuts" pointing to CutSet shards always has to be present. :param in_dir: path to a directory created with ``SharWriter`` with all the shards in a single place. Can be used instead of ``fields``. :param split_for_dataloading: bool, by default ``False`` which does nothing. Setting it to ``True`` is intended for PyTorch training with multiple dataloader workers and possibly multiple DDP nodes. It results in each node+worker combination receiving a unique subset of shards from which to read data to avoid data duplication. This is mutually exclusive with ``seed='randomized'``. :param shuffle_shards: bool, by default ``False``. When ``True``, the shards are shuffled (in case of multi-node training, the shuffling is the same on each node given the same seed). :param seed: When ``shuffle_shards`` is ``True``, we use this number to seed the RNG. Seed can be set to ``'randomized'`` in which case we expect that the user provided :func:`lhotse.dataset.dataloading.worker_init_fn` as DataLoader's ``worker_init_fn`` argument. It will cause the iterator to shuffle shards differently on each node and dataloading worker in PyTorch training. This is mutually exclusive with ``split_for_dataloading=True``. Seed can be set to ``'trng'`` which, like ``'randomized'``, shuffles the shards differently on each iteration, but is not possible to control (and is not reproducible). ``trng`` mode is mostly useful when the user has limited control over the training loop and may not be able to guarantee internal Shar epoch is being incremented, but needs randomness on each iteration (e.g. useful with PyTorch Lightning). :param stateful_shuffle: bool, by default ``False``. When ``True``, every time this object is fully iterated, it increments an internal epoch counter and triggers shard reshuffling with RNG seeded by ``seed`` + ``epoch``. Doesn't have any effect when ``shuffle_shards`` is ``False``. :param cut_map_fns: optional sequence of callables that accept cuts and return cuts. It's expected to have the same length as the number of shards, so each function corresponds to a specific shard. It can be used to attach shard-specific custom attributes to cuts. See also: :class:`~lhotse.shar.writers.shar.SharWriter` """
[docs] def __init__( self, fields: Optional[Dict[str, Sequence[Pathlike]]] = None, in_dir: Optional[Pathlike] = None, split_for_dataloading: bool = False, shuffle_shards: bool = False, stateful_shuffle: bool = True, seed: Union[int, Literal["randomized"], Literal["trng"]] = 42, cut_map_fns: Optional[Sequence[Callable[[Cut], Cut]]] = None, ) -> None: assert exactly_one_not_null( fields, in_dir ), "To read Lhotse Shar format, provide either 'in_dir' or 'fields' argument." if split_for_dataloading: assert seed != "randomized", ( "Error: seed='randomized' and split_for_dataloading=True are mutually exclusive options " "as they would result in data loss." ) self.split_for_dataloading = split_for_dataloading self.shuffle_shards = shuffle_shards self.stateful_shuffle = stateful_shuffle self.seed = seed self.epoch = 0 self._len = None if in_dir is not None: self._init_from_dir(in_dir) else: self._init_from_inputs(fields) self.num_shards = len(self.streams["cuts"]) for field in self.fields: assert ( len(self.streams[field]) == self.num_shards ), f"Expected {self.num_shards} shards available for field '{field}' but found {len(self.streams[field])}: {self.streams[field]}" self.shards = [ {field: self.streams[field][shard_idx] for field in self.streams} for shard_idx in range(self.num_shards) ] self.cut_map_fns = ifnone(cut_map_fns, [None] * self.num_shards)
def _init_from_inputs(self, fields: Optional[Dict[str, Sequence[str]]] = None): assert ( "cuts" in fields ), "To initialize Shar reader, please provide the value for key 'cuts' in 'fields'." self.fields = set(fields.keys()) self.fields.remove("cuts") self.streams = fields def _init_from_dir(self, in_dir: Pathlike): self.in_dir = Path(in_dir) all_paths = list(self.in_dir.glob("*")) self.fields = set(p.stem.split(".")[0] for p in all_paths) assert "cuts" in self.fields self.fields.remove("cuts") self.streams = { "cuts": sorted( p for p in all_paths if p.name.split(".")[0] == "cuts" and extension_contains(".jsonl", p) ) } for field in self.fields: self.streams[field] = sorted( p for p in all_paths if p.name.split(".")[0] == field ) def _maybe_split_for_dataloading(self, shards: List) -> List: from .utils import split_by_node, split_by_worker if self.split_for_dataloading: return split_by_worker(split_by_node(shards)) else: return shards def _maybe_shuffle_shards(self, shards: List) -> List: if self.shuffle_shards: shards = shards.copy() seed = resolve_seed(self.seed) if self.stateful_shuffle: seed += self.epoch random.Random(seed).shuffle(shards) return shards def __iter__(self): shards, map_fns = self.shards, self.cut_map_fns shards = self._maybe_shuffle_shards(shards) shards = self._maybe_split_for_dataloading(shards) if map_fns is not None: # The functions also need to be shuffled/split, if present. map_fns = self._maybe_shuffle_shards(map_fns) map_fns = self._maybe_split_for_dataloading(map_fns) for shard, cut_map_fn in zip(shards, map_fns): # Iterate over cuts for the current shard cuts = LazyManifestIterator(shard["cuts"]) # Iterate over tarfiles/jsonl containing data for specific fields of each cut field_paths = { field: path for field, path in shard.items() if field != "cuts" } # Open every tarfile/jsonl so it's ready for streaming field_iters = { field: TarIterator(path) if extension_contains(".tar", path) else _jsonl_tar_adaptor(LazyJsonlIterator(path), field=field) for field, path in field_paths.items() } # *field_data contains all fields for a single cut (recording, features, array, etc.) for cut, *field_data in zip(cuts, *field_iters.values()): for (field, (maybe_manifest, data_path)) in zip( field_iters.keys(), field_data, ): if maybe_manifest is None: continue # No value available for the current field for this cut. assert ( str(data_path.parent / data_path.stem) == cut.id ), f"Mismatched IDs: cut ID is '{cut.id}' but found data with name '{data_path}' fsor field {field}" setattr(cut, field, maybe_manifest) cut.shard_origin = shard["cuts"] cut.shar_epoch = self.epoch if cut_map_fn is not None: cut = cut_map_fn(cut) yield cut self.epoch += 1 def __len__(self) -> int: if self._len is None: self._len = sum(count_newlines_fast(p) for p in self.streams["cuts"]) return self._len def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other)
def _jsonl_tar_adaptor( jsonl_iter: LazyJsonlIterator, field: str ) -> Generator[Tuple[Optional[dict], Path], None, None]: """ Used to adapt the iteration output of LazyJsonlIterator to mimic that of TarIterator. """ for item in jsonl_iter: # Add extension to make sure Path.stem works OK... pseudo_path = Path(f"{item['cut_id']}.dummy") if field not in item: # We got a placeholder item = None else: item = item[field] yield item, pseudo_path