Source code for lhotse.dataset.sampling.dynamic_bucketing

import random
import warnings
from bisect import bisect_right
from collections import deque
from itertools import islice
from typing import (
    Any,
    Deque,
    Dict,
    Generator,
    Iterable,
    List,
    Literal,
    Optional,
    Sequence,
    Tuple,
    Union,
)

import numpy as np

from lhotse import CutSet, Seconds
from lhotse.cut import Cut
from lhotse.dataset.dataloading import resolve_seed
from lhotse.dataset.sampling.base import (
    CutSampler,
    EpochDiagnostics,
    SamplingConstraint,
    SamplingDiagnostics,
    TimeConstraint,
)
from lhotse.dataset.sampling.dynamic import DurationBatcher, Filter, check_constraint
from lhotse.utils import ifnone


[docs] class DynamicBucketingSampler(CutSampler): """ A dynamic (streaming) variant of :class:`~lhotse.dataset.sampling.bucketing.BucketingSampler`, that doesn't require reading the whole cut set into memory. The basic idea is to sample N (e.g. ~10k) cuts and estimate the boundary durations for buckets. Then, we maintain a buffer of M cuts (stored separately in K buckets) and every time we sample a batch, we consume the input cut iterable for the same amount of cuts. The memory consumption is limited by M at all times. For scenarios such as ASR, VAD, Speaker ID, or TTS training, this class supports single CutSet iteration. Example:: >>> cuts = CutSet(...) >>> sampler = DynamicBucketingSampler(cuts, max_duration=100) >>> for batch in sampler: ... assert isinstance(batch, CutSet) For other scenarios that require pairs (or triplets, etc.) of utterances, this class supports zipping multiple CutSets together. Such scenarios could be voice conversion, speech translation, contrastive self-supervised training, etc. Example:: >>> source_cuts = CutSet(...) >>> target_cuts = CutSet(...) >>> sampler = DynamicBucketingSampler(source_cuts, target_cuts, max_duration=100) >>> for batch in sampler: ... assert isinstance(batch, tuple) ... assert len(batch) == 2 ... assert isinstance(batch[0], CutSet) ... assert isinstance(batch[1], CutSet) .. note:: for cut pairs, triplets, etc. the user is responsible for ensuring that the CutSets are all sorted so that when iterated over sequentially, the items are matched. We take care of preserving the right ordering internally, e.g., when shuffling. By default, we check that the cut IDs are matching, but that can be disabled. .. caution:: when using :meth:`DynamicBucketingSampler.filter` to filter some cuts with more than one CutSet to sample from, we sample one cut from every CutSet, and expect that all of the cuts satisfy the predicate -- otherwise, they are all discarded from being sampled. """
[docs] def __init__( self, *cuts: Iterable, max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, constraint: Optional[SamplingConstraint] = None, num_buckets: Optional[int] = 10, shuffle: bool = False, drop_last: bool = False, consistent_ids: bool = True, duration_bins: List[Seconds] = None, num_cuts_for_bins_estimate: int = 10000, buffer_size: int = 20000, quadratic_duration: Optional[Seconds] = None, world_size: Optional[int] = None, rank: Optional[int] = None, seed: Union[int, Literal["randomized", "trng"]] = 0, strict=None, shuffle_buffer_size=None, ) -> None: """ :param cuts: one or more CutSets (when more than one, will yield tuples of CutSets as mini-batches) :param max_duration: The maximum total recording duration from ``cuts``. Note: with multiple CutSets, ``max_duration`` constraint applies only to the first CutSet. :param max_cuts: The maximum total number of ``cuts`` per batch. When only ``max_duration`` is specified, this sampler yields static batch sizes. :param num_buckets: how many buckets to create. :param shuffle: When ``True``, the cuts will be shuffled dynamically with a reservoir-sampling-based algorithm. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: `for epoch in range(10): for batch in dataset: ...` as every epoch will see a different cuts order. :param drop_last: When ``True``, we will drop all incomplete batches. A batch is considered incomplete if it depleted a bucket before hitting the constraint such as max_duration, max_cuts, etc. :param consistent_ids: Only affects processing of multiple CutSets. When ``True``, at each sampling step we check cuts from all CutSets have the same ID (i.e., the first cut from every CutSet should have the same ID, same for the second, third, etc.). :param duration_bins: A list of floats (seconds); when provided, we'll skip the initial estimation of bucket duration bins (useful to speed-up the launching of experiments). :param num_cuts_for_bins_estimate: We will draw this many cuts to estimate the duration bins for creating similar-duration buckets. Larger number means a better estimate to the data distribution, possibly at a longer init cost. :param buffer_size: How many cuts (or cut pairs, triplets) we hold at any time across all of the buckets. Increasing ``max_duration`` (batch_size) or ``num_buckets`` might require increasing this number. Larger number here will also improve shuffling capabilities. It will result in larger memory usage. :param quadratic_duration: When set, it adds an extra penalty that's quadratic in size w.r.t. a cuts duration. This helps get a more even GPU utilization across different input lengths when models have quadratic input complexity. Set between 15 and 40 for transformers. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ super().__init__( drop_last=drop_last, world_size=world_size, rank=rank, seed=seed ) if not all(cs.is_lazy for cs in cuts if isinstance(cs, CutSet)): warnings.warn( "You are using DynamicBucketingSampler with an eagerly read CutSet. " "You won't see any memory/speed benefits with that setup. " "Either use 'CutSet.from_jsonl_lazy' to read the CutSet lazily, or use a BucketingSampler instead." ) self.cuts = cuts self.max_duration = max_duration self.max_cuts = max_cuts self.constraint = constraint self.shuffle = shuffle self.consistent_ids = consistent_ids self.num_cuts_for_bins_estimate = num_cuts_for_bins_estimate self.buffer_size = buffer_size self.quadratic_duration = quadratic_duration self.rng = None check_constraint(constraint, max_duration, max_cuts) if strict is not None: warnings.warn( "In Lhotse v1.4 all samplers act as if 'strict=True'. " "Sampler's argument 'strict' will be removed in a future Lhotse release.", category=DeprecationWarning, ) if shuffle_buffer_size is not None: _emit_shuffle_buffer_size_warning() self.buffer_size += shuffle_buffer_size if duration_bins is not None: if num_buckets is not None: assert len(duration_bins) == num_buckets - 1, ( f"num_buckets=={num_buckets} but len(duration_bins)=={len(duration_bins)} " f"(bins are the boundaries, it should be one less than the number of buckets)." ) assert list(duration_bins) == sorted( duration_bins ), "Duration bins must be sorted ascendingly." self.duration_bins = duration_bins else: if constraint is None: constraint = TimeConstraint( max_duration=self.max_duration, max_cuts=self.max_cuts, quadratic_duration=self.quadratic_duration, ) self.duration_bins = estimate_duration_buckets( islice(self.cuts[0], num_cuts_for_bins_estimate), num_buckets=num_buckets, constraint=constraint, )
[docs] def state_dict(self) -> Dict[str, Any]: assert ( self.constraint is None ), "state_dict() is not supported with samplers that use a custom constraint." sd = super().state_dict() sd.update( { "max_duration": self.max_duration, "max_cuts": self.max_cuts, "consistent_ids": self.consistent_ids, "buffer_size": self.buffer_size, "num_cuts_for_bins_estimate": self.num_cuts_for_bins_estimate, "quadratic_duration": self.quadratic_duration, } ) return sd
[docs] def load_state_dict(self, sd: Dict[str, Any]) -> None: self.max_duration = sd.pop("max_duration") self.max_cuts = sd.pop("max_cuts") self.consistent_ids = sd.pop("consistent_ids") self.num_cuts_for_bins_estimate = sd.pop("num_cuts_for_bins_estimate") self.buffer_size = sd.pop("buffer_size") if "shuffle_buffer_size" in sd: _emit_shuffle_buffer_size_warning() shuffle_buffer_size = sd.pop("shuffle_buffer_size") self.buffer_size += shuffle_buffer_size self.quadratic_duration = sd.pop("quadratic_duration", None) sd.pop("strict", None) # backward compatibility super().load_state_dict(sd) self._fast_forward()
def _fast_forward(self): current_epoch = self.diagnostics.current_epoch num_batches_to_iter = self.diagnostics.current_epoch_stats.total_batches # Set the right epoch self.set_epoch(current_epoch) # Reset diagnostics for this epoch as we're about to re-iterate self.diagnostics.stats_per_epoch[current_epoch] = EpochDiagnostics( epoch=current_epoch ) self._just_restored_state = False iter(self) for _ in range(num_batches_to_iter): next(self) self._just_restored_state = True def __iter__(self) -> "DynamicBucketingSampler": if self._just_restored_state: return self seed = resolve_seed(self.seed) self.rng = random.Random(seed + self.epoch) # Why reset the current epoch? # Either we are iterating the epoch for the first time and it's a no-op, # or we are iterating the same epoch again, in which case setting more steps # than are actually available per epoch would have broken the checkpoint restoration. self.diagnostics.reset_current_epoch() # Initiate iteration cuts_iter = [iter(cs) for cs in self.cuts] # Apply filter predicate cuts_iter = Filter( iterator=zip(*cuts_iter), predicate=lambda tpl: all(self._filter_fn(c) for c in tpl), diagnostics=self.diagnostics, ) # Convert Iterable[Cut] -> Iterable[CutSet] cuts_iter = DynamicBucketer( cuts_iter, duration_bins=self.duration_bins, max_duration=self.max_duration, max_cuts=self.max_cuts, constraint=self.constraint, drop_last=self.drop_last, buffer_size=self.buffer_size, quadratic_duration=self.quadratic_duration, shuffle=self.shuffle, rng=self.rng, diagnostics=self.diagnostics, ) self.cuts_iter = iter(cuts_iter) return self def _next_batch(self) -> Union[CutSet, Tuple[CutSet]]: batch = next(self.cuts_iter) if self.consistent_ids and isinstance(batch, tuple): for cuts in zip(*batch): expected_id = cuts[0].id assert all(c.id == expected_id for c in cuts[1:]), ( f"The input CutSet are not sorted by cut ID in the same way. " f"We sampled the following mismatched cut IDs: {', '.join(c.id for c in cuts)}. " f"If this is expected, pass the argument 'consistent_ids=False' to DynamicBucketingSampler." ) return batch @property def remaining_duration(self) -> Optional[float]: return None @property def remaining_cuts(self) -> Optional[int]: return None @property def num_cuts(self) -> Optional[int]: return None
def estimate_duration_buckets( cuts: Iterable[Cut], num_buckets: int, constraint: Optional[SamplingConstraint] = None, ) -> List[float]: """ Given an iterable of cuts and a desired number of buckets, select duration values that should start each bucket. The returned list, ``bins``, has ``num_buckets - 1`` elements. The first bucket should contain cuts with duration ``0 <= d < bins[0]``; the last bucket should contain cuts with duration ``bins[-1] <= d < float("inf")``, ``i``-th bucket should contain cuts with duration ``bins[i - 1] <= d < bins[i]``. :param cuts: an iterable of :class:`lhotse.cut.Cut`. :param num_buckets: desired number of buckets. :param constraint: object with ``.measure_length()`` method that's used to determine the size of each sample. If ``None``, we'll use ``TimeConstraint``. :return: a list of boundary duration values (floats). """ assert num_buckets > 1 if constraint is None: constraint = TimeConstraint() sizes = np.array([constraint.measure_length(c) for c in cuts]) sizes.sort() assert num_buckets <= sizes.shape[0], ( f"The number of buckets ({num_buckets}) must be smaller than " f"or equal to the number of cuts ({sizes.shape[0]})." ) size_per_bucket = sizes.sum() / num_buckets bins = [] tot = 0.0 for size in sizes: if tot > size_per_bucket: bins.append(size) tot = 0.0 tot += size return bins class DynamicBucketer: def __init__( self, cuts: Iterable[Union[Cut, Tuple[Cut]]], duration_bins: List[Seconds], max_duration: Optional[Seconds] = None, max_cuts: Optional[int] = None, constraint: Optional[SamplingConstraint] = None, drop_last: bool = False, buffer_size: int = 10000, quadratic_duration: Optional[Seconds] = None, shuffle: bool = False, rng: random.Random = None, diagnostics: Optional[SamplingDiagnostics] = None, ) -> None: self.cuts = cuts self.duration_bins = duration_bins self.max_duration = max_duration self.max_cuts = max_cuts self.constraint = constraint self.drop_last = drop_last self.buffer_size = buffer_size self.quadratic_duration = quadratic_duration self.diagnostics = ifnone(diagnostics, SamplingDiagnostics()) if rng is None: rng = random.Random() self.rng = rng self.shuffle = shuffle assert duration_bins == sorted(duration_bins), ( f"Argument list for 'duration_bins' is expected to be in " f"sorted order (got: {duration_bins})." ) check_constraint(constraint, max_duration, max_cuts) if self.constraint is None: self.constraint = TimeConstraint( max_duration=self.max_duration, max_cuts=self.max_cuts, quadratic_duration=self.quadratic_duration, ) # A heuristic diagnostic first, for finding the right settings. if max_duration is not None: mean_duration = np.mean(duration_bins) expected_buffer_duration = buffer_size * mean_duration expected_bucket_duration = expected_buffer_duration / ( len(duration_bins) + 1 ) if expected_bucket_duration < max_duration: warnings.warn( f"Your 'buffer_size' setting of {buffer_size} might be too low to satisfy " f"a 'max_duration' of {max_duration} (given our best guess)." ) # Init: create empty buckets (note: `num_buckets = len(duration_bins) + 1`). self.buckets: List[Deque[Union[Cut, Tuple[Cut, ...]]]] = [ deque() for _ in range(len(duration_bins) + 1) ] def __iter__(self) -> Generator[CutSet, None, None]: # Init: sample `buffer_size` cuts and assign them to the right buckets. self.cuts_iter = iter(self.cuts) self._collect_cuts_in_buckets(self.buffer_size) # Init: determine which buckets are "ready" def is_ready(bucket: Deque[Cut]): tot = self.constraint.copy() for c in bucket: tot.add(c[0] if isinstance(c, tuple) else c) if tot.close_to_exceeding(): return True return False # The iteration code starts here. # On each step we're sampling a new batch. try: while True: ready_buckets = [b for b in self.buckets if is_ready(b)] if not ready_buckets: # No bucket has enough data to yield for the last full batch. non_empty_buckets = [b for b in self.buckets if b] if self.drop_last or len(non_empty_buckets) == 0: # Either the user requested only full batches, or we have nothing left. raise StopIteration() else: # Sample from partial batches that are left. ready_buckets = non_empty_buckets # Choose a bucket to sample from. # We'll only select from the buckets that have a full batch available. sampling_bucket = self.rng.choice(ready_buckets) # Apply random shuffling if requested: we'll shuffle the items present within the bucket. maybe_shuffled = sampling_bucket indexes_used = [] if self.shuffle: maybe_shuffled = pick_at_random( maybe_shuffled, rng=self.rng, out_indexes_used=indexes_used ) # Sample one batch from that bucket and yield it to the caller. batcher = DurationBatcher( maybe_shuffled, constraint=self.constraint.copy(), diagnostics=self.diagnostics, ) batch = next(iter(batcher)) if isinstance(batch, tuple): batch_size = len(batch[0]) else: batch_size = len(batch) yield batch # Remove sampled cuts from the bucket. if indexes_used: # Shuffling, sort indexes of yielded elements largest -> smallest and remove them indexes_used.sort(reverse=True) for idx in indexes_used: del sampling_bucket[idx] else: # No shuffling, remove first N for _ in range(batch_size): sampling_bucket.popleft() # Fetch new cuts and add them to appropriate buckets. self._collect_cuts_in_buckets(batch_size) except StopIteration: pass # Cleanup. self.cuts_iter = None def _collect_cuts_in_buckets(self, n_cuts: int): try: for _ in range(n_cuts): cuts = next(self.cuts_iter) duration = self.constraint.measure_length( cuts[0] if isinstance(cuts, tuple) else cuts ) bucket_idx = bisect_right(self.duration_bins, duration) self.buckets[bucket_idx].append(cuts) except StopIteration: pass def pick_at_random( bucket: Sequence[Union[Cut, Tuple[Cut, ...]]], rng: random.Random, out_indexes_used: list, ) -> Generator[Union[Cut, Tuple[Cut, ...]], None, None]: """ Generator which will yield items in a sequence in a random order. It will append the indexes of items yielded during iteration via ``out_used_indexes``. """ indexes = list(range(len(bucket))) rng.shuffle(indexes) for idx in indexes: out_indexes_used.append(idx) yield bucket[idx] def _emit_shuffle_buffer_size_warning(): warnings.warn( "Since Lhotse v1.20 'shuffle_buffer_size' is deprecated, because DynamicBucketingSampler " "does not require a separate shuffling buffer anymore. " "To improve both shuffling and sampling randomness, increase 'buffer_size' instead. " "To maintain backward compatibility, we will increase 'buffer_size' " "by 'shuffling_buffer_size' for you. " "This argument will be deprecated in a future Lhotse version.", category=DeprecationWarning, )