Source code for lhotse.dataset.speech_recognition

from typing import Callable, Dict, List, Union

import torch
from torch.utils.data.dataloader import DataLoader, default_collate

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import InputStrategy, PrecomputedFeatures
from lhotse.utils import ifnone


[docs]class K2SpeechRecognitionDataset(torch.utils.data.Dataset): """ The PyTorch Dataset for the speech recognition task using K2 library. This dataset expects to be queried with lists of cut IDs, for which it loads features and automatically collates/batches them. To use it with a PyTorch DataLoader, set ``batch_size=None`` and provide a :class:`SingleCutSampler` sampler. Each item in this dataset is a dict of: .. code-block:: { 'inputs': float tensor with shape determined by :attr:`input_strategy`: - single-channel: - features: (B, T, F) - audio: (B, T) - multi-channel: currently not supported 'supervisions': [ { 'sequence_idx': Tensor[int] of shape (S,) 'text': List[str] of len S # For feature input strategies 'start_frame': Tensor[int] of shape (S,) 'num_frames': Tensor[int] of shape (S,) # For audio input strategies 'start_sample': Tensor[int] of shape (S,) 'num_samples': Tensor[int] of shape (S,) # Optionally, when return_cuts=True 'cut': List[AnyCut] of len S } ] } Dimension symbols legend: * ``B`` - batch size (number of Cuts) * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) * ``T`` - number of frames of the longest Cut * ``F`` - number of features The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. """
[docs] def __init__( self, cuts: CutSet, return_cuts: bool = False, cut_transforms: List[Callable[[CutSet], CutSet]] = None, input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, input_strategy: InputStrategy = PrecomputedFeatures(), check_inputs: bool = True ): """ K2 ASR IterableDataset constructor. :param cuts: the ``CutSet`` to sample data from. :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut objects used to create that batch. :param cut_transforms: A list of transforms to be applied on each sampled batch, before converting cuts to an input representation (audio/features). Examples: cut concatenation, noise cuts mixing, etc. :param input_transforms: A list of transforms to be applied on each sampled batch, after the cuts are converted to audio/features. Examples: normalization, SpecAugment, etc. :param input_strategy: Converts cuts into a collated batch of audio/features. By default, reads pre-computed features from disk. :param check_inputs: Should we iterate over ``cuts`` to validate them. You might want to disable it when using "lazy" CutSets to avoid a very long start up time. """ super().__init__() # Initialize the fields self.cuts = cuts self.return_cuts = return_cuts self.cut_transforms = ifnone(cut_transforms, []) self.input_transforms = ifnone(input_transforms, []) self.input_strategy = input_strategy if check_inputs: validate_for_asr(self.cuts)
def __len__(self) -> int: return len(self.cuts) def __getitem__(self, cut_ids: List[str]) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the contraints of max_frames and max_cuts. """ # Collect the cuts that will form a batch, satisfying the criteria of max_cuts and max_frames. # The returned object is a CutSet that we can keep on modifying (e.g. padding, mixing, etc.) cuts = self.cuts.subset(cut_ids=cut_ids) # Sort the cuts by duration so that the first one determines the batch time dimensions. cuts = cuts.sort_by_duration(ascending=False) # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts # the supervision boundaries. for tnfm in self.cut_transforms: cuts = tnfm(cuts) # Get a tensor with batched feature matrices, shape (B, T, F) # Collation performs auto-padding, if necessary. inputs, _ = self.input_strategy(cuts) # Get a dict of tensors that encode the positional information about supervisions # in the batch of feature matrices. The tensors are named "sequence_idx", # "start_frame/sample" and "num_frames/samples". supervision_intervals = self.input_strategy.supervision_intervals(cuts) # Apply all available transforms on the inputs, i.e. either audio or features. # This could be feature extraction, global MVN, SpecAugment, etc. segments = torch.stack(list(supervision_intervals.values()), dim=1) for tnfm in self.input_transforms: inputs = tnfm(inputs, supervision_segments=segments) batch = { 'inputs': inputs, 'supervisions': default_collate([ { 'text': supervision.text, } for sequence_idx, cut in enumerate(cuts) for supervision in cut.supervisions ]) } # Update the 'supervisions' field with sequence_idx and start/num frames/samples batch['supervisions'].update(supervision_intervals) if self.return_cuts: batch['supervisions']['cut'] = [cut for cut in cuts for sup in cut.supervisions] return batch
[docs]def validate_for_asr(cuts: CutSet) -> None: validate(cuts) tol = 1e-3 # 1ms for cut in cuts: for supervision in cut.supervisions: assert supervision.start >= -tol, f"Supervisions starting before the cut are not supported for ASR" \ f" (sup id: {supervision.id}, cut id: {cut.id})" assert supervision.duration <= cut.duration + tol, f"Supervisions ending after the cut " \ f"are not supported for ASR" \ f" (sup id: {supervision.id}, cut id: {cut.id})"