PyTorch Datasets

Lhotse supports PyTorch’s dataset API, providing implementations for the Dataset and Sampler concepts. They can be used together with the standard DataLoader class for efficient mini-batch collection with multiple parallel readers and pre-fetching.

A quick re-cap of PyTorch’s data API

PyTorch defines the Dataset class that is responsible for reading the data from disk/memory/Internet/database/etc., and converting it to tensors that can be used for network training or inference. These Dataset’s are typically „map-style” datasets which are given an index (or a list of indices) and return the corresponding data samples.

The selection of indices is performed by the Sampler class. Sampler, knowing the length (number of items) in a Dataset, can use various strategies to determine the order of elements to read (e.g. sequential reads, or random reads).

More details about the data pipeline API in PyTorch can be found here.

About Lhotse’s Datasets and Samplers

Lhotse provides a number of utilities that make it simpler to define Dataset’s for speech processing tasks. CutSet is the base data structure that is used to initialize the Dataset class. This makes it possible to manipulate the speech data in convenient ways - pad, mix, concatenate, augment, compute features, look up the supervision information, etc.

Lhotse’s Dataset’s will perform batching by themselves, because auto-collation in DataLoader is too limiting for speech data handling. These Dataset’s expect to be handed lists of element indices, so that they can collate the data before it is passed to the DataLoader (which must use batch_size=None). It allows for interesting collation methods - e.g. padding the speech with noise recordings, or actual acoustic context, rather than artificial zeroes; or dynamic batch sizes.

The items for mini-batch creation are selected by the Sampler. Lhotse defines Sampler classes that are initialized with CutSet’s, so that they can look up specific properties of an utterance to stratify the sampling. For example, SingleCutSampler has a defined max_frames attribute, and it will keep sampling cuts for a batch until they do not exceed the specified number of frames. Another strategy — used in BucketingSampler — will first group the cuts of similar durations into buckets, and then randomly select a bucket to draw the whole batch from.

For tasks where both input and output of the model are speech utterances, we can use the CutPairsSampler, which accepts two CutSet’s and will match the cuts in them by their IDs.

A typical Lhotse’s dataset API usage might look like this:

from torch.utils.data import DataLoader
from lhotse.dataset import SpeechRecognitionDataset, SingleCutSampler

cuts = CutSet(...)
dset = SpeechRecognitionDataset(cuts)
sampler = SingleCutSampler(cuts, max_frames=50000)
# Dataset performs batching by itself, so we have to indicate that
# to the DataLoader with batch_size=None
dloader = DataLoader(dset, sampler=sampler, batch_size=None, num_workers=1)
for batch in dloader:
    ...  # process data

Pre-computed vs. on-the-fly: input strategies

Depending on the experimental setup and infrastructure, it might be more convenient to either pre-compute and store features like filter-bank energies for later use (as traditionally done in Kaldi/ESPnet/Espresso toolkits), or compute them dynamically during training (“on-the-fly”). Lhotse supports both modes of computation by introducing a class called InputStrategy. It is accepted as an argument in most dataset classes, and defaults to PrecomputedFeatures. Other available choices are AudioSamples for working with waveforms directly, and OnTheFlyFeatures, which wraps a FeatureExtractor and applies it to a batch of recordings. These strategies automatically pad and collate the inputs, and provide information about the original signal lengths: as a number of frames/samples, binary mask, or start-end frame/sample pairs.

Which strategy to choose?

In general, pre-computed features can be greatly compressed (we achieve 70% size reduction with regard to un-compressed features), and so the I/O load on your computing infrastructure will be much smaller than if you read the recordings directly. This is especially valuable when working with network file systems (NFS) that are typically used in computational grids for storage. When your experiment is I/O bound, then it is best to use pre-computed features.

When I/O is not the issue, it might be preferable to use on-the-fly computation as it shouldn’t require any prior steps to perform the network training. It is also simpler to apply a vast range of data augmentation methods in a fully randomized way (e.g. reverberation), although Lhotse provides support for approximate feature-domain signal mixing (e.g. for additive noise augmentation) to alleviate that to some extent.

Dataset’s list

class lhotse.dataset.diarization.DiarizationDataset(cuts, uem=None, min_speaker_dim=None, global_speaker_ids=False)[source]

A PyTorch Dataset for the speaker diarization task. Our assumptions about speaker diarization are the following:

  • we assume a single channel input (for now), which could be either a true mono signal

    or a beamforming result from a microphone array.

  • we assume that the supervision used for model training is a speech activity matrix, with one

    row dedicated to each speaker (either in the current cut or the whole dataset, depending on the settings). The columns correspond to feature frames. Each row is effectively a Voice Activity Detection supervision for a single speaker. This setup is somewhat inspired by the TS-VAD paper: https://arxiv.org/abs/2005.07272

Each item in this dataset is a dict of:

{
    'features': (B x T x F) tensor
    'features_lens': (B, ) tensor
    'speaker_activity': (B x num_speaker x T) tensor
}

