Source code for lhotse.dataset.sampling.simple

import warnings
from typing import Any, Dict, Optional

from lhotse import CutSet, Seconds
from lhotse.dataset.sampling.base import CutSampler, TimeConstraint
from lhotse.dataset.sampling.data_source import DataSource


[docs] class SimpleCutSampler(CutSampler): """ Samples cuts from a CutSet to satisfy the input constraints. It behaves like an iterable that yields lists of strings (cut IDs). When one of :attr:`max_frames`, :attr:`max_samples`, or :attr:`max_duration` is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. Padding required to collate the batch does not contribute to max frames/samples/duration. Example usage:: >>> dataset = K2SpeechRecognitionDataset(cuts) >>> sampler = SimpleCutSampler(cuts, shuffle=True) >>> loader = DataLoader(dataset, sampler=sampler, batch_size=None) >>> for epoch in range(start_epoch, n_epochs): ... sampler.set_epoch(epoch) ... train(loader) """
[docs] def __init__( self, cuts: CutSet, max_duration: Seconds = None, max_cuts: Optional[int] = None, shuffle: bool = False, drop_last: bool = False, world_size: Optional[int] = None, rank: Optional[int] = None, seed: int = 0, ): """ SimpleCutSampler's constructor. :param cuts: the ``CutSet`` to sample data from. :param max_duration: The maximum total recording duration from ``cuts``. :param max_cuts: The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off. :param shuffle: When ``True``, the cuts will be shuffled at the start of iteration. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: `for epoch in range(10): for batch in dataset: ...` as every epoch will see a different cuts order. :param drop_last: When ``True``, the last batch is dropped if it's incomplete. :param world_size: Total number of distributed nodes. We will try to infer it by default. :param rank: Index of distributed node. We will try to infer it by default. :param seed: Random seed used to consistently shuffle the dataset across different processes. """ super().__init__( drop_last=drop_last, shuffle=shuffle, world_size=world_size, rank=rank, seed=seed, ) assert any( v is not None for v in (max_duration, max_cuts) ), "At least one of max_duration or max_cuts has to be set." self.data_source = DataSource(cuts) self.time_constraint = TimeConstraint( max_duration=max_duration, max_cuts=max_cuts, )
@property def remaining_duration(self) -> Optional[float]: """ Remaining duration of data left in the sampler (may be inexact due to float arithmetic). Not available when the CutSet is read in lazy mode (returns None). """ return self.data_source.remaining_duration @property def remaining_cuts(self) -> Optional[int]: """ Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None). """ return self.data_source.remaining_cuts @property def num_cuts(self) -> Optional[int]: """ Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None). """ if self.data_source.is_lazy: return None return len(self.data_source)
[docs] def state_dict(self) -> Dict[str, Any]: """ Return the current state of the sampler in a state_dict. Together with ``load_state_dict()``, this can be used to restore the training loop's state to the one stored in the state_dict. """ state_dict = super().state_dict() state_dict.update( { "time_constraint": self.time_constraint.state_dict(), } ) return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None: """ Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. .. caution:: The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere. .. caution:: The input ``state_dict`` is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don't want this behavior, pass a copy inside of this function (e.g., using ``import deepcopy``). .. note:: For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be handled in ``__iter__`` to make it avoid resetting the just-restored state (only once). """ time_constraint = TimeConstraint(**state_dict.pop("time_constraint")) if self.time_constraint != time_constraint: warnings.warn( "SimpleCutSampler.load_state_dict(): Inconsistent time_constraint:\n" f"expected {self.time_constraint}\n" f"received {time_constraint}\n" f"We will overwrite the settings with the received state_dict." ) self.time_constraint = time_constraint super().load_state_dict(state_dict) # Restore the data source's state if self.shuffle: self.data_source.shuffle(self.seed + self.epoch) self.data_source.fast_forward(self.diagnostics.current_epoch_stats.total_cuts)
def __iter__(self) -> "SimpleCutSampler": """ Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested. """ # Restored state with load_state_dict()? Skip resetting only this once. if self._just_restored_state: return self # Why reset the current epoch? # Either we are iterating the epoch for the first time and it's a no-op, # or we are iterating the same epoch again, in which case setting more steps # than are actually available per epoch would have broken the checkpoint restoration. self.diagnostics.reset_current_epoch() # Reset the state to the beginning of the epoch. if self.shuffle: self.data_source.shuffle(self.seed + self.epoch) iter(self.data_source) return self def _next_batch(self) -> CutSet: # Keep iterating the underlying CutSet as long as we hit or exceed the constraints # provided by user (the max number of frames or max number of cuts). # Note: no actual data is loaded into memory yet because the manifests contain all the metadata # required to do this operation. self.time_constraint.reset() cuts = [] while True: # Check that we have not reached the end of the dataset. try: # If this doesn't raise (typical case), it's not the end: keep processing. next_cut = next(self.data_source) except StopIteration: # No more cuts to sample from: if we have a partial batch, # we may output it, unless the user requested to drop it. # We also check if the batch is "almost there" to override drop_last. if cuts and ( not self.drop_last or self.time_constraint.close_to_exceeding() ): # We have a partial batch and we can return it. return CutSet.from_cuts(cuts) else: # There is nothing more to return or it's discarded: # signal the iteration code to stop. self.diagnostics.discard(cuts) raise StopIteration() # Check whether the cut we're about to sample satisfies optional user-requested predicate. if not self._filter_fn(next_cut): # No - try another one. self.diagnostics.discard_single(next_cut) continue # Track the duration/frames/etc. constraints. self.time_constraint.add(next_cut) # Did we exceed the max_frames and max_cuts constraints? if not self.time_constraint.exceeded(): # No - add the next cut to the batch, and keep trying. cuts.append(next_cut) else: # Yes. Do we have at least one cut in the batch? if cuts: # Yes. Return the batch, but keep the currently drawn cut for later. self.data_source.take_back(next_cut) break else: # No. We'll warn the user that the constrains might be too tight, # and return the cut anyway. warnings.warn( "The first cut drawn in batch collection violates " "the max_frames, max_cuts, or max_duration constraints - " "we'll return it anyway. " "Consider increasing max_frames/max_cuts/max_duration." ) cuts.append(next_cut) return CutSet.from_cuts(cuts)