Source code for lhotse.dataset.signal_transforms

import bisect
import math
import random
from typing import Any, Dict, Optional, Sequence, Tuple, TypeVar, Union

import numpy as np
import torch

from lhotse import CutSet, FeatureExtractor
from lhotse.augmentation import dereverb_wpe_torch
from lhotse.utils import Pathlike

__all__ = ["GlobalMVN", "SpecAugment", "RandomizedSmoothing", "DereverbWPE"]


[docs] class GlobalMVN(torch.nn.Module): """Apply global mean and variance normalization"""
[docs] def __init__(self, feature_dim: int): super().__init__() self.feature_dim = feature_dim self.register_buffer("norm_means", torch.zeros(feature_dim)) self.register_buffer("norm_stds", torch.ones(feature_dim))
[docs] @classmethod def from_cuts( cls, cuts: CutSet, max_cuts: Optional[int] = None, extractor: Optional[FeatureExtractor] = None, ) -> "GlobalMVN": stats = cuts.compute_global_feature_stats( max_cuts=max_cuts, extractor=extractor ) stats = {name: torch.as_tensor(value) for name, value in stats.items()} (feature_dim,) = stats["norm_means"].shape global_mvn = cls(feature_dim) global_mvn.load_state_dict(stats) return global_mvn
[docs] @classmethod def from_file(cls, stats_file: Pathlike) -> "GlobalMVN": stats = torch.load(stats_file) (feature_dim,) = stats["norm_means"].shape global_mvn = cls(feature_dim) global_mvn.load_state_dict(stats) return global_mvn
[docs] def to_file(self, stats_file: Pathlike): torch.save(self.state_dict(), stats_file)
[docs] def forward( self, features: torch.Tensor, supervision_segments: Optional[torch.IntTensor] = None, ) -> torch.Tensor: return (features - self.norm_means) / self.norm_stds
[docs] def inverse(self, features: torch.Tensor) -> torch.Tensor: return features * self.norm_stds + self.norm_means
[docs] class RandomizedSmoothing(torch.nn.Module): """ Randomized smoothing - gaussian noise added to an input waveform, or a batch of waveforms. The summed audio is clipped to ``[-1.0, 1.0]`` before returning. """
[docs] def __init__( self, sigma: Union[float, Sequence[Tuple[int, float]]] = 0.1, sample_sigma: bool = True, p: float = 0.3, ): """ RandomizedSmoothing's constructor. :param sigma: standard deviation of the gaussian noise. Either a constant float, or a schedule, i.e. a list of tuples that specify which value to use from which step. For example, ``[(0, 0.01), (1000, 0.1)]`` means that from steps 0-999, the sigma value will be 0.01, and from step 1000 onwards, it will be 0.1. :param sample_sigma: when ``False``, then sigma is used as the standard deviation in each forward step. When ``True``, the standard deviation is sampled from a uniform distribution of ``[-sigma, sigma]`` for each forward step. :param p: the probability of applying this transform. """ super().__init__() self.sigma = sigma self.sample_sigma = sample_sigma self.p = p self.step = 0
[docs] def forward(self, audio: torch.Tensor, *args, **kwargs) -> torch.Tensor: # Determine the stddev value if isinstance(self.sigma, float): # Use a constant stddev value sigma = self.sigma else: # Determine the right stddev value from a given schedule. sigma = schedule_value_for_step(self.sigma, self.step) self.step += 1 if self.sample_sigma: # In this mode stddev is stochastic itself # and is sampled from uniform distribution bounded by [-sigma, sigma] . mask_shape = (audio.shape[0],) + tuple(1 for _ in audio.shape[1:]) # Sigma is of shape (batch_size, 1) - different for each noise example. sigma = sigma * (2 * torch.rand(mask_shape) - 1) # Create the random noise examples with identical sigma's. noise = sigma * torch.randn_like(audio) # Apply the transform with a probability p -> mask noise examples with probability 1 - p. noise_mask = random_mask_along_batch_axis(noise, p=1.0 - self.p) noise = noise * noise_mask return torch.clip(audio + noise, min=-1.0, max=1.0)
[docs] class SpecAugment(torch.nn.Module): """ SpecAugment performs three augmentations: - time warping of the feature matrix - masking of ranges of features (frequency bands) - masking of ranges of frames (time) The current implementation works with batches, but processes each example separately in a loop rather than simultaneously to achieve different augmentation parameters for each example. """
[docs] def __init__( self, time_warp_factor: Optional[int] = 80, num_feature_masks: int = 2, features_mask_size: int = 27, num_frame_masks: int = 10, frames_mask_size: int = 100, max_frames_mask_fraction: float = 0.15, p=0.9, ): """ SpecAugment's constructor. :param time_warp_factor: parameter for the time warping; larger values mean more warping. Set to ``None``, or less than ``1``, to disable. :param num_feature_masks: how many feature masks should be applied. Set to ``0`` to disable. :param features_mask_size: the width of the feature mask (expressed in the number of masked feature bins). This is the ``F`` parameter from the SpecAugment paper. :param num_frame_masks: the number of masking regions for utterances. Set to ``0`` to disable. :param frames_mask_size: the width of the frame (temporal) masks (expressed in the number of masked frames). This is the ``T`` parameter from the SpecAugment paper. :param max_frames_mask_fraction: limits the size of the frame (temporal) mask to this value times the length of the utterance (or supervision segment). This is the parameter denoted by ``p`` in the SpecAugment paper. :param p: the probability of applying this transform. It is different from ``p`` in the SpecAugment paper! """ super().__init__() assert 0 <= p <= 1 assert num_feature_masks >= 0 assert num_frame_masks >= 0 assert features_mask_size > 0 assert frames_mask_size > 0 self.time_warp_factor = time_warp_factor self.num_feature_masks = num_feature_masks self.features_mask_size = features_mask_size self.num_frame_masks = num_frame_masks self.frames_mask_size = frames_mask_size self.max_frames_mask_fraction = max_frames_mask_fraction self.p = p
[docs] def forward( self, features: torch.Tensor, supervision_segments: Optional[torch.IntTensor] = None, *args, **kwargs, ) -> torch.Tensor: """ Computes SpecAugment for a batch of feature matrices. Since the batch will usually already be padded, the user can optionally provide a ``supervision_segments`` tensor that will be used to apply SpecAugment only to selected areas of the input. The format of this input is described below. :param features: a batch of feature matrices with shape ``(B, T, F)``. :param supervision_segments: an int tensor of shape ``(S, 3)``. ``S`` is the number of supervision segments that exist in ``features`` -- there may be either less or more than the batch size. The second dimension encoder three kinds of information: the sequence index of the corresponding feature matrix in `features`, the start frame index, and the number of frames for each segment. :return: an augmented tensor of shape ``(B, T, F)``. """ assert len(features.shape) == 3, ( "SpecAugment only supports batches of " "single-channel feature matrices." ) features = features.clone() if supervision_segments is None: # No supervisions - apply spec augment to full feature matrices. for sequence_idx in range(features.size(0)): features[sequence_idx] = self._forward_single(features[sequence_idx]) else: # Supervisions provided - we will apply time warping only on the supervised areas. for sequence_idx, start_frame, num_frames in supervision_segments: end_frame = start_frame + num_frames features[sequence_idx, start_frame:end_frame] = self._forward_single( features[sequence_idx, start_frame:end_frame], warp=True, mask=False ) # ... and then time-mask the full feature matrices. Note that in this mode, # it might happen that masks are applied to different sequences/examples # than the time warping. for sequence_idx in range(features.size(0)): features[sequence_idx] = self._forward_single( features[sequence_idx], warp=False, mask=True ) return features
def _forward_single( self, features: torch.Tensor, warp: bool = True, mask: bool = True ) -> torch.Tensor: """ Apply SpecAugment to a single feature matrix of shape (T, F). """ if random.random() > self.p: # Randomly choose whether this transform is applied return features if warp: if self.time_warp_factor is not None and self.time_warp_factor >= 1: features = time_warp(features, factor=self.time_warp_factor) if mask: mean = features.mean() # Frequency masking features = mask_along_axis_optimized( features, mask_size=self.features_mask_size, mask_times=self.num_feature_masks, mask_value=mean, axis=2, ) # Time masking max_tot_mask_frames = self.max_frames_mask_fraction * features.size(0) num_frame_masks = min( self.num_frame_masks, math.ceil(max_tot_mask_frames / self.frames_mask_size), ) max_mask_frames = min( self.frames_mask_size, max_tot_mask_frames // num_frame_masks ) features = mask_along_axis_optimized( features, mask_size=max_mask_frames, mask_times=num_frame_masks, mask_value=mean, axis=1, ) return features
[docs] def state_dict(self, **kwargs) -> Dict[str, Any]: return dict( time_warp_factor=self.time_warp_factor, num_feature_masks=self.num_feature_masks, features_mask_size=self.features_mask_size, num_frame_masks=self.num_frame_masks, frames_mask_size=self.frames_mask_size, max_frames_mask_fraction=self.max_frames_mask_fraction, p=self.p, )
[docs] def load_state_dict(self, state_dict: Dict[str, Any]): self.time_warp_factor = state_dict.get( "time_warp_factor", self.time_warp_factor ) self.num_feature_masks = state_dict.get( "num_feature_masks", self.num_feature_masks ) self.features_mask_size = state_dict.get( "features_mask_size", self.features_mask_size ) self.num_frame_masks = state_dict.get("num_frame_masks", self.num_frame_masks) self.frames_mask_size = state_dict.get( "frames_mask_size", self.frames_mask_size ) self.max_frames_mask_fraction = state_dict.get( "max_frames_mask_fraction", self.max_frames_mask_fraction ) self.p = state_dict.get("p", self.p)
def mask_along_axis_optimized( features: torch.Tensor, mask_size: int, mask_times: int, mask_value: float, axis: int, ) -> torch.Tensor: """ Apply Frequency and Time masking along axis. Frequency and Time masking as described in the SpecAugment paper. :param features: input tensor of shape ``(T, F)`` :mask_size: the width size for masking. :mask_times: the number of masking regions. :mask_value: Value to assign to the masked regions. :axis: Axis to apply masking on (1 -> time, 2 -> frequency) """ if axis not in [1, 2]: raise ValueError("Only Frequency and Time masking are supported!") features = features.unsqueeze(0) features = features.reshape([-1] + list(features.size()[-2:])) values = torch.randint(int(0), int(mask_size), (1, mask_times)) min_values = torch.rand(1, mask_times) * (features.size(axis) - values) mask_starts = (min_values.long()).squeeze() mask_ends = (min_values.long() + values.long()).squeeze() if axis == 1: if mask_times == 1: features[:, mask_starts:mask_ends] = mask_value return features.squeeze(0) for (mask_start, mask_end) in zip(mask_starts, mask_ends): features[:, mask_start:mask_end] = mask_value else: if mask_times == 1: features[:, :, mask_starts:mask_ends] = mask_value return features.squeeze(0) for (mask_start, mask_end) in zip(mask_starts, mask_ends): features[:, :, mask_start:mask_end] = mask_value features = features.squeeze(0) return features def time_warp(features: torch.Tensor, factor: int) -> torch.Tensor: """ Time warping as described in the SpecAugment paper. Implementation based on Espresso: https://github.com/freewym/espresso/blob/master/espresso/tools/specaug_interpolate.py#L51 :param features: input tensor of shape ``(T, F)`` :param factor: time warping parameter. :return: a warped tensor of shape ``(T, F)`` """ t = features.size(0) if t - factor <= factor + 1: return features center = np.random.randint(factor + 1, t - factor) warped = np.random.randint(center - factor, center + factor + 1) if warped == center: return features features = features.unsqueeze(0).unsqueeze(0) left = torch.nn.functional.interpolate( features[:, :, :center, :], size=(warped, features.size(3)), mode="bicubic", align_corners=False, ) right = torch.nn.functional.interpolate( features[:, :, center:, :], size=(t - warped, features.size(3)), mode="bicubic", align_corners=False, ) return torch.cat((left, right), dim=2).squeeze(0).squeeze(0) T = TypeVar("T") def schedule_value_for_step(schedule: Sequence[Tuple[int, T]], step: int) -> T: milestones, values = zip(*schedule) assert milestones[0] <= step, ( f"Cannot determine the scheduled value for step {step} with schedule: {schedule}. " f"Did you forget to add the first part of the schedule " f"for steps below {milestones[0]}?" ) idx = bisect.bisect_right(milestones, step) - 1 return values[idx] def random_mask_along_batch_axis(tensor: torch.Tensor, p: float = 0.5) -> torch.Tensor: """ For a given tensor with shape (N, d1, d2, d3, ...), returns a mask with shape (N, 1, 1, 1, ...), that randomly masks the samples in a batch. E.g. for a 2D input matrix it looks like: >>> [[0., 0., 0., ...], ... [1., 1., 1., ...], ... [0., 0., 0., ...]] :param tensor: the input tensor. :param p: the probability of masking an element. """ mask_shape = (tensor.shape[0],) + tuple(1 for _ in tensor.shape[1:]) mask = (torch.rand(mask_shape) > p).to(torch.float32) return mask
[docs] class DereverbWPE(torch.nn.Module): """ Dereverberation with Weighted Prediction Error (WPE). The implementation and default values are borrowed from `nara_wpe` package: https://github.com/fgnt/nara_wpe The method and library are described in the following paper: https://groups.uni-paderborn.de/nt/pubs/2018/ITG_2018_Drude_Paper.pdf """
[docs] def __init__(self, n_fft: int = 512, hop_length: int = 128): super().__init__() self.n_fft = n_fft self.hop_length = hop_length
[docs] def forward(self, audio: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Expects audio to be 2D or 3D tensor. 2D means a batch of single-channel audio, shape (B, T). 3D means a batch of multi-channel audio, shape (B, D, T). B => batch size; D => number of channels; T => number of audio samples. """ # Assume batch of single-channel data: apply dereverb to each example independently. if audio.ndim == 2: return torch.cat( [ dereverb_wpe_torch( a.unsqueeze(0), n_fft=self.n_fft, hop_length=self.hop_length ) for a in audio ], dim=0, ) # Assume batch of multi-channel data: each example has D channels. assert audio.ndim == 3 return torch.stack( [ dereverb_wpe_torch(a, n_fft=self.n_fft, hop_length=self.hop_length) for a in audio ], dim=0, )