Source code for lhotse.dataset.vad

from typing import Callable, Dict, Sequence

import torch

from lhotse import validate
from lhotse.cut import CutSet
from lhotse.dataset.input_strategies import BatchIO, PrecomputedFeatures
from lhotse.utils import ifnone


[docs]class VadDataset(torch.utils.data.Dataset): """ The PyTorch Dataset for the voice activity detection task. Each item in this dataset is a dict of: .. code-block:: { 'inputs': (B x T x F) tensor 'input_lens': (B,) tensor 'is_voice': (T x 1) tensor 'cut': List[Cut] } """
[docs] def __init__( self, input_strategy: BatchIO = PrecomputedFeatures(), cut_transforms: Sequence[Callable[[CutSet], CutSet]] = None, input_transforms: Sequence[Callable[[torch.Tensor], torch.Tensor]] = None, ) -> None: super().__init__() self.input_strategy = input_strategy self.cut_transforms = ifnone(cut_transforms, []) self.input_transforms = ifnone(input_transforms, [])
def __getitem__(self, cuts: CutSet) -> Dict[str, torch.Tensor]: validate(cuts) cuts = cuts.sort_by_duration() for tfnm in self.cut_transforms: cuts = tfnm(cuts) inputs, input_lens = self.input_strategy(cuts) for tfnm in self.input_transforms: inputs = tfnm(inputs) return { "inputs": inputs, "input_lens": input_lens, "is_voice": self.input_strategy.supervision_masks(cuts), "cut": cuts, }