Constructor arguments:

Parameters
  • cuts (CutSet) – a CutSet used to create the dataset object.

  • uem (Optional[SupervisionSet]) – a SupervisionSet used to set regions for diarization

  • min_speaker_dim (Optional[int]) – optional int, when specified it will enforce that the matrix shape is at least that value (useful for datasets like CHiME 6 where the number of speakers is always 4, but some cuts might have less speakers than that).

  • global_speaker_ids (bool) – a bool, indicates whether the same speaker should always retain the same row index in the speaker activity matrix (useful for speaker-dependent systems)

  • root_dir – a prefix path to be attached to the feature files paths.

__init__(cuts, uem=None, min_speaker_dim=None, global_speaker_ids=False)[source]

Initialize self. See help(type(self)) for accurate signature.

class lhotse.dataset.unsupervised.UnsupervisedDataset(cuts)[source]

Dataset that contains no supervision - it only provides the features extracted from recordings.

{
    'features': (B x T x F) tensor
    'features_lens': (B, ) tensor
}
__init__(cuts)[source]

Initialize self. See help(type(self)) for accurate signature.

class lhotse.dataset.unsupervised.UnsupervisedWaveformDataset(cuts)[source]

A variant of UnsupervisedDataset that provides waveform samples instead of features. The output is a tensor of shape (C, T), with C being the number of channels and T the number of audio samples. In this implemenation, there will always be a single channel.

Returns:

{
    'audio': (B x NumSamples) float tensor
    'audio_lens': (B, ) int tensor
}
class lhotse.dataset.unsupervised.DynamicUnsupervisedDataset(feature_extractor, cuts, augment_fn=None)[source]

An example dataset that shows how to use on-the-fly feature extraction in Lhotse. It accepts two additional inputs - a FeatureExtractor and an optional WavAugmenter for time-domain data augmentation.. The output is approximately the same as that of the UnsupervisedDataset - there might be slight differences for MixedCut``s, because this dataset mixes them in the time domain, and ``UnsupervisedDataset does that in the feature domain. Cuts that are not mixed will yield identical results in both dataset classes.

__init__(feature_extractor, cuts, augment_fn=None)[source]

Initialize self. See help(type(self)) for accurate signature.

class lhotse.dataset.speech_recognition.K2SpeechRecognitionDataset(cuts, return_cuts=False, cut_transforms=None, input_transforms=None, input_strategy=<lhotse.dataset.input_strategies.PrecomputedFeatures object>, check_inputs=True)[source]

The PyTorch Dataset for the speech recognition task using K2 library.

This dataset expects to be queried with lists of cut IDs, for which it loads features and automatically collates/batches them.

To use it with a PyTorch DataLoader, set batch_size=None and provide a SingleCutSampler sampler.

Each item in this dataset is a dict of:

{
    'inputs': float tensor with shape determined by :attr:`input_strategy`:
              - single-channel:
                - features: (B, T, F)
                - audio: (B, T)
              - multi-channel: currently not supported
    'supervisions': [
        {
            'sequence_idx': Tensor[int] of shape (S,)
            'text': List[str] of len S

            # For feature input strategies
            'start_frame': Tensor[int] of shape (S,)
            'num_frames': Tensor[int] of shape (S,)

            # For audio input strategies
            'start_sample': Tensor[int] of shape (S,)
            'num_samples': Tensor[int] of shape (S,)

            # Optionally, when return_cuts=True
            'cut': List[AnyCut] of len S
        }
    ]
}

Dimension symbols legend: * B - batch size (number of Cuts) * S - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) * T - number of frames of the longest Cut * F - number of features

The ‘sequence_idx’ field is the index of the Cut used to create the example in the Dataset.

__init__(cuts, return_cuts=False, cut_transforms=None, input_transforms=None, input_strategy=<lhotse.dataset.input_strategies.PrecomputedFeatures object>, check_inputs=True)[source]

K2 ASR IterableDataset constructor.

Parameters
  • cuts (CutSet) – the CutSet to sample data from.

  • return_cuts (bool) – When True, will additionally return a “cut” field in each batch with the Cut objects used to create that batch.

  • cut_transforms (Optional[List[Callable[[CutSet], CutSet]]]) – A list of transforms to be applied on each sampled batch, before converting cuts to an input representation (audio/features). Examples: cut concatenation, noise cuts mixing, etc.

  • input_transforms (Optional[List[Callable[[Tensor], Tensor]]]) – A list of transforms to be applied on each sampled batch, after the cuts are converted to audio/features. Examples: normalization, SpecAugment, etc.

  • input_strategy (InputStrategy) – Converts cuts into a collated batch of audio/features. By default, reads pre-computed features from disk.

  • check_inputs (bool) – Should we iterate over cuts to validate them. You might want to disable it when using “lazy” CutSets to avoid a very long start up time.

