Source code for lhotse.dataset.speech_recognition

from typing import Callable, Dict, List, Union

import torch
from import DataLoader, default_collate

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import InputStrategy, PrecomputedFeatures
from lhotse.utils import ifnone

[docs]class K2SpeechRecognitionDataset( """ The PyTorch Dataset for the speech recognition task using K2 library. This dataset expects to be queried with lists of cut IDs, for which it loads features and automatically collates/batches them. To use it with a PyTorch DataLoader, set ``batch_size=None`` and provide a :class:`SingleCutSampler` sampler. Each item in this dataset is a dict of: .. code-block:: { 'inputs': float tensor with shape determined by :attr:`input_strategy`: - single-channel: - features: (B, T, F) - audio: (B, T) - multi-channel: currently not supported 'supervisions': [ { 'sequence_idx': Tensor[int] of shape (S,) 'text': List[str] of len S # For feature input strategies 'start_frame': Tensor[int] of shape (S,) 'num_frames': Tensor[int] of shape (S,) # For audio input strategies 'start_sample': Tensor[int] of shape (S,) 'num_samples': Tensor[int] of shape (S,) # Optionally, when return_cuts=True 'cut': List[AnyCut] of len S } ] } Dimension symbols legend: * ``B`` - batch size (number of Cuts) * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) * ``T`` - number of frames of the longest Cut * ``F`` - number of features The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. """
[docs] def __init__( self, cuts: CutSet, return_cuts: bool = False, cut_transforms: List[Callable[[CutSet], CutSet]] = None, input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, input_strategy: InputStrategy = PrecomputedFeatures(), check_inputs: bool = True ): """ K2 ASR IterableDataset constructor. :param cuts: the ``CutSet`` to sample data from. :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut objects used to create that batch. :param cut_transforms: A list of transforms to be applied on each sampled batch, before converting cuts to an input representation (audio/features). Examples: cut concatenation, noise cuts mixing, etc. :param input_transforms: A list of transforms to be applied on each sampled batch, after the cuts are converted to audio/features. Examples: normalization, SpecAugment, etc. :param input_strategy: Converts cuts into a collated batch of audio/features. By default, reads pre-computed features from disk. :param check_inputs: Should we iterate over ``cuts`` to validate them. You might want to disable it when using "lazy" CutSets to avoid a very long start up time. """ super().__init__() # Initialize the fields self.cuts = cuts self.return_cuts = return_cuts self.cut_transforms = ifnone(cut_transforms, []) self.input_transforms = ifnone(input_transforms, []) self.input_strategy = input_strategy if check_inputs: validate_for_asr(self.cuts)
def __len__(self) -> int: return len(self.cuts) def __getitem__(self, cut_ids: List[str]) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the contraints of max_frames and max_cuts. """ # Collect the cuts that will form a batch, satisfying the criteria of max_cuts and max_frames. # The returned object is a CutSet that we can keep on modifying (e.g. padding, mixing, etc.) cuts = self.cuts.subset(cut_ids=cut_ids) # Sort the cuts by duration so that the first one determines the batch time dimensions. cuts = cuts.sort_by_duration(ascending=False) # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts # the supervision boundaries. for tnfm in self.cut_transforms: cuts = tnfm(cuts) # Get a tensor with batched feature matrices, shape (B, T, F) # Collation performs auto-padding, if necessary. inputs, _ = self.input_strategy(cuts) # Get a dict of tensors that encode the positional information about supervisions # in the batch of feature matrices. The tensors are named "sequence_idx", # "start_frame/sample" and "num_frames/samples". supervision_intervals = self.input_strategy.supervision_intervals(cuts) # Apply all available transforms on the inputs, i.e. either audio or features. # This could be feature extraction, global MVN, SpecAugment, etc. segments = torch.stack(list(supervision_intervals.values()), dim=1) for tnfm in self.input_transforms: inputs = tnfm(inputs, supervision_segments=segments) batch = { 'inputs': inputs, 'supervisions': default_collate([ { 'text': supervision.text, } for sequence_idx, cut in enumerate(cuts) for supervision in cut.supervisions ]) } # Update the 'supervisions' field with sequence_idx and start/num frames/samples batch['supervisions'].update(supervision_intervals) if self.return_cuts: batch['supervisions']['cut'] = [cut for cut in cuts for sup in cut.supervisions] return batch
[docs]def validate_for_asr(cuts: CutSet) -> None: validate(cuts) tol = 1e-3 # 1ms for cut in cuts: for supervision in cut.supervisions: assert supervision.start >= -tol, f"Supervisions starting before the cut are not supported for ASR" \ f" (sup id: {}, cut id: {})" assert supervision.duration <= cut.duration + tol, f"Supervisions ending after the cut " \ f"are not supported for ASR" \ f" (sup id: {}, cut id: {})"