Source code for lhotse.dataset.sampling.dynamic

import random
import warnings
from collections import deque
from typing import (
    Any,
    Callable,
    Dict,
    Generator,
    Iterable,
    List,
    Literal,
    Optional,
    Tuple,
    Union,
)

from lhotse import CutSet, Seconds
from lhotse.cut import Cut
from lhotse.dataset.dataloading import resolve_seed
from lhotse.dataset.sampling.base import (
    CutSampler,
    EpochDiagnostics,
    SamplingConstraint,
    SamplingDiagnostics,
    TimeConstraint,
)
from lhotse.utils import ifnone, streaming_shuffle


[docs] class DynamicCutSampler(CutSampler): """ A dynamic (streaming) variant of sampler that doesn't stratify the sampled cuts in any way. It is a generalization of :class:`~lhotse.dataset.sampling.SimpleCutSampler` and :class:`~lhotse.dataset.sampling.CutPairsSampler` in that it allows to jointly iterate an arbitrary number of CutSets. When input CutSets are opened in lazy mode, this sampler doesn't require reading the whole cut set into memory. For scenarios such as ASR, VAD, Speaker ID, or TTS training, this class supports single CutSet iteration. Example:: >>> cuts = CutSet(...) >>> sampler = DynamicCutSampler(cuts, max_duration=100) >>> for batch in sampler: ... assert isinstance(batch, CutSet) For other scenarios that require pairs (or triplets, etc.) of utterances, this class supports zipping multiple CutSets together. Such scenarios could be voice conversion, speech translation, contrastive self-supervised training, etc. Example:: >>> source_cuts = CutSet(...) >>> target_cuts = CutSet(...) >>> sampler = DynamicCutSampler(source_cuts, target_cuts, max_duration=100) >>> for batch in sampler: ... assert isinstance(batch, tuple) ... assert len(batch) == 2 ... assert isinstance(batch[0], CutSet) ... assert isinstance(batch[1], CutSet) .. note:: for cut pairs, triplets, etc. the user is responsible for ensuring that the CutSets are all sorted so that when iterated over sequentially, the items are matched. We take care of preserving the right ordering internally, e.g., when shuffling. By default, we check that the cut IDs are matching, but that can be disabled. .. caution:: when using :meth:`DynamicCutSampler.filter` to filter some cuts with more than one CutSet to sample from, we sample one cut from every CutSet, and expect that all of the cuts satisfy the predicate -- otherwise, they are all discarded from being sampled. """
[docs] def __init__( self, *cuts: Iterable, max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, constraint: Optional[SamplingConstraint] = None, shuffle: bool = False, drop_last: bool = False, consistent_ids: bool = True, shuffle_buffer_size: int = 20000, quadratic_duration: Optional[Seconds] = None, world_size: Optional[int] = None, rank: Optional[int] = None, seed: Union[int, Literal["trng", "randomized"]] = 0, strict=None, ) -> None: """ :param cuts: one or more CutSets (when more than one, will yield tuples of CutSets as mini-batches) :param max_duration: The maximum total recording duration from ``cuts``. Note: with multiple CutSets, ``max_duration`` constraint applies only to the first CutSet. :param max_cuts: The maximum total number of ``cuts`` per batch. When only ``max_duration`` is specified, this sampler yields static batch sizes. :param constraint: Provide a :class:`~lhotse.dataset.sampling.base.SamplingConstraint` object defining how the sampler decides when a mini-batch is complete. It also affects which attribute of the input examples decides the "size" of the example (by default it's ``.duration``). Before this parameter was introduced, Lhotse samplers used :class:`~lhotse.dataset.sampling.base.TimeConstraint` implicitly. Introduced in Lhotse v1.22.0. :param shuffle: When ``True``, the cuts will be shuffled dynamically with a reservoir-sampling-based algorithm. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: `for epoch in range(10): for batch in dataset: ...` as every epoch will see a different cuts order. :param drop_last: When ``True``, we will drop all incomplete batches. A batch is considered incomplete if it depleted a bucket before hitting the constraint such as max_duration, max_cuts, etc. :param consistent_ids: Only affects processing of multiple CutSets. When ``True``, at each sampling step we check cuts from all CutSets have the same ID (i.e., the first cut from every CutSet should have the same ID, same for the second, third, etc.). :param shuffle_buffer_size: How many cuts (or cut pairs, triplets) are being held in memory a buffer used for streaming shuffling. Larger number means better randomness at the cost of higher memory usage. :param quadratic_duration: When set, it adds an extra penalty that's quadratic in size w.r.t. a cuts duration. This helps get a more even GPU utilization across different input lengths when models have quadratic input complexity. Set between 15 and 40 for transformers. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ super().__init__( drop_last=drop_last, world_size=world_size, rank=rank, seed=seed ) if not all(cs.is_lazy for cs in cuts if isinstance(cs, CutSet)): warnings.warn( "You are using DynamicCutSampler with an eagerly read CutSet. " "You won't see any memory/speed benefits with that setup. " "Use e.g. 'CutSet.from_jsonl_lazy' to read the CutSet lazily." ) self.cuts = cuts self.max_duration = max_duration self.max_cuts = max_cuts self.constraint = constraint self.shuffle = shuffle self.consistent_ids = consistent_ids self.shuffle_buffer_size = shuffle_buffer_size self.quadratic_duration = quadratic_duration self.rng = None if strict is not None: warnings.warn( "In Lhotse v1.4 all samplers act as if 'strict=True'. " "Sampler's argument 'strict' will be removed in a future Lhotse release.", category=DeprecationWarning, )
[docs] def state_dict(self) -> Dict[str, Any]: assert ( self.constraint is None ), "state_dict() is not supported with samplers that use a custom constraint." sd = super().state_dict() sd.update( { "max_duration": self.max_duration, "max_cuts": self.max_cuts, "consistent_ids": self.consistent_ids, "shuffle_buffer_size": self.shuffle_buffer_size, "quadratic_duration": self.quadratic_duration, } ) return sd
[docs] def load_state_dict(self, sd: Dict[str, Any]) -> None: self.max_duration = sd.pop("max_duration") self.max_cuts = sd.pop("max_cuts") self.consistent_ids = sd.pop("consistent_ids") self.shuffle_buffer_size = sd.pop("shuffle_buffer_size") self.quadratic_duration = sd.pop("quadratic_duration") sd.pop("strict", None) # backward compatibility super().load_state_dict(sd) self._fast_forward()
def _fast_forward(self): current_epoch = self.diagnostics.current_epoch num_batches_to_iter = self.diagnostics.current_epoch_stats.total_batches # Set the right epoch self.set_epoch(current_epoch) # Reset diagnostics for this epoch as we're about to re-iterate self.diagnostics.stats_per_epoch[current_epoch] = EpochDiagnostics( epoch=current_epoch ) self._just_restored_state = False iter(self) for _ in range(num_batches_to_iter): next(self) self._just_restored_state = True def __iter__(self) -> "DynamicCutSampler": if self._just_restored_state: return self # Why reset the current epoch? # Either we are iterating the epoch for the first time and it's a no-op, # or we are iterating the same epoch again, in which case setting more steps # than are actually available per epoch would have broken the checkpoint restoration. self.diagnostics.reset_current_epoch() seed = resolve_seed(self.seed) self.rng = random.Random(seed + self.epoch) # Initiate iteration self.cuts_iter = [iter(cs) for cs in self.cuts] # Optionally shuffle if self.shuffle: self.cuts_iter = [ # Important -- every shuffler has a copy of RNG seeded in the same way, # so that they are reproducible. streaming_shuffle( cs, rng=random.Random(seed + self.epoch), bufsize=self.shuffle_buffer_size, ) for cs in self.cuts_iter ] # Apply filter predicate self.cuts_iter = Filter( iterator=zip(*self.cuts_iter), predicate=lambda tpl: all(self._filter_fn(c) for c in tpl), diagnostics=self.diagnostics, ) # Convert Iterable[Cut] -> Iterable[CutSet] self.cuts_iter = DurationBatcher( self.cuts_iter, max_duration=self.max_duration, max_cuts=self.max_cuts, constraint=self.constraint, drop_last=self.drop_last, quadratic_duration=self.quadratic_duration, diagnostics=self.diagnostics, ) self.cuts_iter = iter(self.cuts_iter) return self def _next_batch(self) -> Union[CutSet, Tuple[CutSet]]: batch = next(self.cuts_iter) if self.consistent_ids and isinstance(batch, tuple): for cuts in zip(*batch): expected_id = cuts[0].id assert all(c.id == expected_id for c in cuts[1:]), ( f"The input CutSet are not sorted by cut ID in the same way. " f"We sampled the following mismatched cut IDs: {', '.join(c.id for c in cuts)}. " f"If this is expected, pass the argument 'consistent_ids=False' to DynamicBucketingSampler." ) return batch @property def remaining_duration(self) -> Optional[float]: return None @property def remaining_cuts(self) -> Optional[int]: return None @property def num_cuts(self) -> Optional[int]: return None
# Note: this class is a subset of SimpleCutSampler and is "datapipes" ready. class DurationBatcher: def __init__( self, datapipe: Iterable[Union[Cut, Tuple[Cut]]], max_duration: Seconds = None, max_cuts: Optional[int] = None, constraint: Optional[SamplingConstraint] = None, drop_last: bool = False, quadratic_duration: Optional[Seconds] = None, diagnostics: Optional[SamplingDiagnostics] = None, ) -> None: self.datapipe = datapipe self.reuse_cuts_buffer = deque() self.drop_last = drop_last self.diagnostics = ifnone(diagnostics, SamplingDiagnostics()) check_constraint(constraint, max_duration, max_cuts) if constraint is not None: self.constraint = constraint else: self.constraint = TimeConstraint( max_duration=max_duration, max_cuts=max_cuts, quadratic_duration=quadratic_duration, ) def __iter__(self) -> Generator[Union[CutSet, Tuple[CutSet]], None, None]: self.cuts_iter = iter(self.datapipe) try: while True: yield self._collect_batch() except StopIteration: pass self.cuts_iter = None def _collect_batch(self) -> Union[CutSet, Tuple[CutSet]]: def detuplify( cuts: List[Union[Cut, Tuple[Cut]]] ) -> Union[CutSet, Tuple[CutSet]]: """Helper to do the right thing whether we sampled single cuts or cut tuples.""" if isinstance(cuts[0], tuple): if len(cuts[0]) == 1: cuts = CutSet.from_cuts(cs[0] for cs in cuts) return cuts else: tuple_of_cut_lists = list(zip(*cuts)) return tuple([CutSet.from_cuts(cs) for cs in tuple_of_cut_lists]) else: return CutSet.from_cuts(cuts) self.constraint.reset() cuts = [] while True: # Check that we have not reached the end of the dataset. try: # If this doesn't raise (typical case), it's not the end: keep processing. next_cut_or_tpl = next(self.cuts_iter) except StopIteration: # No more cuts to sample from: if we have a partial batch, # we may output it, unless the user requested to drop it. # We also check if the batch is "almost there" to override drop_last. if cuts and ( not self.drop_last or self.constraint.close_to_exceeding() ): # We have a partial batch and we can return it. return detuplify(cuts) else: # There is nothing more to return or it's discarded: # signal the iteration code to stop. try: self.diagnostics.discard(cuts) except AttributeError: # accounts for cuts being a tuple self.diagnostics.discard(cuts[0]) raise StopIteration() # Track the duration/frames/etc. constraints. cuts.append(next_cut_or_tpl) self.constraint.add( next_cut_or_tpl[0] if isinstance(next_cut_or_tpl, tuple) else next_cut_or_tpl ) # Did we exceed the max_frames and max_cuts constraints? if self.constraint.close_to_exceeding(): # Yes. Finish sampling this batch. if self.constraint.exceeded() and len(cuts) == 1: warnings.warn( "We have exceeded the max_duration constraint during sampling but have only 1 cut. " "This is likely because max_duration was set to a very low value ~10s, " "or you're using a CutSet with very long cuts (e.g. 100s of seconds long)." ) break return detuplify(cuts) class Filter(Iterable): """ A wrapper over an iterable that enables lazy filtering. It works like Python's `filter` built-in by applying the filter predicate to each element and yielding it further if predicate returned ``True``. This variant additionally tracks the number of discarded items and updates the sampling statistics. """ def __init__( self, iterator: Iterable, predicate: Callable[[Cut], bool], diagnostics: Optional[SamplingDiagnostics] = None, ) -> None: self.iterator = iterator self.predicate = predicate self.diagnostics = ifnone(diagnostics, SamplingDiagnostics()) assert callable( self.predicate ), f"LazyFilter: 'predicate' arg must be callable (got {predicate})." def __iter__(self) -> Iterable: for item in self.iterator: if self.predicate(item): yield item else: self.diagnostics.discard(item) def check_constraint(constraint: Optional, max_duration: Optional, max_cuts: Optional): if constraint is not None: assert ( max_duration is None and max_cuts is None ), "Cannot specify both constraint= and max_duration=/max_cuts=" else: assert ( max_duration is not None or max_cuts is not None ), "At least one of max_duration= or max_cuts= has to be defined (or provide constraint=)."