lhotse.dataset.speech_recognition.validate_for_asr(cuts)[source]
Return type

None

lhotse.dataset.speech_synthesis

alias of lhotse.dataset.speech_synthesis

class lhotse.dataset.source_separation.DynamicallyMixedSourceSeparationDataset(sources_set, mixtures_set, nonsources_set=None)[source]

A PyTorch Dataset for the source separation task. It’s created from a number of CutSets:

  • sources_set: provides the audio cuts for the sources that (the targets of source separation),

  • mixtures_set: provides the audio cuts for the signal mix (the input of source separation),

  • nonsources_set: (optional) provides the audio cuts for other signals that are in the mix, but are not the targets of source separation. Useful for adding noise.

When queried for data samples, it returns a dict of:

{
    'sources': (N x T x F) tensor,
    'mixture': (T x F) tensor,
    'real_mask': (N x T x F) tensor,
    'binary_mask': (T x F) tensor
}

This Dataset performs on-the-fly feature-domain mixing of the sources. It expects the mixtures_set to contain MixedCuts, so that it knows which Cuts should be mixed together.

__init__(sources_set, mixtures_set, nonsources_set=None)[source]

Initialize self. See help(type(self)) for accurate signature.

validate()[source]
class lhotse.dataset.source_separation.PreMixedSourceSeparationDataset(sources_set, mixtures_set)[source]

A PyTorch Dataset for the source separation task. It’s created from two CutSets - one provides the audio cuts for the sources, and the other one the audio cuts for the signal mix. When queried for data samples, it returns a dict of:

{
    'sources': (N x T x F) tensor,
    'mixture': (T x F) tensor,
    'real_mask': (N x T x F) tensor,
    'binary_mask': (T x F) tensor
}

It expects both CutSets to return regular Cuts, meaning that the signals were mixed in the time domain. In contrast to DynamicallyMixedSourceSeparationDataset, no on-the-fly feature-domain-mixing is performed.

__init__(sources_set, mixtures_set)[source]

Initialize self. See help(type(self)) for accurate signature.

class lhotse.dataset.vad.VadDataset(cuts, input_strategy=<lhotse.dataset.input_strategies.PrecomputedFeatures object>, cut_transforms=None, input_transforms=None)[source]

The PyTorch Dataset for the voice activity detection task. Each item in this dataset is a dict of:

{
    'inputs': (B x T x F) tensor
    'input_lens': (B,) tensor
    'is_voice': (T x 1) tensor
    'cut': List[Cut]
}
__init__(cuts, input_strategy=<lhotse.dataset.input_strategies.PrecomputedFeatures object>, cut_transforms=None, input_transforms=None)[source]

Initialize self. See help(type(self)) for accurate signature.

Sampler’s list

class lhotse.dataset.sampling.DataSource(items)[source]
__init__(items)[source]

Initialize self. See help(type(self)) for accurate signature.

shuffle(seed)[source]
class lhotse.dataset.sampling.CutSampler(cut_ids, shuffle=False, world_size=None, rank=None, seed=0, provide_len=True)[source]

CutSampler is responsible for collecting batches of cuts, given specified criteria. It implements correct handling of distributed sampling in DataLoader, so that the cuts are not duplicated across workers.

Sampling in a CutSampler is intended to be very quick - it only uses the metadata in CutSet manifest to select the cuts, and is not intended to perform any I/O.

CutSampler works similarly to PyTorch’s DistributedSampler - when shuffle=True, you should call sampler.set_epoch(epoch) at each new epoch to have a different ordering of returned elements. However, its actual behaviour is different than that of DistributedSampler – instead of partitioning the underlying cuts into equally sized chunks, it will return every N-th batch and skip the other batches (where N == world_size). The formula used to determine which batches are returned is: (batch_idx + (world_size - rank)) % world_size == 0. This ensures that we can return an equal number of batches in all distributed workers in spite of using a dynamic batch size, at the cost of skipping at most world_size - 1 batches.

Example usage:

>>> dataset = K2SpeechRecognitionDataset(cuts)
>>> sampler = SingleCutSampler(cuts, shuffle=True)
>>> loader = DataLoader(dataset, sampler=sampler, batch_size=None)
>>> for epoch in range(start_epoch, n_epochs):
...     sampler.set_epoch(epoch)
...     train(loader)

Note

