Source code for lhotse.dataset.sampling.dynamic_bucketing

import random
import warnings
from bisect import bisect_right
from collections import deque
from itertools import islice
from typing import Any, Deque, Dict, Generator, Iterable, List, Optional, Tuple, Union

import numpy as np

from lhotse import CutSet, Seconds
from lhotse.cut import Cut
from lhotse.dataset.sampling.base import (
    CutSampler,
    EpochDiagnostics,
    SamplingDiagnostics,
    TimeConstraint,
)
from lhotse.dataset.sampling.dynamic import DurationBatcher, Filter
from lhotse.utils import ifnone, streaming_shuffle


[docs]class DynamicBucketingSampler(CutSampler): """ A dynamic (streaming) variant of :class:`~lhotse.dataset.sampling.bucketing.BucketingSampler`, that doesn't require reading the whole cut set into memory. The basic idea is to sample N (e.g. ~10k) cuts and estimate the boundary durations for buckets. Then, we maintain a buffer of M cuts (stored separately in K buckets) and every time we sample a batch, we consume the input cut iterable for the same amount of cuts. The memory consumption is limited by M at all times. For scenarios such as ASR, VAD, Speaker ID, or TTS training, this class supports single CutSet iteration. Example:: >>> cuts = CutSet(...) >>> sampler = DynamicBucketingSampler(cuts, max_duration=100) >>> for batch in sampler: ... assert isinstance(batch, CutSet) For other scenarios that require pairs (or triplets, etc.) of utterances, this class supports zipping multiple CutSets together. Such scenarios could be voice conversion, speech translation, contrastive self-supervised training, etc. Example:: >>> source_cuts = CutSet(...) >>> target_cuts = CutSet(...) >>> sampler = DynamicBucketingSampler(source_cuts, target_cuts, max_duration=100) >>> for batch in sampler: ... assert isinstance(batch, tuple) ... assert len(batch) == 2 ... assert isinstance(batch[0], CutSet) ... assert isinstance(batch[1], CutSet) .. note:: for cut pairs, triplets, etc. the user is responsible for ensuring that the CutSets are all sorted so that when iterated over sequentially, the items are matched. We take care of preserving the right ordering internally, e.g., when shuffling. By default, we check that the cut IDs are matching, but that can be disabled. .. caution:: when using :meth:`DynamicBucketingSampler.filter` to filter some cuts with more than one CutSet to sample from, we sample one cut from every CutSet, and expect that all of the cuts satisfy the predicate -- otherwise, they are all discarded from being sampled. """
[docs] def __init__( self, *cuts: Iterable[Cut], max_duration: Seconds, max_cuts: Optional[int] = None, num_buckets: int = 10, shuffle: bool = False, drop_last: bool = False, consistent_ids: bool = True, duration_bins: List[Seconds] = None, num_cuts_for_bins_estimate: int = 10000, buffer_size: int = 10000, shuffle_buffer_size: int = 20000, quadratic_duration: Optional[Seconds] = None, world_size: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, strict=None, ) -> None: """ :param cuts: one or more CutSets (when more than one, will yield tuples of CutSets as mini-batches) :param max_duration: The maximum total recording duration from ``cuts``. Note: with multiple CutSets, ``max_duration`` constraint applies only to the first CutSet. :param max_cuts: The maximum total number of ``cuts`` per batch. When only ``max_duration`` is specified, this sampler yields static batch sizes. :param num_buckets: how many buckets to create. :param shuffle: When ``True``, the cuts will be shuffled dynamically with a reservoir-sampling-based algorithm. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: `for epoch in range(10): for batch in dataset: ...` as every epoch will see a different cuts order. :param drop_last: When ``True``, we will drop all incomplete batches. A batch is considered incomplete if it depleted a bucket before hitting the constraint such as max_duration, max_cuts, etc. :param consistent_ids: Only affects processing of multiple CutSets. When ``True``, at each sampling step we check cuts from all CutSets have the same ID (i.e., the first cut from every CutSet should have the same ID, same for the second, third, etc.). :param duration_bins: A list of floats (seconds); when provided, we'll skip the initial estimation of bucket duration bins (useful to speed-up the launching of experiments). :param num_cuts_for_bins_estimate: We will draw this many cuts to estimate the duration bins for creating similar-duration buckets. Larger number means a better estimate to the data distribution, possibly at a longer init cost. :param buffer_size: How many cuts (or cut pairs, triplets) we hold at any time across all of the buckets. Increasing ``max_duration`` (batch_size) or ``num_buckets`` might require increasing this number. It will result in larger memory usage. :param shuffle_buffer_size: How many cuts (or cut pairs, triplets) are being held in memory a buffer used for streaming shuffling. Larger number means better randomness at the cost of higher memory usage. :param quadratic_duration: When set, it adds an extra penalty that's quadratic in size w.r.t. a cuts duration. This helps get a more even GPU utilization across different input lengths when models have quadratic input complexity. Set between 15 and 40 for transformers. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ super().__init__( drop_last=drop_last, world_size=world_size, rank=rank, seed=seed ) if not all(cs.is_lazy for cs in cuts if isinstance(cs, CutSet)): warnings.warn( "You are using DynamicBucketingSampler with an eagerly read CutSet. " "You won't see any memory/speed benefits with that setup. " "Either use 'CutSet.from_jsonl_lazy' to read the CutSet lazily, or use a BucketingSampler instead." ) self.cuts = cuts self.max_duration = max_duration self.max_cuts = max_cuts self.shuffle = shuffle self.consistent_ids = consistent_ids self.num_cuts_for_bins_estimate = num_cuts_for_bins_estimate self.buffer_size = buffer_size self.shuffle_buffer_size = shuffle_buffer_size self.quadratic_duration = quadratic_duration self.rng = None assert any( v is not None for v in (self.max_duration, self.max_cuts) ), "At least one of max_duration or max_cuts has to be set." if strict is not None: warnings.warn( "In Lhotse v1.4 all samplers act as if 'strict=True'. " "Sampler's argument 'strict' will be removed in a future Lhotse release.", category=DeprecationWarning, ) if self.shuffle: cuts_for_bins_estimate = streaming_shuffle( iter(self.cuts[0]), rng=random.Random(self.seed), bufsize=self.shuffle_buffer_size, ) else: cuts_for_bins_estimate = self.cuts[0] if duration_bins is not None: assert len(duration_bins) == num_buckets - 1, ( f"num_buckets=={num_buckets} but len(duration_bins)=={len(duration_bins)} " f"(bins are the boundaries, it should be one less than the number of buckets)." ) assert list(duration_bins) == sorted( duration_bins ), "Duration bins must be sorted ascendingly." self.duration_bins = duration_bins else: self.duration_bins = estimate_duration_buckets( islice(cuts_for_bins_estimate, num_cuts_for_bins_estimate), num_buckets=num_buckets, )
[docs] def state_dict(self) -> Dict[str, Any]: sd = super().state_dict() sd.update( { "max_duration": self.max_duration, "max_cuts": self.max_cuts, "consistent_ids": self.consistent_ids, "buffer_size": self.buffer_size, "num_cuts_for_bins_estimate": self.num_cuts_for_bins_estimate, "shuffle_buffer_size": self.shuffle_buffer_size, "quadratic_duration": self.quadratic_duration, } ) return sd
[docs] def load_state_dict(self, sd: Dict[str, Any]) -> None: self.max_duration = sd.pop("max_duration") self.max_cuts = sd.pop("max_cuts") self.consistent_ids = sd.pop("consistent_ids") self.num_cuts_for_bins_estimate = sd.pop("num_cuts_for_bins_estimate") self.buffer_size = sd.pop("buffer_size") self.shuffle_buffer_size = sd.pop("shuffle_buffer_size") self.quadratic_duration = sd.pop("quadratic_duration", None) sd.pop("strict", None) # backward compatibility super().load_state_dict(sd) self._fast_forward()
def _fast_forward(self): current_epoch = self.diagnostics.current_epoch num_batches_to_iter = self.diagnostics.current_epoch_stats.total_batches # Set the right epoch self.set_epoch(current_epoch) # Reset diagnostics for this epoch as we're about to re-iterate self.diagnostics.stats_per_epoch[current_epoch] = EpochDiagnostics( epoch=current_epoch ) self._just_restored_state = False iter(self) for _ in range(num_batches_to_iter): next(self) self._just_restored_state = True def __iter__(self) -> "DynamicBucketingSampler": if self._just_restored_state: return self self.rng = random.Random(self.seed + self.epoch) # Why reset the current epoch? # Either we are iterating the epoch for the first time and it's a no-op, # or we are iterating the same epoch again, in which case setting more steps # than are actually available per epoch would have broken the checkpoint restoration. self.diagnostics.reset_current_epoch() # Initiate iteration self.cuts_iter = [iter(cs) for cs in self.cuts] # Optionally shuffle if self.shuffle: self.cuts_iter = [ # Important -- every shuffler has a copy of RNG seeded in the same way, # so that they are reproducible. streaming_shuffle( cs, rng=random.Random(self.seed + self.epoch), bufsize=self.shuffle_buffer_size, ) for cs in self.cuts_iter ] # Apply filter predicate self.cuts_iter = Filter( iterator=zip(*self.cuts_iter), predicate=lambda tpl: all(self._filter_fn(c) for c in tpl), diagnostics=self.diagnostics, ) # Convert Iterable[Cut] -> Iterable[CutSet] self.cuts_iter = DynamicBucketer( self.cuts_iter, duration_bins=self.duration_bins, max_duration=self.max_duration, max_cuts=self.max_cuts, drop_last=self.drop_last, buffer_size=self.buffer_size, quadratic_duration=self.quadratic_duration, rng=self.rng, diagnostics=self.diagnostics, ) self.cuts_iter = iter(self.cuts_iter) return self def _next_batch(self) -> Union[CutSet, Tuple[CutSet]]: batch = next(self.cuts_iter) if self.consistent_ids and isinstance(batch, tuple): for cuts in zip(*batch): expected_id = cuts[0].id assert all(c.id == expected_id for c in cuts[1:]), ( f"The input CutSet are not sorted by cut ID in the same way. " f"We sampled the following mismatched cut IDs: {', '.join(c.id for c in cuts)}. " f"If this is expected, pass the argument 'consistent_ids=False' to DynamicBucketingSampler." ) return batch @property def remaining_duration(self) -> Optional[float]: return None @property def remaining_cuts(self) -> Optional[int]: return None @property def num_cuts(self) -> Optional[int]: return None
def estimate_duration_buckets(cuts: Iterable[Cut], num_buckets: int) -> List[Seconds]: """ Given an iterable of cuts and a desired number of buckets, select duration values that should start each bucket. The returned list, ``bins``, has ``num_buckets - 1`` elements. The first bucket should contain cuts with duration ``0 <= d < bins[0]``; the last bucket should contain cuts with duration ``bins[-1] <= d < float("inf")``, ``i``-th bucket should contain cuts with duration ``bins[i - 1] <= d < bins[i]``. :param cuts: an iterable of :class:`lhotse.cut.Cut`. :param num_buckets: desired number of buckets. :return: a list of boundary duration values (floats). """ assert num_buckets > 1 durs = np.array([c.duration for c in cuts]) durs.sort() assert num_buckets <= durs.shape[0], ( f"The number of buckets ({num_buckets}) must be smaller than " f"or equal to the number of cuts ({durs.shape[0]})." ) bucket_duration = durs.sum() / num_buckets bins = [] tot = 0.0 for dur in durs: if tot > bucket_duration: bins.append(dur) tot = 0.0 tot += dur return bins class DynamicBucketer: def __init__( self, cuts: Iterable[Union[Cut, Tuple[Cut]]], duration_bins: List[Seconds], max_duration: float, max_cuts: Optional[int] = None, drop_last: bool = False, buffer_size: int = 10000, quadratic_duration: Optional[Seconds] = None, rng: random.Random = None, diagnostics: Optional[SamplingDiagnostics] = None, ) -> None: self.cuts = cuts self.duration_bins = duration_bins self.max_duration = max_duration self.max_cuts = max_cuts self.drop_last = drop_last self.buffer_size = buffer_size self.quadratic_duration = quadratic_duration self.diagnostics = ifnone(diagnostics, SamplingDiagnostics()) if rng is None: rng = random.Random() self.rng = rng assert duration_bins == sorted(duration_bins), ( f"Argument list for 'duration_bins' is expected to be in " f"sorted order (got: {duration_bins})." ) # A heuristic diagnostic first, for finding the right settings. mean_duration = np.mean(duration_bins) expected_buffer_duration = buffer_size * mean_duration expected_bucket_duration = expected_buffer_duration / (len(duration_bins) + 1) if expected_bucket_duration < max_duration: warnings.warn( f"Your 'buffer_size' setting of {buffer_size} might be too low to satisfy " f"a 'max_duration' of {max_duration} (given our best guess)." ) # Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`). self.buckets: List[Deque[Union[Cut, Tuple[Cut]]]] = [ deque() for _ in range(len(duration_bins) + 1) ] def __iter__(self) -> Generator[CutSet, None, None]: # Init: sample `buffer_size` cuts and assign them to the right buckets. self.cuts_iter = iter(self.cuts) self._collect_cuts_in_buckets(self.buffer_size) # Init: determine which buckets are "ready" def is_ready(bucket: Deque[Cut]): tot = TimeConstraint( max_duration=self.max_duration, max_cuts=self.max_cuts, quadratic_duration=self.quadratic_duration, ) for c in bucket: tot.add(c[0] if isinstance(c, tuple) else c) if tot.close_to_exceeding(): return True return False # The iteration code starts here. # On each step we're sampling a new batch. try: while True: ready_buckets = [b for b in self.buckets if is_ready(b)] if not ready_buckets: # No bucket has enough data to yield for the last full batch. non_empty_buckets = [b for b in self.buckets if b] if self.drop_last or len(non_empty_buckets) == 0: # Either the user requested only full batches, or we have nothing left. raise StopIteration() else: # Sample from partial batches that are left. ready_buckets = non_empty_buckets # Choose a bucket to sample from. # We'll only select from the buckets that have a full batch available. sampling_bucket = self.rng.choice(ready_buckets) # Sample one batch from that bucket and yield it to the caller. batcher = DurationBatcher( sampling_bucket, max_duration=self.max_duration, max_cuts=self.max_cuts, quadratic_duration=self.quadratic_duration, diagnostics=self.diagnostics, ) batch = next(iter(batcher)) if isinstance(batch, tuple): batch_size = len(batch[0]) else: batch_size = len(batch) yield batch # Remove sampled cuts from the bucket. for _ in range(batch_size): sampling_bucket.popleft() # Fetch new cuts and add them to appropriate buckets. self._collect_cuts_in_buckets(batch_size) except StopIteration: pass # Cleanup. self.cuts_iter = None def _collect_cuts_in_buckets(self, n_cuts: int): try: for _ in range(n_cuts): cuts = next(self.cuts_iter) duration = ( cuts[0].duration if isinstance(cuts, tuple) else cuts.duration ) bucket_idx = bisect_right(self.duration_bins, duration) self.buckets[bucket_idx].append(cuts) except StopIteration: pass