Source code for lhotse.dataset.speech_recognition

from typing import Callable, Dict, List, Union

import torch
from import DataLoader, default_collate

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import compute_num_frames, ifnone
from lhotse.workarounds import Hdf5MemoryIssueFix

[docs] class K2SpeechRecognitionDataset( """ The PyTorch Dataset for the speech recognition task using k2 library. This dataset expects to be queried with lists of cut IDs, for which it loads features and automatically collates/batches them. To use it with a PyTorch DataLoader, set ``batch_size=None`` and provide a :class:`SimpleCutSampler` sampler. Each item in this dataset is a dict of: .. code-block:: { 'inputs': float tensor with shape determined by :attr:`input_strategy`: - single-channel: - features: (B, T, F) - audio: (B, T) - multi-channel: currently not supported 'supervisions': [ { 'sequence_idx': Tensor[int] of shape (S,) 'text': List[str] of len S # For feature input strategies 'start_frame': Tensor[int] of shape (S,) 'num_frames': Tensor[int] of shape (S,) # For audio input strategies 'start_sample': Tensor[int] of shape (S,) 'num_samples': Tensor[int] of shape (S,) # Optionally, when return_cuts=True 'cut': List[AnyCut] of len S } ] } Dimension symbols legend: * ``B`` - batch size (number of Cuts) * ``S`` - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) * ``T`` - number of frames of the longest Cut * ``F`` - number of features The 'sequence_idx' field is the index of the Cut used to create the example in the Dataset. """
[docs] def __init__( self, return_cuts: bool = False, cut_transforms: List[Callable[[CutSet], CutSet]] = None, input_transforms: List[Callable[[torch.Tensor], torch.Tensor]] = None, input_strategy: BatchIO = PrecomputedFeatures(), ): """ k2 ASR IterableDataset constructor. :param return_cuts: When ``True``, will additionally return a "cut" field in each batch with the Cut objects used to create that batch. :param cut_transforms: A list of transforms to be applied on each sampled batch, before converting cuts to an input representation (audio/features). Examples: cut concatenation, noise cuts mixing, etc. :param input_transforms: A list of transforms to be applied on each sampled batch, after the cuts are converted to audio/features. Examples: normalization, SpecAugment, etc. :param input_strategy: Converts cuts into a collated batch of audio/features. By default, reads pre-computed features from disk. """ super().__init__() # Initialize the fields self.return_cuts = return_cuts self.cut_transforms = ifnone(cut_transforms, []) self.input_transforms = ifnone(input_transforms, []) self.input_strategy = input_strategy # This attribute is a workaround to constantly growing HDF5 memory # throughout the epoch. It regularly closes open file handles to # reset the internal HDF5 caches. self.hdf5_fix = Hdf5MemoryIssueFix(reset_interval=100)
def __getitem__(self, cuts: CutSet) -> Dict[str, Union[torch.Tensor, List[str]]]: """ Return a new batch, with the batch size automatically determined using the constraints of max_frames and max_cuts. """ validate_for_asr(cuts) self.hdf5_fix.update() # Sort the cuts by duration so that the first one determines the batch time dimensions. cuts = cuts.sort_by_duration(ascending=False) # Optional CutSet transforms - e.g. padding, or speed perturbation that adjusts # the supervision boundaries. for tnfm in self.cut_transforms: cuts = tnfm(cuts) # Sort the cuts again after transforms cuts = cuts.sort_by_duration(ascending=False) # Get a tensor with batched feature matrices, shape (B, T, F) # Collation performs auto-padding, if necessary. input_tpl = self.input_strategy(cuts) if len(input_tpl) == 3: # An input strategy with fault tolerant audio reading mode. # "cuts" may be a subset of the original "cuts" variable, # that only has cuts for which we succesfully read the audio. inputs, _, cuts = input_tpl else: inputs, _ = input_tpl # Get a dict of tensors that encode the positional information about supervisions # in the batch of feature matrices. The tensors are named "sequence_idx", # "start_frame/sample" and "num_frames/samples". supervision_intervals = self.input_strategy.supervision_intervals(cuts) # Apply all available transforms on the inputs, i.e. either audio or features. # This could be feature extraction, global MVN, SpecAugment, etc. segments = torch.stack(list(supervision_intervals.values()), dim=1) for tnfm in self.input_transforms: inputs = tnfm(inputs, supervision_segments=segments) batch = { "inputs": inputs, "supervisions": default_collate( [ { "text": supervision.text, } for sequence_idx, cut in enumerate(cuts) for supervision in cut.supervisions ] ), } # Update the 'supervisions' field with sequence_idx and start/num frames/samples batch["supervisions"].update(supervision_intervals) if self.return_cuts: batch["supervisions"]["cut"] = [ cut for cut in cuts for sup in cut.supervisions ] has_word_alignments = all( s.alignment is not None and "word" in s.alignment for c in cuts for s in c.supervisions ) if has_word_alignments: # TODO: might need to refactor BatchIO API to move the following conditional logic # into these objects (e.g. use like: self.input_strategy.convert_timestamp(), # that returns either num_frames or num_samples depending on the strategy). words, starts, ends = [], [], [] frame_shift = cuts[0].frame_shift sampling_rate = cuts[0].sampling_rate if frame_shift is None: try: frame_shift = self.input_strategy.extractor.frame_shift except AttributeError: raise ValueError( "Can't determine the frame_shift -- it is not present either in cuts or the input_strategy. " ) for c in cuts: for s in c.supervisions: words.append([aliword.symbol for aliword in s.alignment["word"]]) starts.append( [ compute_num_frames( aliword.start, frame_shift=frame_shift, sampling_rate=sampling_rate, ) for aliword in s.alignment["word"] ] ) ends.append( [ compute_num_frames( aliword.end, frame_shift=frame_shift, sampling_rate=sampling_rate, ) for aliword in s.alignment["word"] ] ) batch["supervisions"]["word"] = words batch["supervisions"]["word_start"] = starts batch["supervisions"]["word_end"] = ends return batch
[docs] def validate_for_asr(cuts: CutSet) -> None: validate(cuts) tol = 2e-3 # 1ms for cut in cuts: for supervision in cut.supervisions: assert supervision.start >= -tol, ( f"Supervisions starting before the cut are not supported for ASR" f" (sup id: {}, cut id: {})" ) # Supervision start time is relative to Cut ... # # # 'supervision.end' is end of supervision inside the Cut assert supervision.end <= cut.duration + tol, ( f"Supervisions ending after the cut " f"are not supported for ASR" f" (sup id: {}, cut id: {})" )