For implementers of new samplers: Subclasses of CutSampler are expected to implement self._next_batch() to introduce specific sampling logic (e.g. based on filters such as max number of frames/tokens/etc.). CutSampler defines __iter__(), which optionally shuffles the cut IDs, and resets self.cut_idx to zero (to be used and incremented inside of _next_batch().

__init__(cut_ids, shuffle=False, world_size=None, rank=None, seed=0, provide_len=True)[source]
Parameters
  • cut_ids (Iterable[str]) – An iterable of cut IDs for the full dataset. CutSampler will take care of partitioning that into distributed workers (if needed).

  • shuffle (bool) – When True, the cuts will be shuffled at the start of iteration. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: for epoch in range(10): for batch in dataset: … as every epoch will see a different cuts order.

  • world_size (Optional[int]) – Total number of distributed nodes. We will try to infer it by default.

  • rank (Optional[int]) – Index of distributed node. We will try to infer it by default.

  • seed (int) – Random seed used to consistently shuffle the dataset across different processes.

  • provide_len (bool) – Should we expose the __len__ attribute in this class. It makes sense to turn it off when iterating the sampler is somewhat costly for any reason; e.g. because the underlying manifest is lazily loaded from the filesystem/somewhere else.

set_epoch(epoch)[source]

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

Return type

None

filter(predicate)[source]

Add a constraint on invidual cuts that has to be satisfied to consider them.

Can be useful when handling large, lazy manifests where it is not feasible to pre-filter them before instantiating the sampler.

When set, we will remove the __len__ attribute on the sampler, as it is now determined dynamically.

Example:
>>> cuts = CutSet(...)
... sampler = SingleCutSampler(cuts, max_duration=100.0)
... # Retain only the cuts that have at least 1s and at most 20s duration.
... sampler.filter(lambda cut: 1.0 <= cut.duration <= 20.0)
Return type

None

class lhotse.dataset.sampling.TimeConstraint(max_duration: Optional[float] = None, max_samples: Optional[int] = None, max_frames: Optional[int] = None, current: Union[int, float] = 0)[source]

Represents a time-based constraint for sampler classes. It can be defined either as maximum total batch duration (in seconds), number of frames, or number of samples. These options are mutually exclusive and this class checks for that.

TimeConstraint can be used for tracking whether the criterion has been exceeded via the add(cut), exceeded() and reset() methods. It will automatically track the right criterion (i.e. select frames/samples/duration from the cut). It can also be a null constraint (never exceeded).

max_duration: Optional[float] = None
max_samples: Optional[int] = None
max_frames: Optional[int] = None
current: Union[int, float] = 0
is_active()[source]

Is it an actual constraint, or a dummy one (i.e. never exceeded).

Return type

bool

add(cut)[source]

Increment the internal counter for the time constraint, selecting the right property from the input cut object.

Return type

None

exceeded()[source]

Is the constraint exceeded or not.

Return type

bool

reset()[source]

Reset the internal counter (to be used after a batch was created, to start collecting a new one).

Return type

None

__init__(max_duration=None, max_samples=None, max_frames=None, current=0)

Initialize self. See help(type(self)) for accurate signature.

class lhotse.dataset.sampling.SingleCutSampler(cuts, max_frames=None, max_samples=None, max_duration=None, max_cuts=None, **kwargs)[source]

Samples cuts from a CutSet to satisfy the input constraints. It behaves like an iterable that yields lists of strings (cut IDs).

When one of max_frames, max_samples, or max_duration is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. Padding required to collate the batch does not contribute to max frames/samples/duration.

Example usage:

>>> dataset = K2SpeechRecognitionDataset(cuts)
>>> sampler = SingleCutSampler(cuts, shuffle=True)
>>> loader = DataLoader(dataset, sampler=sampler, batch_size=None)
>>> for epoch in range(start_epoch, n_epochs):
...     sampler.set_epoch(epoch)
...     train(loader)
__init__(cuts, max_frames=None, max_samples=None, max_duration=None, max_cuts=None, **kwargs)[source]

SingleCutSampler’s constructor.

Parameters
  • cuts (CutSet) – the CutSet to sample data from.

  • max_frames (Optional[int]) – The maximum total number of feature frames from cuts.

  • max_samples (Optional[int]) – The maximum total number of audio samples from cuts.

  • max_duration (Optional[float]) – The maximum total recording duration from cuts.

  • max_cuts (Optional[int]) – The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off.

  • kwargs – Arguments to be passed into CutSampler.

class lhotse.dataset.sampling.CutPairsSampler(source_cuts, target_cuts, max_source_frames=None, max_source_samples=None, max_source_duration=None, max_target_frames=None, max_target_samples=None, max_target_duration=None, max_cuts=None, **kwargs)[source]

Samples pairs of cuts from a “source” and “target” CutSet. It expects that both CutSet’s strictly consist of Cuts with corresponding IDs. It behaves like an iterable that yields lists of strings (cut IDs).

When one of max_frames, max_samples, or max_duration is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. Padding required to collate the batch does not contribute to max frames/samples/duration.

__init__(source_cuts, target_cuts, max_source_frames=None, max_source_samples=None, max_source_duration=None, max_target_frames=None, max_target_samples=None, max_target_duration=None, max_cuts=None, **kwargs)[source]

CutPairsSampler’s constructor.

Parameters
  • source_cuts (CutSet) – the first CutSet to sample data from.

  • target_cuts (CutSet) – the second CutSet to sample data from.

  • max_source_frames (Optional[int]) – The maximum total number of feature frames from source_cuts.

  • max_source_samples (Optional[int]) – The maximum total number of audio samples from source_cuts.

  • max_source_duration (Optional[float]) – The maximum total recording duration from source_cuts.

  • max_target_frames (Optional[int]) – The maximum total number of feature frames from target_cuts.

  • max_target_samples (Optional[int]) – The maximum total number of audio samples from target_cuts.

  • max_target_duration (Optional[int]) – The maximum total recording duration from target_cuts.

  • max_cuts (Optional[int]) – The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off.

class lhotse.dataset.sampling.BucketingSampler(*cuts, sampler_type=<class 'lhotse.dataset.sampling.SingleCutSampler'>, num_buckets=10, seed=0, **kwargs)[source]

Sorts the cuts in a CutSet by their duration and puts them into similar duration buckets. For each bucket, it instantiates a simpler sampler instance, e.g. SingleCutSampler.

It behaves like an iterable that yields lists of strings (cut IDs). During iteration, it randomly selects one of the buckets to yield the batch from, until all the underlying samplers are depleted (which means it’s the end of an epoch).

Examples:

Bucketing sampler with 20 buckets, sampling single cuts:

>>> sampler = BucketingSampler(
...    cuts,
...    # BucketingSampler specific args
...    sampler_type=SingleCutSampler, num_buckets=20,
...    # Args passed into SingleCutSampler
...    max_frames=20000
... )

Bucketing sampler with 20 buckets, sampling pairs of source-target cuts:

>>> sampler = BucketingSampler(
...    cuts, target_cuts,
...    # BucketingSampler specific args
...    sampler_type=CutPairsSampler, num_buckets=20,
...    # Args passed into CutPairsSampler
...    max_source_frames=20000, max_target_frames=15000
... )
__init__(*cuts, sampler_type=<class 'lhotse.dataset.sampling.SingleCutSampler'>, num_buckets=10, seed=0, **kwargs)[source]

BucketingSampler’s constructor.

Parameters
  • cuts (CutSet) – one or more CutSet objects. The first one will be used to determine the buckets for all of them. Then, all of them will be used to instantiate the per-bucket samplers.

  • sampler_type (Type) – a sampler type that will be created for each underlying bucket.

  • num_buckets (int) – how many buckets to create.

  • seed (int) – random seed for bucket selection

  • kwargs – Arguments used to create the underlying sampler for each bucket.

set_epoch(epoch)[source]

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters

epoch (int) – Epoch number.

Return type

None

filter(predicate)[source]

Add a constraint on invidual cuts that has to be satisfied to consider them.

Can be useful when handling large, lazy manifests where it is not feasible to pre-filter them before instantiating the sampler.

When set, we will remove the __len__ attribute on the sampler, as it is now determined dynamically.

Example:
>>> cuts = CutSet(...)
... sampler = SingleCutSampler(cuts, max_duration=100.0)
... # Retain only the cuts that have at least 1s and at most 20s duration.
... sampler.filter(lambda cut: 1.0 <= cut.duration <= 20.0)
Return type

None

property is_depleted
Return type

bool

Input strategies’ list

class lhotse.dataset.input_strategies.InputStrategy[source]

Converts a CutSet into a collated batch of audio representations. These representations can be e.g. audio samples or features. They might also be single or multi channel.

This is a base class that only defines the interface.

__call__(cuts)[source]

Returns a tensor with collated input signals, and a tensor of length of each signal before padding.

Return type

Tuple[Tensor, IntTensor]

supervision_intervals(cuts)[source]

Returns a dict that specifies the start and end bounds for each supervision, as a 1-D int tensor.

Depending on the strategy, the dict should look like:

or

Where S is the total number of supervisions encountered in the CutSet. Note that S might be different than the number of cuts (B). sequence_idx means the index of the corresponding feature matrix (or cut) in a batch.

Return type

Dict[str, Tensor]

supervision_masks(cuts)[source]

Returns a collated batch of masks, marking the supervised regions in cuts. They are zero-padded to the longest cut.

Depending on the strategy implementation, it is expected to be a tensor of shape (B, NF) or (B, NS), where B denotes the number of cuts, NF the number of frames and NS the total number of samples. NF and NS are determined by the longest cut in a batch.

Return type

Tensor

class lhotse.dataset.input_strategies.PrecomputedFeatures[source]

InputStrategy that reads pre-computed features, whose manifests are attached to cuts, from disk.

It pads the feature matrices, if needed.

__call__(cuts)[source]

Reads the pre-computed features from disk/other storage. The returned shape is (B, T, F) => (batch_size, num_frames, num_features).

Return type

Tuple[Tensor, IntTensor]

Returns

a tensor with collated features, and a tensor of num_frames of each cut before padding.

supervision_intervals(cuts)[source]

Returns a dict that specifies the start and end bounds for each supervision, as a 1-D int tensor, in terms of frames:

Where S is the total number of supervisions encountered in the CutSet. Note that S might be different than the number of cuts (B). sequence_idx means the index of the corresponding feature matrix (or cut) in a batch.

Return type

Dict[str, Tensor]

supervision_masks(cuts, use_alignment_if_exists=None)[source]

Returns the mask for supervised frames. :type use_alignment_if_exists: Optional[str] :param use_alignment_if_exists: optional str, key for alignment type to use for generating the mask. If not

exists, fall back on supervision time spans.

Return type

Tensor

class lhotse.dataset.input_strategies.AudioSamples[source]

InputStrategy that reads single-channel recordings, whose manifests are attached to cuts, from disk (or other audio source).

It pads the recordings, if needed.

__call__(cuts)[source]

Reads the audio samples from recordings on disk/other storage. The returned shape is (B, T) => (batch_size, num_samples).

Return type

Tuple[Tensor, IntTensor]

Returns

a tensor with collated audio samples, and a tensor of num_samples of each cut before padding.

supervision_intervals(cuts)[source]

Returns a dict that specifies the start and end bounds for each supervision, as a 1-D int tensor, in terms of samples:

Where S is the total number of supervisions encountered in the CutSet. Note that S might be different than the number of cuts (B). sequence_idx means the index of the corresponding feature matrix (or cut) in a batch.

Return type

Dict[str, Tensor]

supervision_masks(cuts, use_alignment_if_exists=None)[source]

Returns the mask for supervised samples. :type use_alignment_if_exists: Optional[str] :param use_alignment_if_exists: optional str, key for alignment type to use for generating the mask. If not

exists, fall back on supervision time spans.

Return type

Tensor

class lhotse.dataset.input_strategies.OnTheFlyFeatures(extractor, wave_transforms=None)[source]

InputStrategy that reads single-channel recordings, whose manifests are attached to cuts, from disk (or other audio source). Then, it uses a FeatureExtractor to compute their features on-the-fly.

It pads the recordings, if needed.

__call__(cuts)[source]

Reads the audio samples from recordings on disk/other storage and computes their features. The returned shape is (B, T, F) => (batch_size, num_frames, num_features).

Return type

Tuple[Tensor, IntTensor]

Returns

a tensor with collated features, and a tensor of num_frames of each cut before padding.

__init__(extractor, wave_transforms=None)[source]

OnTheFlyFeatures’ constructor.

Parameters
  • extractor (FeatureExtractor) – the feature extractor used on-the-fly (individually on each waveform).

  • wave_transforms (Optional[List[Callable[[Tensor], Tensor]]]) – an optional list of transforms applied on the batch of audio waveforms collated into a single tensor, right before the feature extraction.

supervision_intervals(cuts)[source]

Returns a dict that specifies the start and end bounds for each supervision, as a 1-D int tensor, in terms of frames:

Where S is the total number of supervisions encountered in the CutSet. Note that S might be different than the number of cuts (B). sequence_idx means the index of the corresponding feature matrix (or cut) in a batch.

Return type

Dict[str, Tensor]

supervision_masks(cuts, use_alignment_if_exists=None)[source]

Returns the mask for supervised samples. :type use_alignment_if_exists: Optional[str] :param use_alignment_if_exists: optional str, key for alignment type to use for generating the mask. If not

exists, fall back on supervision time spans.

Return type

Tensor

Augmentation - transforms on cuts

Some transforms, in order for us to have accurate information about the start and end times of the signal and its supervisions, have to be performed on cuts (or CutSets).

class lhotse.dataset.cut_transforms.CutConcatenate(gap=1.0, duration_factor=1.0)[source]

A transform on batch of cuts (CutSet) that concatenates the cuts to minimize the total amount of padding; e.g. instead of creating a batch with 40 examples, we will merge some of the examples together adding some silence between them to avoid a large number of padding frames that waste the computation.

__init__(gap=1.0, duration_factor=1.0)[source]

CutConcatenate’s constructor.

Parameters
  • gap (float) – The duration of silence in seconds that is inserted between the cuts; it’s goal is to let the model “know” that there are separate utterances in a single example.

  • duration_factor (float) – Determines the maximum duration of the concatenated cuts; by default it’s 1, setting the limit at the duration of the longest cut in the batch.

class lhotse.dataset.cut_transforms.CutMix(cuts, snr=10, 20, prob=0.5, pad_to_longest=True)[source]

A transform for batches of cuts (CutSet’s) that stochastically performs noise augmentation with a constant or varying SNR.

__init__(cuts, snr=10, 20, prob=0.5, pad_to_longest=True)[source]

CutMix’s constructor.

Parameters
  • cuts (CutSet) – a CutSet containing augmentation data, e.g. noise, music, babble.

  • snr (Union[float, Tuple[float, float], None]) – either a float, a pair (range) of floats, or None. It determines the SNR of the speech signal vs the noise signal that’s mixed into it. When a range is specified, we will uniformly sample SNR in that range. When it’s None, the noise will be mixed as-is – i.e. without any level adjustment. Note that it’s different from snr=0, which will adjust the noise level so that the SNR is 0.

  • prob (float) – a float probability in range [0, 1]. Specifies the probability with which we will mix augment the cuts.

  • pad_to_longest (bool) – when True, each processed CutSet will be padded with noise to match the duration of the longest Cut in a batch.

class lhotse.dataset.cut_transforms.ExtraPadding(extra_frames=None, extra_samples=None, extra_seconds=None, pad_feat_value=- 23.025850929940457, randomized=False)[source]

A transform on batch of cuts (CutSet) that adds a number of extra context frames/samples/seconds on both sides of the cut. Exactly one type of duration has to specified in the constructor.

It is intended mainly for training frame-synchronous ASR models with convolutional layers to avoid using padding inside of the hidden layers, by giving the model larger context in the input. Another useful application is to shift the input by a little, so that the data seen after frame subsampling is a bit different, which makes this a data augmentation technique.

This is best used as the first transform in the transform list for dataset - it will ensure that each individual cut gets extra context before concatenation, or that it will be filled with noise, etc.

__init__(extra_frames=None, extra_samples=None, extra_seconds=None, pad_feat_value=- 23.025850929940457, randomized=False)[source]

ExtraPadding’s constructor.

Parameters
  • extra_frames (Optional[int]) – The total number of frames to add to each cut. We will add half that number on each side of the cut (“both” directions padding).

  • extra_samples (Optional[int]) – The total number of samples to add to each cut. We will add half that number on each side of the cut (“both” directions padding).

  • extra_seconds (Optional[float]) – The total duration in seconds to add to each cut. We will add half that number on each side of the cut (“both” directions padding).

  • pad_feat_value (float) – When padding a cut with precomputed features, what value should be used for padding (the default is a very low log-energy).

  • randomized (bool) – When True, we will sample a value from a uniform distribution of [0, extra_X] for each cut (for samples/frames – sample an int, for duration – sample a float).

class lhotse.dataset.cut_transforms.PerturbSpeed(factors, p, randgen=None)[source]

A transform on batch of cuts (CutSet) that perturbs the speed of the recordings with a given probability p.

If the effect is applied, then one of the perturbation factors from the constructor’s factors parameter is sampled with uniform probability.

__init__(factors, p, randgen=None)[source]

Initialize self. See help(type(self)) for accurate signature.

Augmentation - transforms on signals

These transforms work directly on batches of collated feature matrices (or possibly raw waveforms, if applicable).

class lhotse.dataset.signal_transforms.GlobalMVN(feature_dim)[source]

Apply global mean and variance normalization

__init__(feature_dim)[source]

Initializes internal Module state, shared by both nn.Module and ScriptModule.

classmethod from_cuts(cuts, max_cuts=None)[source]
Return type

GlobalMVN

classmethod from_file(stats_file)[source]
Return type

GlobalMVN

to_file(stats_file)[source]
forward(features, *args, **kwargs)[source]

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Return type

Tensor

inverse(features)[source]
Return type

Tensor

training
class lhotse.dataset.signal_transforms.SpecAugment(time_warp_factor=80, num_feature_masks=1, features_mask_size=13, num_frame_masks=1, frames_mask_size=70, max_frames_mask_fraction=0.2, p=0.5)[source]

SpecAugment performs three augmentations: - time warping of the feature matrix - masking of ranges of features (frequency bands) - masking of ranges of frames (time)

The current implementation works with batches, but processes each example separately in a loop rather than simultaneously to achieve different augmentation parameters for each example.

__init__(time_warp_factor=80, num_feature_masks=1, features_mask_size=13, num_frame_masks=1, frames_mask_size=70, max_frames_mask_fraction=0.2, p=0.5)[source]

SpecAugment’s contructor.

Parameters
  • time_warp_factor (Optional[int]) – parameter for the time warping; larger values mean more warping. Set to None, or less than 1, to disable.

  • num_feature_masks (int) – how many feature masks should be applied. Set to 0 to disable.

  • features_mask_size (int) – the width of the feature mask (expressed in the number of masked feature bins). This is the T parameter from the SpecAugment paper.

  • num_frame_masks (int) – how many frame (temporal) masks should be applied. Set to 0 to disable.

  • frames_mask_size (int) – the width of the frame (temporal) masks (expressed in the number of masked frames). This is the F parameter from the SpecAugment paper.

  • max_frames_mask_fraction (float) – limits the size of the frame (temporal) mask to this value times the length of the utterance (or supervision segment). This is the parameter denoted by p in the SpecAugment paper.

  • p – the probability of applying this transform. It is different from p in the SpecAugment paper!

forward(features, supervision_segments=None, *args, **kwargs)[source]

Computes SpecAugment for a batch of feature matrices.

Since the batch will usually already be padded, the user can optionally provide a supervision_segments tensor that will be used to apply SpecAugment only to selected areas of the input. The format of this input is described below.

Parameters
  • features (Tensor) – a batch of feature matrices with shape (B, T, F).

  • supervision_segments (Optional[IntTensor]) – an int tensor of shape (S, 3). S is the number of supervision segments that exist in features – there may be either less or more than the batch size. The second dimension encoder three kinds of information: the sequence index of the corresponding feature matrix in features, the start frame index, and the number of frames for each segment.

Return type

Tensor

Returns

a tensor of shape (T, F), or a batch of them with shape (B, T, F)

training

Collation utilities for building custom Datasets

class lhotse.dataset.collation.TokenCollater(cuts, add_eos=True, add_bos=True, pad_symbol='<pad>', bos_symbol='<bos>', eos_symbol='<eos>', unk_symbol='<unk>')[source]

Collate list of tokens

Map sentences to integers. Sentences are padded to equal length. Beginning and end-of-sequence symbols can be added. Call .inverse(tokens_batch, tokens_lens) to reconstruct batch as string sentences.

Example:
>>> token_collater = TokenCollater(cuts)
>>> tokens_batch, tokens_lens = token_collater(cuts.subset(first=32))
>>> original_sentences = token_collater.inverse(tokens_batch, tokens_lens)
Returns:
tokens_batch: IntTensor of shape (B, L)

B: batch dimensoion, number of input sentences L: length of the longest sentence

tokens_lens: IntTensor of shape (B,)

Length of each sentence after adding <eos> and <bos> but before padding.

__init__(cuts, add_eos=True, add_bos=True, pad_symbol='<pad>', bos_symbol='<bos>', eos_symbol='<eos>', unk_symbol='<unk>')[source]

Initialize self. See help(type(self)) for accurate signature.

inverse(tokens_batch, tokens_lens)[source]
Return type

List[str]

lhotse.dataset.collation.collate_features(cuts, pad_direction='right')[source]

Load features for all the cuts and return them as a batch in a torch tensor. The output shape is (batch, time, features). The cuts will be padded with silence if necessary.

Parameters
  • cuts (CutSet) – a CutSet used to load the features.

  • pad_direction (str) – where to apply the padding (right, left, or both).

Return type

Tuple[Tensor, IntTensor]

Returns

a tuple of tensors (features, features_lens).

lhotse.dataset.collation.collate_audio(cuts, pad_direction='right')[source]

Load audio samples for all the cuts and return them as a batch in a torch tensor. The output shape is (batch, time). The cuts will be padded with silence if necessary.

Parameters
  • cuts (CutSet) – a CutSet used to load the audio samples.

  • pad_direction (str) – where to apply the padding (right, left, or both).

Return type

Tuple[Tensor, IntTensor]

Returns

a tuple of tensors (audio, audio_lens).

lhotse.dataset.collation.collate_multi_channel_features(cuts)[source]

Load features for all the cuts and return them as a batch in a torch tensor. The cuts have to be of type MixedCut and their tracks will be interpreted as individual channels. The output shape is (batch, channel, time, features). The cuts will be padded with silence if necessary.

Return type

Tensor

lhotse.dataset.collation.collate_multi_channel_audio(cuts)[source]

Load audio samples for all the cuts and return them as a batch in a torch tensor. The cuts have to be of type MixedCut and their tracks will be interpreted as individual channels. The output shape is (batch, channel, time). The cuts will be padded with silence if necessary.

Return type

Tensor

lhotse.dataset.collation.collate_vectors(tensors, padding_value=- 100, matching_shapes=False)[source]

Convert an iterable of 1-D tensors (of possibly various lengths) into a single stacked tensor.

Parameters
  • tensors (Iterable[Union[Tensor, ndarray]]) – an iterable of 1-D tensors.

  • padding_value (Union[int, float]) – the padding value inserted to make all tensors have the same length.

  • matching_shapes (bool) – when True, will fail when input tensors have different shapes.

Return type

Tensor

Returns

a tensor with shape (B, L) where B is the number of input tensors and L is the number of items in the longest tensor.

lhotse.dataset.collation.collate_matrices(tensors, padding_value=0, matching_shapes=False)[source]

Convert an iterable of 2-D tensors (of possibly various first dimension, but consistent second dimension) into a single stacked tensor.

Parameters
  • tensors (Iterable[Union[Tensor, ndarray]]) – an iterable of 2-D tensors.

  • padding_value (Union[int, float]) – the padding value inserted to make all tensors have the same length.

  • matching_shapes (bool) – when True, will fail when input tensors have different shapes.

Return type

Tensor

Returns

a tensor with shape (B, L, F) where B is the number of input tensors, L is the largest found shape[0], and F is equal to shape[1].

lhotse.dataset.collation.maybe_pad(cuts, duration=None, num_frames=None, num_samples=None, direction='right')[source]

Check if all cuts’ durations are equal and pad them to match the longest cut otherwise.

Return type

CutSet