PyTorch Datasets

Lhotse supports PyTorch’s dataset API, providing implementations for the Dataset and Sampler concepts. They can be used together with the standard DataLoader class for efficient mini-batch collection with multiple parallel readers and pre-fetching.

A quick re-cap of PyTorch’s data API

PyTorch defines the Dataset class that is responsible for reading the data from disk/memory/Internet/database/etc., and converting it to tensors that can be used for network training or inference. These Dataset’s are typically „map-style” datasets which are given an index (or a list of indices) and return the corresponding data samples.

The selection of indices is performed by the Sampler class. Sampler, knowing the length (number of items) in a Dataset, can use various strategies to determine the order of elements to read (e.g. sequential reads, or random reads).

More details about the data pipeline API in PyTorch can be found here.

About Lhotse’s Datasets and Samplers

Lhotse provides a number of utilities that make it simpler to define Dataset’s for speech processing tasks. CutSet is the base data structure that is used to initialize the Dataset class. This makes it possible to manipulate the speech data in convenient ways - pad, mix, concatenate, augment, compute features, look up the supervision information, etc.

Lhotse’s Dataset’s will perform batching by themselves, because auto-collation in DataLoader is too limiting for speech data handling. These Dataset’s expect to be handed lists of element indices, so that they can collate the data before it is passed to the DataLoader (which must use batch_size=None). It allows for interesting collation methods - e.g. padding the speech with noise recordings, or actual acoustic context, rather than artificial zeroes; or dynamic batch sizes.

The items for mini-batch creation are selected by the Sampler. Lhotse defines Sampler classes that are initialized with CutSet’s, so that they can look up specific properties of an utterance to stratify the sampling. For example, SimpleCutSampler has a defined max_frames attribute, and it will keep sampling cuts for a batch until they do not exceed the specified number of frames. Another strategy — used in BucketingSampler — will first group the cuts of similar durations into buckets, and then randomly select a bucket to draw the whole batch from.

For tasks where both input and output of the model are speech utterances, we can use the CutPairsSampler, which accepts two CutSet’s and will match the cuts in them by their IDs.

A typical Lhotse’s dataset API usage might look like this:

from torch.utils.data import DataLoader
from lhotse.dataset import SpeechRecognitionDataset, SimpleCutSampler

cuts = CutSet(...)
dset = SpeechRecognitionDataset(cuts)
sampler = SimpleCutSampler(cuts, max_frames=50000)
# Dataset performs batching by itself, so we have to indicate that
# to the DataLoader with batch_size=None
dloader = DataLoader(dset, sampler=sampler, batch_size=None, num_workers=1)
for batch in dloader:
    ...  # process data

Restoring sampler’s state: continuing the training

All CutSampler types can save their progress and pick up from that checkpoint. For consistency with PyTorch tensors, the relevant methods are called .state_dict() and .load_state_dict(). The following example illustrates how to save the sampler’s state (pay attention to the last bit):

dataset = ...  # Some task-specific dataset initialization
sampler = BucketingSampler(cuts, max_duration=200, shuffle=True, num_buckets=30)
dloader = DataLoader(dataset, batch_size=None, sampler=sampler, num_workers=4)
global_step = 0
for epoch in range(30):
    dloader.sampler.set_epoch(epoch)
    for batch in dloader:
        # ... processing forward, backward, etc.
        global_step += 1

        if global_step % 5000 == 0:
            state = dloader.sampler.state_dict()
            torch.save(state, f'sampler-ckpt-ep{epoch}-step{global_step}.pt')

In case that the training is ended abruptly and the epochs are very long (10k+ steps, not uncommon with large datasets these days), we can resume the training from where it left off like the following:

# Creating a vanilla sampler, we will read the previous progress into it.
sampler = BucketingSampler(cuts, max_duration=200, shuffle=True, num_buckets=30)

# Restore the sampler's state.
state = torch.load('sampler-ckpt-ep5-step75000.pt')
sampler.load_state_dict(state)

dloader = DataLoader(dataset, batch_size=None, sampler=sampler, num_workers=4)

global_step = sampler.diagnostics.total_cuts  # <-- Restore the global step idx.
for epoch in range(sampler.epoch, 30):  # <-- Skip previous epochs that are already processed.

    dloader.sampler.set_epoch(epoch)
    for batch in dloader:
        # Note: the first batch is going to be from step 75009.
        # With DataLoader num_workers==0, it would have been 75001, but we get
        # +8 because of num_workers==4 * prefetching_factor==2

        # ... processing forward, backward, etc.
        global_step += 1

Note

In general, the sampler arguments may be different – loading a state_dict will overwrite the arguments, and emit a warning for the user to be aware what happened. BucketingSampler is an exception – the num_buckets and bucket_method must be consistent, otherwise we couldn’t guarantee identical outcomes after training resumption.

Note

The DataLoader’s num_workers can be different after resuming.

Batch I/O: pre-computed vs. on-the-fly features

Depending on the experimental setup and infrastructure, it might be more convenient to either pre-compute and store features like filter-bank energies for later use (as traditionally done in Kaldi/ESPnet/Espresso toolkits), or compute them dynamically during training (“on-the-fly”). Lhotse supports both modes of computation by introducing a class called BatchIO. It is accepted as an argument in most dataset classes, and defaults to PrecomputedFeatures. Other available choices are AudioSamples for working with waveforms directly, and OnTheFlyFeatures, which wraps a FeatureExtractor and applies it to a batch of recordings. These strategies automatically pad and collate the inputs, and provide information about the original signal lengths: as a number of frames/samples, binary mask, or start-end frame/sample pairs.

Which strategy to choose?

In general, pre-computed features can be greatly compressed (we achieve 70% size reduction with regard to un-compressed features), and so the I/O load on your computing infrastructure will be much smaller than if you read the recordings directly. This is especially valuable when working with network file systems (NFS) that are typically used in computational grids for storage. When your experiment is I/O bound, then it is best to use pre-computed features.

When I/O is not the issue, it might be preferable to use on-the-fly computation as it shouldn’t require any prior steps to perform the network training. It is also simpler to apply a vast range of data augmentation methods in a fully randomized way (e.g. reverberation), although Lhotse provides support for approximate feature-domain signal mixing (e.g. for additive noise augmentation) to alleviate that to some extent.

Handling random seeds

Lhotse provides several mechanisms for controlling randomness. At a basic level, there is a function lhotse.utils.fix_random_seed() which seeds Python’s, numpy’s and torch’s RNGs with the provided number.

However, many functions and classes in Lhotse accept either a random seed or an RNG instance to provide a finer control over randomness. Whenever random seed is accepted, it can be either an integer, or one of two strings: "randomized" or "trng".

  • "randomized” seed is resolved lazily at the moment it’s needed and is intended as a mechanism to provide a different seed to each dataloading worker. In order for "randomized" to work, you have to first invoke lhotse.dataset.dataloading.worker_init_fn() in a given subprocess which sets the right environment variables. With a PyTorch DataLoader you can pass the keyword argument worker_init_fn==make_worker_init_fn(seed=int_seed, rank=..., world_size=...) using lhotse.dataset.dataloading.make_worker_init_fn() which will set the right seeds for you in multiprocessing and multi-node training. Note that if you resume training, you should change the seed passed to make_worker_init_fn on each resumed run to make the model train on different data.

  • "trng" seed is also resolved lazily at runtime, but it uses a true RNG (if available on your OS; consult Python’s secrets module documentation). It’s an easy way to ensure that every time you iterate data it’s done in different order, but may cause debugging data issues to be more difficult.

Note

The lazy seed resolution is done by calling lhotse.dataset.dataloading.resolve_seed().

Customizing sampling constraints

Since version 1.22.0, Lhotse provides a mechanism to customize how samplers measure the “length” of each example for the purpose of determining dynamic batch size. To leverage this option, use the keyword argument constraint in DynamicCutSampler or DynamicBucketingSampler. The sampling criteria are defined by implementing a subclass of SamplingConstraint:

class lhotse.dataset.sampling.base.SamplingConstraint[source]

Defines the interface for sampling constraints. A sampling constraint keeps track of the sampled examples and lets the sampler know when it should yield a mini-batch.

abstract add(example)[source]

Update the sampling constraint with the information about the sampled example (e.g. current batch size, total duration).

Return type:

None

abstract exceeded()[source]

Inform if the sampling constraint has been exceeded.

Return type:

bool

abstract close_to_exceeding()[source]

Inform if we’re going to exceed the sampling constraint after adding one more example.

Return type:

bool

abstract reset()[source]

Resets the internal state (called after yielding a mini-batch).

Return type:

None

abstract measure_length(example)[source]

Returns the “size” of an example, used to create bucket distribution for bucketing samplers (e.g., for audio it may be duration; for text it may be number of tokens; etc.).

Return type:

float

copy()[source]

Return a shallow copy of this constraint.

Return type:

SamplingConstraint

The default constraint is TimeConstraint which is created from max_duration, max_cuts, and quadratic_duration args passed to samplers constructor.

Sampling non-audio data

Because SamplingConstraint defines the method measure_length, it’s possible to use a different attribute than duration (or a different formula) for computing the effective batch size. This enables re-using Lhotse’s sampling algorithms for other data than speech, and passing around other objects than Cut.

To showcase this, we added an experimental support for text-only dataloading. We introduced a few classes specifically for this purpose:

class lhotse.cut.text.TextExample(text, tokens=None, custom=None)[source]

Represents a single text example. Useful e.g. for language modeling.

text: str
tokens: Optional[ndarray] = None
custom: Optional[Dict[str, Any]] = None
property num_tokens: int | None
__init__(text, tokens=None, custom=None)
class lhotse.cut.text.TextPairExample(source, target, custom=None)[source]

Represents a pair of text examples. Useful e.g. for sequence-to-sequence tasks.

source: TextExample
target: TextExample
custom: Optional[Dict[str, Any]] = None
property num_tokens: int | None
__init__(source, target, custom=None)
class lhotse.lazy.LazyTxtIterator(path, as_text_example=True)[source]

LazyTxtIterator is a thin wrapper over builtin open function to iterate over lines in a (possibly compressed) text file. It can also provide the number of lines via __len__ via fast newlines counting.

__init__(path, as_text_example=True)[source]
class lhotse.dataset.sampling.base.TokenConstraint(max_tokens=None, max_examples=None, current=0, num_examples=0, longest_seen=0, quadratic_length=None)[source]

Represents a token-based constraint for sampler classes that sample text data. It is defined as maximum total number of tokens in a mini-batch and/or max batch size.

Similarly to TimeConstraint, we support quadratic_length for quadratic token penalty when sampling longer texts.

max_tokens: int = None
max_examples: int = None
current: int = 0
num_examples: int = 0
longest_seen: int = 0
quadratic_length: Optional[int] = None
add(example)[source]

Increment the internal token counter for the constraint, selecting the right property from the input object.

Return type:

None

exceeded()[source]

Is the constraint exceeded or not.

Return type:

bool

close_to_exceeding()[source]

Check if the batch is close to satisfying the constraints. We define “closeness” as: if we added one more cut that has duration/num_frames/num_samples equal to the longest seen cut in the current batch, then the batch would have exceeded the constraints.

Return type:

bool

reset()[source]

Reset the internal counter (to be used after a batch was created, to start collecting a new one).

Return type:

None

measure_length(example)[source]

Returns the “size” of an example, used to create bucket distribution for bucketing samplers (e.g., for audio it may be duration; for text it may be number of tokens; etc.).

Return type:

float

__init__(max_tokens=None, max_examples=None, current=0, num_examples=0, longest_seen=0, quadratic_length=None)

A minimal example of how to perform text-only dataloading is available below (note that any of these classes may be replaced by your own implementation if that is more suitable to your work):

import torch
import numpy as np
from lhotse import CutSet
from lhotse.lazy import LazyTxtIterator
from lhotse.cut.text import TextPairExample
from lhotse.dataset import DynamicBucketingSampler, TokenConstraint
from lhotse.dataset.collation import collate_vectors

examples = CutSet(LazyTxtIterator("data.txt"))

def tokenize(example):
    # tokenize as individual bytes; BPE or another technique may be used here instead
    example.tokens = np.frombuffer(example.text.encode("utf-8"), np.int8)
    return example

examples = examples.map(tokenize, apply_fn=None)

sampler = DynamicBucketingSampler(examples, constraint=TokenConstraint(max_tokens=1024, quadratic_length=128),      num_buckets=2)

class ExampleTextDataset(torch.utils.data.Dataset):
    def __getitem__(self, examples: CutSet):
        tokens = [ex.tokens for ex in examples]
        token_lens = torch.tensor([len(t) for t in tokens])
        tokens = collate_vectors(tokens, padding_value=-1)
        return tokens, token_lens

dloader = torch.utils.data.DataLoader(ExampleTextDataset(), sampler=sampler, batch_size=None)

for batch in dloader:
    print(batch)

Note

Support for this kind of dataloading is experimental in Lhotse. If you run into any rough edges, please let us know.

Dataset’s list

class lhotse.dataset.diarization.DiarizationDataset(cuts, uem=None, min_speaker_dim=None, global_speaker_ids=False)[source]

A PyTorch Dataset for the speaker diarization task. Our assumptions about speaker diarization are the following:

  • we assume a single channel input (for now), which could be either a true mono signal

    or a beamforming result from a microphone array.

  • we assume that the supervision used for model training is a speech activity matrix, with one

    row dedicated to each speaker (either in the current cut or the whole dataset, depending on the settings). The columns correspond to feature frames. Each row is effectively a Voice Activity Detection supervision for a single speaker. This setup is somewhat inspired by the TS-VAD paper: https://arxiv.org/abs/2005.07272

Each item in this dataset is a dict of:

{
    'features': (B x T x F) tensor
    'features_lens': (B, ) tensor
    'speaker_activity': (B x num_speaker x T) tensor
}

Constructor arguments:

Parameters:
  • cuts (CutSet) – a CutSet used to create the dataset object.

  • uem (Optional[SupervisionSet]) – a SupervisionSet used to set regions for diarization

  • min_speaker_dim (Optional[int]) – optional int, when specified it will enforce that the matrix shape is at least that value (useful for datasets like CHiME 6 where the number of speakers is always 4, but some cuts might have less speakers than that).

  • global_speaker_ids (bool) – a bool, indicates whether the same speaker should always retain the same row index in the speaker activity matrix (useful for speaker-dependent systems)

  • root_dir – a prefix path to be attached to the feature files paths.

__init__(cuts, uem=None, min_speaker_dim=None, global_speaker_ids=False)[source]
class lhotse.dataset.unsupervised.UnsupervisedDataset[source]

Dataset that contains no supervision - it only provides the features extracted from recordings.

{
    'features': (B x T x F) tensor
    'features_lens': (B, ) tensor
}
__init__()[source]
class lhotse.dataset.unsupervised.UnsupervisedWaveformDataset(collate=True)[source]

A variant of UnsupervisedDataset that provides waveform samples instead of features. The output is a tensor of shape (C, T), with C being the number of channels and T the number of audio samples. In this implementation, there will always be a single channel.

Returns:

{
    'audio': (B x NumSamples) float tensor
    'audio_lens': (B, ) int tensor
}
__init__(collate=True)[source]
class lhotse.dataset.unsupervised.DynamicUnsupervisedDataset(feature_extractor, augment_fn=None)[source]

An example dataset that shows how to use on-the-fly feature extraction in Lhotse. It accepts two additional inputs - a FeatureExtractor and an optional WavAugmenter for time-domain data augmentation.. The output is approximately the same as that of the UnsupervisedDataset - there might be slight differences for MixedCut``s, because this dataset mixes them in the time domain, and ``UnsupervisedDataset does that in the feature domain. Cuts that are not mixed will yield identical results in both dataset classes.

__init__(feature_extractor, augment_fn=None)[source]
class lhotse.dataset.unsupervised.RecordingChunkIterableDataset(recordings, chunk_size, chunk_shift)[source]

This dataset iterates over chunks of a recording, for each recording provided. It supports setting a chunk_shift < chunk_size to run model predictions on overlapping audio chunks.

The format of yielded items is the following:

{
    "recording_id": str
    "begin_time": tensor with dtype=float32 shape=(1,)
    "end_time": tensor with dtype=float32 shape=(1,)
    "audio": tensor with dtype=float32 shape=(chunk_size_in_samples,)
}

Unlike most other datasets in Lhotse, this dataset does not yield batched items, and should be used like the following:

>>> recordings = RecordingSet.from_file("my-recordings.jsonl.gz")
... dataset = RecordingChunkIterableDataset(recordings, chunk_size=30.0, chunk_shift=25.0)
... dloader = torch.utils.data.DataLoader(
...     dataset,
...     batch_size=32,
...     collate_fn=audio_chunk_collate,
...     worker_init_fn=audio_chunk_worker_init_fn,
... )
__init__(recordings, chunk_size, chunk_shift)[source]
validate()[source]
Return type:

None

class lhotse.dataset.unsupervised.ShiftingBuffer(chunk_size, chunk_shift)[source]

Utility for iterating over streaming audio chunks that supports chunk_shift < chunk_size. It is useful when running model predictions on overlapping chunks of audio data.

__init__(chunk_size, chunk_shift)[source]
push(audio)[source]

Add new chunk of audio to the buffer. Expects shape (num_samples, ).

Return type:

None

get_chunks()[source]

Retrieve chunks accumulated so far, adjusted for chunk_shift. For chunk_shift < chunk_size, there will typically be more chunks returned from this function than were pushed into the buffer because of overlap. The returned shape is (num_chunks, chunk_size).

Return type:

Tensor

flush()[source]

Flush out the remainder chunk from the buffer. Typically it will be shorter than chunk_size. The returned shape is (remainder_size, ).

Return type:

Tensor

lhotse.dataset.unsupervised.audio_chunk_collate(batch)[source]
lhotse.dataset.unsupervised.audio_chunk_worker_init_fn(worker_id)[source]
class lhotse.dataset.speech_recognition.K2SpeechRecognitionDataset(return_cuts=False, cut_transforms=None, input_transforms=None, input_strategy=<lhotse.dataset.input_strategies.PrecomputedFeatures object>)[source]

The PyTorch Dataset for the speech recognition task using k2 library.

This dataset expects to be queried with lists of cut IDs, for which it loads features and automatically collates/batches them.

To use it with a PyTorch DataLoader, set batch_size=None and provide a SimpleCutSampler sampler.

Each item in this dataset is a dict of:

{
    'inputs': float tensor with shape determined by :attr:`input_strategy`:
              - single-channel:
                - features: (B, T, F)
                - audio: (B, T)
              - multi-channel: currently not supported
    'supervisions': [
        {
            'sequence_idx': Tensor[int] of shape (S,)
            'text': List[str] of len S

            # For feature input strategies
            'start_frame': Tensor[int] of shape (S,)
            'num_frames': Tensor[int] of shape (S,)

            # For audio input strategies
            'start_sample': Tensor[int] of shape (S,)
            'num_samples': Tensor[int] of shape (S,)

            # Optionally, when return_cuts=True
            'cut': List[AnyCut] of len S
        }
    ]
}

Dimension symbols legend: * B - batch size (number of Cuts) * S - number of supervision segments (greater or equal to B, as each Cut may have multiple supervisions) * T - number of frames of the longest Cut * F - number of features

The ‘sequence_idx’ field is the index of the Cut used to create the example in the Dataset.

__init__(return_cuts=False, cut_transforms=None, input_transforms=None, input_strategy=<lhotse.dataset.input_strategies.PrecomputedFeatures object>)[source]

k2 ASR IterableDataset constructor.

Parameters:
  • return_cuts (bool) – When True, will additionally return a “cut” field in each batch with the Cut objects used to create that batch.

  • cut_transforms (Optional[List[Callable[[CutSet], CutSet]]]) – A list of transforms to be applied on each sampled batch, before converting cuts to an input representation (audio/features). Examples: cut concatenation, noise cuts mixing, etc.

  • input_transforms (Optional[List[Callable[[Tensor], Tensor]]]) – A list of transforms to be applied on each sampled batch, after the cuts are converted to audio/features. Examples: normalization, SpecAugment, etc.

  • input_strategy (BatchIO) – Converts cuts into a collated batch of audio/features. By default, reads pre-computed features from disk.

lhotse.dataset.speech_recognition.validate_for_asr(cuts)[source]
Return type:

None

lhotse.dataset.speech_synthesis

alias of <module ‘lhotse.dataset.speech_synthesis’ from ‘/home/docs/checkouts/readthedocs.org/user_builds/lhotse/envs/latest/lib/python3.10/site-packages/lhotse/dataset/speech_synthesis.py’>

class lhotse.dataset.source_separation.DynamicallyMixedSourceSeparationDataset(sources_set, mixtures_set, nonsources_set=None)[source]

A PyTorch Dataset for the source separation task. It’s created from a number of CutSets:

  • sources_set: provides the audio cuts for the sources that (the targets of source separation),

  • mixtures_set: provides the audio cuts for the signal mix (the input of source separation),

  • nonsources_set: (optional) provides the audio cuts for other signals that are in the mix, but are not the targets of source separation. Useful for adding noise.

When queried for data samples, it returns a dict of:

{
    'sources': (N x T x F) tensor,
    'mixture': (T x F) tensor,
    'real_mask': (N x T x F) tensor,
    'binary_mask': (T x F) tensor
}

This Dataset performs on-the-fly feature-domain mixing of the sources. It expects the mixtures_set to contain MixedCuts, so that it knows which Cuts should be mixed together.

__init__(sources_set, mixtures_set, nonsources_set=None)[source]
validate()[source]
class lhotse.dataset.source_separation.PreMixedSourceSeparationDataset(sources_set, mixtures_set)[source]

A PyTorch Dataset for the source separation task. It’s created from two CutSets - one provides the audio cuts for the sources, and the other one the audio cuts for the signal mix. When queried for data samples, it returns a dict of:

{
    'sources': (N x T x F) tensor,
    'mixture': (T x F) tensor,
    'real_mask': (N x T x F) tensor,
    'binary_mask': (T x F) tensor
}

It expects both CutSets to return regular Cuts, meaning that the signals were mixed in the time domain. In contrast to DynamicallyMixedSourceSeparationDataset, no on-the-fly feature-domain-mixing is performed.

__init__(sources_set, mixtures_set)[source]
class lhotse.dataset.vad.VadDataset(input_strategy=<lhotse.dataset.input_strategies.PrecomputedFeatures object>, cut_transforms=None, input_transforms=None)[source]

The PyTorch Dataset for the voice activity detection task. Each item in this dataset is a dict of:

{
    'inputs': (B x T x F) tensor
    'input_lens': (B,) tensor
    'is_voice': (T x 1) tensor
    'cut': List[Cut]
}
__init__(input_strategy=<lhotse.dataset.input_strategies.PrecomputedFeatures object>, cut_transforms=None, input_transforms=None)[source]

Sampler’s list

class lhotse.dataset.sampling.TokenConstraint(max_tokens=None, max_examples=None, current=0, num_examples=0, longest_seen=0, quadratic_length=None)[source]

Represents a token-based constraint for sampler classes that sample text data. It is defined as maximum total number of tokens in a mini-batch and/or max batch size.

Similarly to TimeConstraint, we support quadratic_length for quadratic token penalty when sampling longer texts.

max_tokens: int = None
max_examples: int = None
current: int = 0
num_examples: int = 0
longest_seen: int = 0
quadratic_length: Optional[int] = None
add(example)[source]

Increment the internal token counter for the constraint, selecting the right property from the input object.

Return type:

None

exceeded()[source]

Is the constraint exceeded or not.

Return type:

bool

close_to_exceeding()[source]

Check if the batch is close to satisfying the constraints. We define “closeness” as: if we added one more cut that has duration/num_frames/num_samples equal to the longest seen cut in the current batch, then the batch would have exceeded the constraints.

Return type:

bool

reset()[source]

Reset the internal counter (to be used after a batch was created, to start collecting a new one).

Return type:

None

measure_length(example)[source]

Returns the “size” of an example, used to create bucket distribution for bucketing samplers (e.g., for audio it may be duration; for text it may be number of tokens; etc.).

Return type:

float

__init__(max_tokens=None, max_examples=None, current=0, num_examples=0, longest_seen=0, quadratic_length=None)
class lhotse.dataset.sampling.TimeConstraint(max_duration=None, max_cuts=None, current=0, num_cuts=0, longest_seen=0, quadratic_duration=None)[source]

Represents a time-based constraint for sampler classes. It is defined as maximum total batch duration (in seconds) and/or the total number of cuts.

TimeConstraint can be used for tracking whether the criterion has been exceeded via the add(cut), exceeded() and reset() methods. It will automatically track the right criterion (i.e. select duration from the cut). It can also be a null constraint (never exceeded).

When quadratic_duration is set, we will try to compensate for models that have a quadratic complexity w.r.t. the input sequence length. We use the following formula to determine the effective duration for each cut:

effective_duration = duration + (duration ** 2) / quadratic_duration

We recomend setting quadratic_duration to something between 15 and 40 for transformer architectures.

max_duration: Optional[float] = None
max_cuts: Optional[int] = None
current: Union[int, float] = 0
num_cuts: int = 0
longest_seen: Union[int, float] = 0
quadratic_duration: Optional[float] = None
is_active()[source]

Is it an actual constraint, or a dummy one (i.e. never exceeded).

Return type:

bool

add(example)[source]

Increment the internal counter for the time constraint, selecting the right property from the input cut object.

Return type:

None

exceeded()[source]

Is the constraint exceeded or not.

Return type:

bool

close_to_exceeding()[source]

Check if the batch is close to satisfying the constraints. We define “closeness” as: if we added one more cut that has duration/num_frames/num_samples equal to the longest seen cut in the current batch, then the batch would have exceeded the constraints.

Return type:

bool

reset()[source]

Reset the internal counter (to be used after a batch was created, to start collecting a new one).

Return type:

None

measure_length(example)[source]

Returns the “size” of an example, used to create bucket distribution for bucketing samplers (e.g., for audio it may be duration; for text it may be number of tokens; etc.).

Return type:

float

state_dict()[source]
Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]
Return type:

None

__init__(max_duration=None, max_cuts=None, current=0, num_cuts=0, longest_seen=0, quadratic_duration=None)
class lhotse.dataset.sampling.SamplingDiagnostics(current_epoch=0, stats_per_epoch=None)[source]

Utility for collecting diagnostics about the sampling process: how many cuts/batches were discarded.

current_epoch: int = 0
stats_per_epoch: Dict[int, EpochDiagnostics] = None
reset_current_epoch()[source]
Return type:

None

set_epoch(epoch)[source]
Return type:

None

advance_epoch()[source]
Return type:

None

property current_epoch_stats: EpochDiagnostics
keep(cuts)[source]
Return type:

None

discard(cuts)[source]
Return type:

None

discard_single(cut)[source]
Return type:

None

property kept_cuts: int
property discarded_cuts: int
property kept_batches: int
property discarded_batches: int
property total_cuts: int
property total_batches: int
get_report(per_epoch=False)[source]

Returns a string describing the statistics of the sampling process so far.

Return type:

str

state_dict()[source]
Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]
Return type:

SamplingDiagnostics

__init__(current_epoch=0, stats_per_epoch=None)
class lhotse.dataset.sampling.SamplingConstraint[source]

Defines the interface for sampling constraints. A sampling constraint keeps track of the sampled examples and lets the sampler know when it should yield a mini-batch.

abstract add(example)[source]

Update the sampling constraint with the information about the sampled example (e.g. current batch size, total duration).

Return type:

None

abstract exceeded()[source]

Inform if the sampling constraint has been exceeded.

Return type:

bool

abstract close_to_exceeding()[source]

Inform if we’re going to exceed the sampling constraint after adding one more example.

Return type:

bool

abstract reset()[source]

Resets the internal state (called after yielding a mini-batch).

Return type:

None

abstract measure_length(example)[source]

Returns the “size” of an example, used to create bucket distribution for bucketing samplers (e.g., for audio it may be duration; for text it may be number of tokens; etc.).

Return type:

float

copy()[source]

Return a shallow copy of this constraint.

Return type:

SamplingConstraint

class lhotse.dataset.sampling.BucketingSampler(*cuts, sampler_type=<class 'lhotse.dataset.sampling.simple.SimpleCutSampler'>, num_buckets=10, drop_last=False, seed=0, **kwargs)[source]

Sorts the cuts in a CutSet by their duration and puts them into similar duration buckets. For each bucket, it instantiates a simpler sampler instance, e.g. SimpleCutSampler.

It behaves like an iterable that yields lists of strings (cut IDs). During iteration, it randomly selects one of the buckets to yield the batch from, until all the underlying samplers are depleted (which means it’s the end of an epoch).

Examples:

Bucketing sampler with 20 buckets, sampling single cuts:

>>> sampler = BucketingSampler(
...    cuts,
...    # BucketingSampler specific args
...    sampler_type=SimpleCutSampler, num_buckets=20,
...    # Args passed into SimpleCutSampler
...    max_frames=20000
... )

Bucketing sampler with 20 buckets, sampling pairs of source-target cuts:

>>> sampler = BucketingSampler(
...    cuts, target_cuts,
...    # BucketingSampler specific args
...    sampler_type=CutPairsSampler, num_buckets=20,
...    # Args passed into CutPairsSampler
...    max_source_frames=20000, max_target_frames=15000
... )
__init__(*cuts, sampler_type=<class 'lhotse.dataset.sampling.simple.SimpleCutSampler'>, num_buckets=10, drop_last=False, seed=0, **kwargs)[source]

BucketingSampler’s constructor.

Parameters:
  • cuts (CutSet) – one or more CutSet objects. The first one will be used to determine the buckets for all of them. Then, all of them will be used to instantiate the per-bucket samplers.

  • sampler_type (Type) – a sampler type that will be created for each underlying bucket.

  • num_buckets (int) – how many buckets to create.

  • drop_last (bool) – When True, we will drop all incomplete batches. A batch is considered incomplete if it depleted a bucket before hitting the constraint such as max_duration, max_cuts, etc.

  • seed (int) – random seed for bucket selection

  • kwargs (Any) – Arguments used to create the underlying sampler for each bucket.

property remaining_duration: float | None

Remaining duration of data left in the sampler (may be inexact due to float arithmetic). Not available when the CutSet is read in lazy mode (returns None).

property remaining_cuts: int | None

Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

property num_cuts: int | None

Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

set_epoch(epoch)[source]

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

epoch (int) – Epoch number.

Return type:

None

filter(predicate)[source]

Add a constraint on individual cuts that has to be satisfied to consider them.

Can be useful when handling large, lazy manifests where it is not feasible to pre-filter them before instantiating the sampler.

Return type:

None

Example:
>>> cuts = CutSet(...)
... sampler = SimpleCutSampler(cuts, max_duration=100.0)
... # Retain only the cuts that have at least 1s and at most 20s duration.
... sampler.filter(lambda cut: 1.0 <= cut.duration <= 20.0)
allow_iter_to_reset_state()[source]

Enables re-setting to the start of an epoch when iter() is called. This is only needed in one specific scenario: when we restored previous sampler state via sampler.load_state_dict() but want to discard the progress in the current epoch and start from the beginning.

state_dict()[source]

Return the current state of the sampler in a state_dict. Together with load_state_dict(), this can be used to restore the training loop’s state to the one stored in the state_dict.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]

Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. :rtype: None

Caution

The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere.

Caution

The input state_dict is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don’t want this behavior, pass a copy inside of this function (e.g., using import deepcopy).

Note

For implementers of sub-classes of CutSampler: the flag self._just_restored_state has to be handled in __iter__ to make it avoid resetting the just-restored state (only once).

property is_depleted: bool
property diagnostics: SamplingDiagnostics

Info on how many cuts / batches were returned or rejected during iteration.

This property can be overriden by child classes e.g. to merge diagnostics of composite samplers.

get_report()[source]

Returns a string describing the statistics of the sampling process so far.

Return type:

str

class lhotse.dataset.sampling.CutPairsSampler(source_cuts, target_cuts, max_source_duration=None, max_target_duration=None, max_cuts=None, shuffle=False, drop_last=False, world_size=None, rank=None, seed=0)[source]

Samples pairs of cuts from a “source” and “target” CutSet. It expects that both CutSet’s strictly consist of Cuts with corresponding IDs. It behaves like an iterable that yields lists of strings (cut IDs).

When one of max_frames, max_samples, or max_duration is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. Padding required to collate the batch does not contribute to max frames/samples/duration.

__init__(source_cuts, target_cuts, max_source_duration=None, max_target_duration=None, max_cuts=None, shuffle=False, drop_last=False, world_size=None, rank=None, seed=0)[source]

CutPairsSampler’s constructor.

Parameters:
  • source_cuts (CutSet) – the first CutSet to sample data from.

  • target_cuts (CutSet) – the second CutSet to sample data from.

  • max_source_duration (Optional[float]) – The maximum total recording duration from source_cuts.

  • max_target_duration (Optional[float]) – The maximum total recording duration from target_cuts.

  • max_cuts (Optional[int]) – The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off.

  • shuffle (bool) – When True, the cuts will be shuffled at the start of iteration. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: for epoch in range(10): for batch in dataset: … as every epoch will see a different cuts order.

  • drop_last (bool) – When True, the last batch is dropped if it’s incomplete.

  • world_size (Optional[int]) – Total number of distributed nodes. We will try to infer it by default.

  • rank (Optional[int]) – Index of distributed node. We will try to infer it by default.

  • seed (int) – Random seed used to consistently shuffle the dataset across different processes.

property remaining_duration: float | None

Remaining duration of data left in the sampler (may be inexact due to float arithmetic). Not available when the CutSet is read in lazy mode (returns None).

property remaining_cuts: int | None

Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

property num_cuts: int | None

Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

state_dict()[source]

Return the current state of the sampler in a state_dict. Together with load_state_dict(), this can be used to restore the training loop’s state to the one stored in the state_dict.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]

Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. :rtype: None

Caution

The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere.

Caution

The input state_dict is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don’t want this behavior, pass a copy inside of this function (e.g., using import deepcopy).

Note

For implementers of sub-classes of CutSampler: the flag self._just_restored_state has to be handled in __iter__ to make it avoid resetting the just-restored state (only once).

class lhotse.dataset.sampling.DynamicCutSampler(*cuts, max_duration=None, max_cuts=None, constraint=None, shuffle=False, drop_last=False, consistent_ids=True, shuffle_buffer_size=20000, quadratic_duration=None, world_size=None, rank=None, seed=0, strict=None)[source]

A dynamic (streaming) variant of sampler that doesn’t stratify the sampled cuts in any way. It is a generalization of SimpleCutSampler and CutPairsSampler in that it allows to jointly iterate an arbitrary number of CutSets.

When input CutSets are opened in lazy mode, this sampler doesn’t require reading the whole cut set into memory.

For scenarios such as ASR, VAD, Speaker ID, or TTS training, this class supports single CutSet iteration. Example:

>>> cuts = CutSet(...)
>>> sampler = DynamicCutSampler(cuts, max_duration=100)
>>> for batch in sampler:
...     assert isinstance(batch, CutSet)

For other scenarios that require pairs (or triplets, etc.) of utterances, this class supports zipping multiple CutSets together. Such scenarios could be voice conversion, speech translation, contrastive self-supervised training, etc. Example:

>>> source_cuts = CutSet(...)
>>> target_cuts = CutSet(...)
>>> sampler = DynamicCutSampler(source_cuts, target_cuts, max_duration=100)
>>> for batch in sampler:
...     assert isinstance(batch, tuple)
...     assert len(batch) == 2
...     assert isinstance(batch[0], CutSet)
...     assert isinstance(batch[1], CutSet)

Note

for cut pairs, triplets, etc. the user is responsible for ensuring that the CutSets are all sorted so that when iterated over sequentially, the items are matched. We take care of preserving the right ordering internally, e.g., when shuffling. By default, we check that the cut IDs are matching, but that can be disabled.

Caution

when using DynamicCutSampler.filter() to filter some cuts with more than one CutSet to sample from, we sample one cut from every CutSet, and expect that all of the cuts satisfy the predicate – otherwise, they are all discarded from being sampled.

__init__(*cuts, max_duration=None, max_cuts=None, constraint=None, shuffle=False, drop_last=False, consistent_ids=True, shuffle_buffer_size=20000, quadratic_duration=None, world_size=None, rank=None, seed=0, strict=None)[source]
Parameters:
  • cuts (Iterable) – one or more CutSets (when more than one, will yield tuples of CutSets as mini-batches)

  • max_duration (Optional[float]) – The maximum total recording duration from cuts. Note: with multiple CutSets, max_duration constraint applies only to the first CutSet.

  • max_cuts (Optional[int]) – The maximum total number of cuts per batch. When only max_duration is specified, this sampler yields static batch sizes.

  • constraint (Optional[SamplingConstraint]) – Provide a SamplingConstraint object defining how the sampler decides when a mini-batch is complete. It also affects which attribute of the input examples decides the “size” of the example (by default it’s .duration). Before this parameter was introduced, Lhotse samplers used TimeConstraint implicitly. Introduced in Lhotse v1.22.0.

  • shuffle (bool) – When True, the cuts will be shuffled dynamically with a reservoir-sampling-based algorithm. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: for epoch in range(10): for batch in dataset: … as every epoch will see a different cuts order.

  • drop_last (bool) – When True, we will drop all incomplete batches. A batch is considered incomplete if it depleted a bucket before hitting the constraint such as max_duration, max_cuts, etc.

  • consistent_ids (bool) – Only affects processing of multiple CutSets. When True, at each sampling step we check cuts from all CutSets have the same ID (i.e., the first cut from every CutSet should have the same ID, same for the second, third, etc.).

  • shuffle_buffer_size (int) – How many cuts (or cut pairs, triplets) are being held in memory a buffer used for streaming shuffling. Larger number means better randomness at the cost of higher memory usage.

  • quadratic_duration (Optional[float]) – When set, it adds an extra penalty that’s quadratic in size w.r.t. a cuts duration. This helps get a more even GPU utilization across different input lengths when models have quadratic input complexity. Set between 15 and 40 for transformers.

  • world_size (Optional[int]) – Total number of distributed nodes. We will try to infer it by default.

  • rank (Optional[int]) – Index of distributed node. We will try to infer it by default.

  • seed (Union[int, Literal['trng', 'randomized']]) – Random seed used to consistently shuffle the dataset across different processes.

state_dict()[source]

Return the current state of the sampler in a state_dict. Together with load_state_dict(), this can be used to restore the training loop’s state to the one stored in the state_dict.

Return type:

Dict[str, Any]

load_state_dict(sd)[source]

Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. :rtype: None

Caution

The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere.

Caution

The input state_dict is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don’t want this behavior, pass a copy inside of this function (e.g., using import deepcopy).

Note

For implementers of sub-classes of CutSampler: the flag self._just_restored_state has to be handled in __iter__ to make it avoid resetting the just-restored state (only once).

property remaining_duration: float | None

Remaining duration of data left in the sampler (may be inexact due to float arithmetic). Not available when the CutSet is read in lazy mode (returns None).

property remaining_cuts: int | None

Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

property num_cuts: int | None

Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

class lhotse.dataset.sampling.DynamicBucketingSampler(*cuts, max_duration=None, max_cuts=None, constraint=None, num_buckets=10, shuffle=False, drop_last=False, consistent_ids=True, duration_bins=None, num_cuts_for_bins_estimate=10000, buffer_size=20000, quadratic_duration=None, world_size=None, rank=None, seed=0, strict=None, shuffle_buffer_size=None)[source]

A dynamic (streaming) variant of BucketingSampler, that doesn’t require reading the whole cut set into memory.

The basic idea is to sample N (e.g. ~10k) cuts and estimate the boundary durations for buckets. Then, we maintain a buffer of M cuts (stored separately in K buckets) and every time we sample a batch, we consume the input cut iterable for the same amount of cuts. The memory consumption is limited by M at all times.

For scenarios such as ASR, VAD, Speaker ID, or TTS training, this class supports single CutSet iteration. Example:

>>> cuts = CutSet(...)
>>> sampler = DynamicBucketingSampler(cuts, max_duration=100)
>>> for batch in sampler:
...     assert isinstance(batch, CutSet)

For other scenarios that require pairs (or triplets, etc.) of utterances, this class supports zipping multiple CutSets together. Such scenarios could be voice conversion, speech translation, contrastive self-supervised training, etc. Example:

>>> source_cuts = CutSet(...)
>>> target_cuts = CutSet(...)
>>> sampler = DynamicBucketingSampler(source_cuts, target_cuts, max_duration=100)
>>> for batch in sampler:
...     assert isinstance(batch, tuple)
...     assert len(batch) == 2
...     assert isinstance(batch[0], CutSet)
...     assert isinstance(batch[1], CutSet)

Note

for cut pairs, triplets, etc. the user is responsible for ensuring that the CutSets are all sorted so that when iterated over sequentially, the items are matched. We take care of preserving the right ordering internally, e.g., when shuffling. By default, we check that the cut IDs are matching, but that can be disabled.

Caution

when using DynamicBucketingSampler.filter() to filter some cuts with more than one CutSet to sample from, we sample one cut from every CutSet, and expect that all of the cuts satisfy the predicate – otherwise, they are all discarded from being sampled.

__init__(*cuts, max_duration=None, max_cuts=None, constraint=None, num_buckets=10, shuffle=False, drop_last=False, consistent_ids=True, duration_bins=None, num_cuts_for_bins_estimate=10000, buffer_size=20000, quadratic_duration=None, world_size=None, rank=None, seed=0, strict=None, shuffle_buffer_size=None)[source]
Parameters:
  • cuts (Iterable) – one or more CutSets (when more than one, will yield tuples of CutSets as mini-batches)

  • max_duration (Optional[float]) – The maximum total recording duration from cuts. Note: with multiple CutSets, max_duration constraint applies only to the first CutSet.

  • max_cuts (Optional[int]) – The maximum total number of cuts per batch. When only max_duration is specified, this sampler yields static batch sizes.

  • num_buckets (Optional[int]) – how many buckets to create.

  • shuffle (bool) – When True, the cuts will be shuffled dynamically with a reservoir-sampling-based algorithm. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: for epoch in range(10): for batch in dataset: … as every epoch will see a different cuts order.

  • drop_last (bool) – When True, we will drop all incomplete batches. A batch is considered incomplete if it depleted a bucket before hitting the constraint such as max_duration, max_cuts, etc.

  • consistent_ids (bool) – Only affects processing of multiple CutSets. When True, at each sampling step we check cuts from all CutSets have the same ID (i.e., the first cut from every CutSet should have the same ID, same for the second, third, etc.).

  • duration_bins (Optional[List[float]]) – A list of floats (seconds); when provided, we’ll skip the initial estimation of bucket duration bins (useful to speed-up the launching of experiments).

  • num_cuts_for_bins_estimate (int) – We will draw this many cuts to estimate the duration bins for creating similar-duration buckets. Larger number means a better estimate to the data distribution, possibly at a longer init cost.

  • buffer_size (int) – How many cuts (or cut pairs, triplets) we hold at any time across all of the buckets. Increasing max_duration (batch_size) or num_buckets might require increasing this number. Larger number here will also improve shuffling capabilities. It will result in larger memory usage.

  • quadratic_duration (Optional[float]) – When set, it adds an extra penalty that’s quadratic in size w.r.t. a cuts duration. This helps get a more even GPU utilization across different input lengths when models have quadratic input complexity. Set between 15 and 40 for transformers.

  • world_size (Optional[int]) – Total number of distributed nodes. We will try to infer it by default.

  • rank (Optional[int]) – Index of distributed node. We will try to infer it by default.

  • seed (Union[int, Literal['trng', 'randomized']]) – Random seed used to consistently shuffle the dataset across different processes.

state_dict()[source]

Return the current state of the sampler in a state_dict. Together with load_state_dict(), this can be used to restore the training loop’s state to the one stored in the state_dict.

Return type:

Dict[str, Any]

load_state_dict(sd)[source]

Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. :rtype: None

Caution

The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere.

Caution

The input state_dict is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don’t want this behavior, pass a copy inside of this function (e.g., using import deepcopy).

Note

For implementers of sub-classes of CutSampler: the flag self._just_restored_state has to be handled in __iter__ to make it avoid resetting the just-restored state (only once).

property remaining_duration: float | None

Remaining duration of data left in the sampler (may be inexact due to float arithmetic). Not available when the CutSet is read in lazy mode (returns None).

property remaining_cuts: int | None

Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

property num_cuts: int | None

Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

class lhotse.dataset.sampling.RoundRobinSampler(*samplers, stop_early=False, randomize=False, seed=0)[source]

RoundRobinSampler takes several samplers as input, and yields a mini-batch of cuts from each of those samplers in turn. E.g., with two samplers, the first mini-batch is from sampler0, the seconds from sampler1, the third from sampler0, and so on. It is helpful for alternating mini-batches from multiple datasets or manually creating batches of different sizes.

The input samplers do not have to provide the same number of batches – when any of the samplers becomes depleted, we continue to iterate the non-depleted samplers, until all of them are exhausted.

Example:

>>> sampler = RoundRobinSampler(
...     SimpleCutSampler(cuts_corpusA, max_cuts=32, shuffle=True),
...     SimpleCutSampler(cuts_corpusB, max_cuts=64, shuffle=True),
... )
>>> for cut in sampler:
...     pass  # profit
__init__(*samplers, stop_early=False, randomize=False, seed=0)[source]

RoundRobinSampler’s constructor.

Parameters:
  • samplers (CutSampler) – The list of samplers from which we sample batches in turns.

  • stop_early (bool) – Should we finish the epoch once any of the samplers becomes depleted. By default, we will keep iterating until all the samplers are exhausted. This setting can be used to balance datasets of different sizes.

  • randomize (Union[bool, List[float]]) – Select the next sampler according to a distribution, instead of in order. If a list of floats is provided, it must contain the same number of elements as the number of samplers, and the values will be used as probabilities. If True is provided, the probabilities will be uniform. If False is provided, the samplers will be selected in order.

  • seed (int) – Random seed used to select the next sampler (only used if randomize is True)

property remaining_duration: float | None

Remaining duration of data left in the sampler (may be inexact due to float arithmetic).

property remaining_cuts: int | None

Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

property num_cuts: int | None

Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

allow_iter_to_reset_state()[source]

Enables re-setting to the start of an epoch when iter() is called. This is only needed in one specific scenario: when we restored previous sampler state via sampler.load_state_dict() but want to discard the progress in the current epoch and start from the beginning.

state_dict()[source]

Return the current state of the sampler in a state_dict. Together with load_state_dict(), this can be used to restore the training loop’s state to the one stored in the state_dict.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]

Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. :rtype: None

Caution

The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere.

Caution

The input state_dict is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don’t want this behavior, pass a copy inside of this function (e.g., using import deepcopy).

Note

For implementers of sub-classes of CutSampler: the flag self._just_restored_state has to be handled in __iter__ to make it avoid resetting the just-restored state (only once).

set_epoch(epoch)[source]

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

epoch (int) – Epoch number.

Return type:

None

filter(predicate)[source]

Add a constraint on individual cuts that has to be satisfied to consider them.

Can be useful when handling large, lazy manifests where it is not feasible to pre-filter them before instantiating the sampler.

Return type:

None

Example:
>>> cuts = CutSet(...)
... sampler = SimpleCutSampler(cuts, max_duration=100.0)
... # Retain only the cuts that have at least 1s and at most 20s duration.
... sampler.filter(lambda cut: 1.0 <= cut.duration <= 20.0)
property diagnostics: SamplingDiagnostics

Info on how many cuts / batches were returned or rejected during iteration.

This property can be overriden by child classes e.g. to merge diagnostics of composite samplers.

get_report()[source]

Returns a string describing the statistics of the sampling process so far.

Return type:

str

class lhotse.dataset.sampling.SimpleCutSampler(cuts, max_duration=None, max_cuts=None, shuffle=False, drop_last=False, world_size=None, rank=None, seed=0)[source]

Samples cuts from a CutSet to satisfy the input constraints. It behaves like an iterable that yields lists of strings (cut IDs).

When one of max_frames, max_samples, or max_duration is specified, the batch size is dynamic. Exactly zero or one of those constraints can be specified. Padding required to collate the batch does not contribute to max frames/samples/duration.

Example usage:

>>> dataset = K2SpeechRecognitionDataset(cuts)
>>> sampler = SimpleCutSampler(cuts, shuffle=True)
>>> loader = DataLoader(dataset, sampler=sampler, batch_size=None)
>>> for epoch in range(start_epoch, n_epochs):
...     sampler.set_epoch(epoch)
...     train(loader)
__init__(cuts, max_duration=None, max_cuts=None, shuffle=False, drop_last=False, world_size=None, rank=None, seed=0)[source]

SimpleCutSampler’s constructor.

Parameters:
  • cuts (CutSet) – the CutSet to sample data from.

  • max_duration (Optional[float]) – The maximum total recording duration from cuts.

  • max_cuts (Optional[int]) – The maximum number of cuts sampled to form a mini-batch. By default, this constraint is off.

  • shuffle (bool) – When True, the cuts will be shuffled at the start of iteration. Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.: for epoch in range(10): for batch in dataset: … as every epoch will see a different cuts order.

  • drop_last (bool) – When True, the last batch is dropped if it’s incomplete.

  • world_size (Optional[int]) – Total number of distributed nodes. We will try to infer it by default.

  • rank (Optional[int]) – Index of distributed node. We will try to infer it by default.

  • seed (int) – Random seed used to consistently shuffle the dataset across different processes.

property remaining_duration: float | None

Remaining duration of data left in the sampler (may be inexact due to float arithmetic). Not available when the CutSet is read in lazy mode (returns None).

property remaining_cuts: int | None

Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

property num_cuts: int | None

Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

state_dict()[source]

Return the current state of the sampler in a state_dict. Together with load_state_dict(), this can be used to restore the training loop’s state to the one stored in the state_dict.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]

Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. :rtype: None

Caution

The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere.

Caution

The input state_dict is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don’t want this behavior, pass a copy inside of this function (e.g., using import deepcopy).

Note

For implementers of sub-classes of CutSampler: the flag self._just_restored_state has to be handled in __iter__ to make it avoid resetting the just-restored state (only once).

class lhotse.dataset.sampling.StatelessSampler(cuts_paths, index_path, base_seed, max_duration=None, max_cuts=None, num_buckets=None, quadratic_duration=None)[source]

An infinite and stateless cut sampler that selects data at random from one or more cut manifests. The main idea is to make training resumption easy while guaranteeing the data seen each time by the model is shuffled differently. It discards the notion of an “epoch” and it never finishes iteration. It makes no strong guarantees about avoiding data duplication, but in practice you would rarely see duplicated data.

The recommended way to use this sampler is by placing it into a dataloader worker subprocess with Lhotse’s :class:~lhotse.dataset.iterable_dataset.IterableDatasetWrapper, so that each worker has its own sampler replica that uses a slightly different random seed:

>>> import torch
>>> import lhotse
>>> dloader = torch.utils.data.DataLoader(
...     lhotse.dataset.iterable_dataset.IterableDatasetWrapper(
...         dataset=lhotse.dataset.K2SpeechRecognitionDataset(...),
...         sampler=StatelessSampler(...),
...     ),
...     batch_size=None,
...     num_workers=4,
... )

This sampler’s design was originally proposed by Dan Povey. For details see: https://github.com/lhotse-speech/lhotse/issues/1096

Example 1: Get a non-bucketing :class:.StatelessSampler:

>>> sampler = StatelessSampler(
...     cuts_paths=["data/cuts_a.jsonl", "data/cuts_b.jsonl"],
...     index_path="data/files.idx",
...     max_duration=600.0,
... )

Example 2: Get a bucketing :class:.StatelessSampler:

>>> sampler = StatelessSampler(
...     cuts_paths=["data/cuts_a.jsonl", "data/cuts_b.jsonl"],
...     index_path="data/files.idx",
...     max_duration=600.0,
...     num_buckets=50,
...     quadratic_duration=30.0,
... )

Example 3: Get a bucketing :class:.StatelessSampler with scaled weights for each cutset:

>>> sampler = StatelessSampler(
...     cuts_paths=[
...         ("data/cuts_a.jsonl", 2.0),
...         ("data/cuts_b.jsonl", 1.0),
...     ],
...     index_path="data/files.idx",
...     max_duration=600.0,
...     num_buckets=50,
...     quadratic_duration=30.0,
... )

Note

This sampler works only with uncompressed jsonl manifests, as it creates extra index files with line byte offsets to quickly find and sample JSON lines. This means this sampler will not work with Webdataset and Lhotse Shar data format.

Parameters:
  • cuts_paths (Union[Path, str, Iterable[Union[Path, str]], Iterable[Tuple[Union[Path, str], float]]]) – Path, or list of paths, or list of tuples of (path, scale) to cutset files.

  • index_path (Union[Path, str]) – Path to a file that contains the index of all cutsets and their line count (will be auto-created the first time this object is initialized).

  • base_seed (int) – Int, user-provided part of the seed used to initialize the RNG for sampling (each node and worker are still going to produce different results). When continuing the training it should be a function of the number of training steps to ensure the model doesn’t see identical mini-batches again.

  • max_duration (Optional[float]) – Maximum total number of audio seconds in a mini-batch (dynamic batch size).

  • max_cuts (Optional[int]) – Maximum number of examples in a mini-batch (static batch size).

  • num_buckets (Optional[int]) – If set, enables bucketing (each mini-batch has examples of a similar duration).

  • quadratic_duration (Optional[float]) – If set, adds a penalty term for longer duration cuts. Works well with models that have quadratic time complexity to keep GPU utilization similar when using bucketing. Suggested values are between 30 and 45.

__init__(cuts_paths, index_path, base_seed, max_duration=None, max_cuts=None, num_buckets=None, quadratic_duration=None)[source]
map(fn)[source]

Apply fn to each mini-batch of CutSet before yielding it.

Return type:

StatelessSampler

state_dict()[source]

Stub state_dict method that returns nothing - this sampler is stateless.

Return type:

Dict

load_state_dict(state_dict)[source]

Stub load_state_dict method that does nothing - this sampler is stateless.

Return type:

None

get_report()[source]

Returns a string describing the statistics of the sampling process so far.

Return type:

str

class lhotse.dataset.sampling.ZipSampler(*samplers, merge_batches=True)[source]

ZipSampler takes several samplers as input and concatenates their sampled mini-batch cuts together into a single CutSet, or returns a tuple of the mini-batch CutSets. It is helpful for ensuring that each batch consists of some proportion of cuts coming from different sources.

The input samplers do not have to provide the same number of batches – when any of the samplers becomes depleted, the iteration will stop (like with Python’s zip() function).

Example:

>>> sampler = ZipSampler(
...     SimpleCutSampler(cuts_corpusA, max_duration=250, shuffle=True),
...     SimpleCutSampler(cuts_corpusB, max_duration=100, shuffle=True),
... )
>>> for cut in sampler:
...     pass  # profit
__init__(*samplers, merge_batches=True)[source]

ZipSampler’s constructor.

Parameters:
  • samplers (CutSampler) – The list of samplers from which we sample batches together.

  • merge_batches (bool) – Should we merge the batches from each sampler into a single CutSet, or return a tuple of CutSets. Setting this to False makes ZipSampler behave more like Python’s zip function.

property remaining_duration: float | None

Remaining duration of data left in the sampler (may be inexact due to float arithmetic).

property remaining_cuts: int | None

Remaining number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

property num_cuts: int | None

Total number of cuts in the sampler. Not available when the CutSet is read in lazy mode (returns None).

allow_iter_to_reset_state()[source]

Enables re-setting to the start of an epoch when iter() is called. This is only needed in one specific scenario: when we restored previous sampler state via sampler.load_state_dict() but want to discard the progress in the current epoch and start from the beginning.

state_dict()[source]

Return the current state of the sampler in a state_dict. Together with load_state_dict(), this can be used to restore the training loop’s state to the one stored in the state_dict.

Return type:

Dict[str, Any]

load_state_dict(state_dict)[source]

Restore the state of the sampler that is described in a state_dict. This will result in the sampler yielding batches from where the previous training left it off. :rtype: None

Caution

The samplers are expected to be initialized with the same CutSets, but this is not explicitly checked anywhere.

Caution

The input state_dict is being mutated: we remove each consumed key, and expect it to be empty at the end of loading. If you don’t want this behavior, pass a copy inside of this function (e.g., using import deepcopy).

Note

For implementers of sub-classes of CutSampler: the flag self._just_restored_state has to be handled in __iter__ to make it avoid resetting the just-restored state (only once).

set_epoch(epoch)[source]

Sets the epoch for this sampler. When shuffle=True, this ensures all replicas use a different random ordering for each epoch. Otherwise, the next iteration of this sampler will yield the same ordering.

Parameters:

epoch (int) – Epoch number.

Return type:

None

filter(predicate)[source]

Add a constraint on individual cuts that has to be satisfied to consider them.

Can be useful when handling large, lazy manifests where it is not feasible to pre-filter them before instantiating the sampler.

Return type:

None

Example:
>>> cuts = CutSet(...)
... sampler = SimpleCutSampler(cuts, max_duration=100.0)
... # Retain only the cuts that have at least 1s and at most 20s duration.
... sampler.filter(lambda cut: 1.0 <= cut.duration <= 20.0)
property diagnostics: SamplingDiagnostics

Info on how many cuts / batches were returned or rejected during iteration.

This property can be overriden by child classes e.g. to merge diagnostics of composite samplers.

get_report()[source]

Returns a string describing the statistics of the sampling process so far.

Return type:

str

lhotse.dataset.sampling.find_pessimistic_batches(sampler, batch_tuple_index=0)[source]

Function for finding ‘pessimistic’ batches, i.e. batches that have the highest potential to blow up the GPU memory during training. We will fully iterate the sampler and record the most risky batches under several criteria: - single longest cut - single longest supervision - largest batch cuts duration - largest batch supervisions duration - max num cuts - max num supervisions

Example of how this function can be used with a PyTorch model and a K2SpeechRecognitionDataset:

sampler = SimpleCutSampler(cuts, max_duration=300)
dataset = K2SpeechRecognitionDataset()
batches, scores = find_pessimistic_batches(sampler)
for reason, cuts in batches.items():
    try:
        batch = dset[cuts]
        outputs = model(batch)
        loss = loss_fn(outputs)
        loss.backward()
    except:
        print(f"Exception caught when evaluating pessimistic batch for: {reason}={scores[reason]}")
        raise
Parameters:
  • sampler (CutSampler) – An instance of a Lhotse CutSampler.

  • batch_tuple_index (int) – Applicable to samplers that return tuples of CutSet. Indicates which position in the tuple we should look up for the CutSet.

Return type:

Tuple[Dict[str, CutSet], Dict[str, float]]

Returns:

A tuple of dicts: the first with batches (as CutSets) and the other with criteria values, i.e.: ({"<criterion>": <CutSet>, ...}, {"<criterion>": <value>, ...})

lhotse.dataset.sampling.report_padding_ratio_estimate(sampler, n_samples=1000)[source]

Returns a human-readable string message about amount of padding diagnostics. Assumes that padding corresponds to segments without any supervision within cuts.

Return type:

str

Input strategies’ list

class lhotse.dataset.input_strategies.BatchIO(num_workers=0, executor_type=<class 'concurrent.futures.thread.ThreadPoolExecutor'>)[source]

Converts a CutSet into a collated batch of audio representations. These representations can be e.g. audio samples or features. They might also be single or multi channel.

All InputStrategies support the executor parameter in the constructor. It allows to pass a ThreadPoolExecutor or a ProcessPoolExecutor to parallelize reading audio/features from wherever they are stored. Note that this approach is incompatible with specifying the num_workers to torch.utils.data.DataLoader, but in some instances may be faster.

Note

This is a base class that only defines the interface.

__call__(cuts)[source]

Returns a tensor with collated input signals, and a tensor of length of each signal before padding.

Return type:

Tuple[Tensor, IntTensor]

__init__(num_workers=0, executor_type=<class 'concurrent.futures.thread.ThreadPoolExecutor'>)[source]
supervision_intervals(cuts)[source]

Returns a dict that specifies the start and end bounds for each supervision, as a 1-D int tensor.

Depending on the strategy, the dict should look like:

{
    "sequence_idx": tensor(shape=(S,)),
    "start_frame": tensor(shape=(S,)),
    "num_frames": tensor(shape=(S,)),
}

or

{
    "sequence_idx": tensor(shape=(S,)),
    "start_sample": tensor(shape=(S,)),
    "num_samples": tensor(shape=(S,))
}

Where S is the total number of supervisions encountered in the CutSet. Note that S might be different than the number of cuts (B). sequence_idx means the index of the corresponding feature matrix (or cut) in a batch.

Return type:

Dict[str, Tensor]

supervision_masks(cuts)[source]

Returns a collated batch of masks, marking the supervised regions in cuts. They are zero-padded to the longest cut.

Depending on the strategy implementation, it is expected to be a tensor of shape (B, NF) or (B, NS), where B denotes the number of cuts, NF the number of frames and NS the total number of samples. NF and NS are determined by the longest cut in a batch.

Return type:

Tensor

class lhotse.dataset.input_strategies.PrecomputedFeatures(num_workers=0, executor_type=<class 'concurrent.futures.thread.ThreadPoolExecutor'>)[source]

InputStrategy that reads pre-computed features, whose manifests are attached to cuts, from disk.

It automatically pads the feature matrices so that every example has the same number of frames as the longest cut in a mini-batch. This is needed to put all examples into a single tensor. The padding value is a low log-energy, around log(1e-10).

__call__(cuts)[source]

Reads the pre-computed features from disk/other storage. The returned shape is (B, T, F) => (batch_size, num_frames, num_features).

Return type:

Tuple[Tensor, Tensor]

Returns:

a tensor with collated features, and a tensor of num_frames of each cut before padding.

supervision_intervals(cuts)[source]

Returns a dict that specifies the start and end bounds for each supervision, as a 1-D int tensor, in terms of frames:

{
    "sequence_idx": tensor(shape=(S,)),
    "start_frame": tensor(shape=(S,)),
    "num_frames": tensor(shape=(S,))
}

Where S is the total number of supervisions encountered in the CutSet. Note that S might be different than the number of cuts (B). sequence_idx means the index of the corresponding feature matrix (or cut) in a batch.

Return type:

Dict[str, Tensor]

supervision_masks(cuts, use_alignment_if_exists=None)[source]

Returns the mask for supervised frames.

Parameters:

use_alignment_if_exists (Optional[str]) – optional str, key for alignment type to use for generating the mask. If not exists, fall back on supervision time spans.

Return type:

Tensor

class lhotse.dataset.input_strategies.AudioSamples(num_workers=0, fault_tolerant=False, executor_type=<class 'concurrent.futures.thread.ThreadPoolExecutor'>)[source]

InputStrategy that reads single-channel recordings, whose manifests are attached to cuts, from disk (or other audio source).

It automatically zero-pads the recordings so that every example has the same number of audio samples as the longest cut in a mini-batch. This is needed to put all examples into a single tensor.

__call__(cuts, recording_field=None)[source]

Reads the audio samples from recordings on disk/other storage. The returned shape is (B, T) => (batch_size, num_samples).

Return type:

Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, CutSet]]

Returns:

a tensor with collated audio samples, and a tensor of num_samples of each cut before padding.

Parameters:

recording_field (Optional[str]) – when specified, we will try to load recordings from a custom field with this name (i.e., cut.load_<recording_field>() instead of default cut.load_audio()).

__init__(num_workers=0, fault_tolerant=False, executor_type=<class 'concurrent.futures.thread.ThreadPoolExecutor'>)[source]

AudioSamples constructor.

Parameters:
  • num_workers (int) – when larger than 0, we will spawn an executor (of type specified by executor_type) to read the audio data in parallel. Thread executor can be used with PyTorch’s DataLoader, whereas Process executor would fail (but could be faster for other applications).

  • fault_tolerant (bool) – when True, the cuts for which audio loading failed will be skipped. It will make __call__ return an additional item, which is the CutSet for which we successfully read the audio. It may be a subset of the input CutSet.

  • executor_type (Type[TypeVar(ExecutorType, bound= Executor)]) – the type of executor used for parallel audio reads (only relevant when num_workers>0).

supervision_intervals(cuts)[source]

Returns a dict that specifies the start and end bounds for each supervision, as a 1-D int tensor, in terms of samples:

{
    "sequence_idx": tensor(shape=(S,)),
    "start_sample": tensor(shape=(S,)),
    "num_samples": tensor(shape=(S,))
}

Where S is the total number of supervisions encountered in the CutSet. Note that S might be different than the number of cuts (B). sequence_idx means the index of the corresponding feature matrix (or cut) in a batch.

Return type:

Dict[str, Tensor]

supervision_masks(cuts, use_alignment_if_exists=None)[source]

Returns the mask for supervised samples.

Parameters:

use_alignment_if_exists (Optional[str]) – optional str, key for alignment type to use for generating the mask. If not exists, fall back on supervision time spans.

Return type:

Tensor

class lhotse.dataset.input_strategies.OnTheFlyFeatures(extractor, wave_transforms=None, num_workers=0, use_batch_extract=True, fault_tolerant=False, return_audio=False, executor_type=<class 'concurrent.futures.thread.ThreadPoolExecutor'>)[source]

InputStrategy that reads single-channel recordings, whose manifests are attached to cuts, from disk (or other audio source). Then, it uses a FeatureExtractor to compute their features on-the-fly.

It automatically pads the feature matrices so that every example has the same number of frames as the longest cut in a mini-batch. This is needed to put all examples into a single tensor. The padding value is a low log-energy, around log(1e-10).

__call__(cuts, recording_field=None)[source]

Reads the audio samples from recordings on disk/other storage and computes their features. The returned shape is (B, T, F) => (batch_size, num_frames, num_features).

Parameters:

recording_field (Optional[str]) – when specified, we will try to load recordings from a custom field with this name (i.e., cut.load_<recording_field>() instead of default cut.load_audio()).

Return type:

Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, CutSet]]

Returns:

a tuple of objcets: (feats, feat_lens, [audios, audio_lens], [cuts]). Tensors audios and audio_lens are returned when return_audio=True. CutSet cuts is returned when fault_tolerant=True.

__init__(extractor, wave_transforms=None, num_workers=0, use_batch_extract=True, fault_tolerant=False, return_audio=False, executor_type=<class 'concurrent.futures.thread.ThreadPoolExecutor'>)[source]

OnTheFlyFeatures’ constructor.

Parameters:
  • extractor (FeatureExtractor) – the feature extractor used on-the-fly (individually on each waveform).

  • wave_transforms (Optional[List[Callable[[Tensor], Tensor]]]) – an optional list of transforms applied on the batch of audio waveforms collated into a single tensor, right before the feature extraction.

  • num_workers (int) – when larger than 0, we will spawn an executor (of type specified by executor_type) to read the audio data in parallel. Thread executor can be used with PyTorch’s DataLoader, whereas Process executor would fail (but could be faster for other applications).

  • use_batch_extract (bool) – when True, we will call extract_batch() to compute the features as it is possibly faster. It has a restriction that all cuts must have the same sampling rate. If that is not the case, set this to False.

  • fault_tolerant (bool) – when True, the cuts for which audio loading failed will be skipped. It will make __call__ return an additional item, which is the CutSet for which we successfully read the audio. It may be a subset of the input CutSet.

  • return_audio (bool) – When True, calling this object will additionally return collated audio tensor and audio lengths tensor.

  • executor_type (Type[TypeVar(ExecutorType, bound= Executor)]) – the type of executor used for parallel audio reads (only relevant when num_workers>0).

supervision_intervals(cuts)[source]

Returns a dict that specifies the start and end bounds for each supervision, as a 1-D int tensor, in terms of frames:

{
    "sequence_idx": tensor(shape=(S,)),
    "start_frame": tensor(shape=(S,)),
    "num_frames": tensor(shape=(S,))
}

Where S is the total number of supervisions encountered in the CutSet. Note that S might be different than the number of cuts (B). sequence_idx means the index of the corresponding feature matrix (or cut) in a batch.

Return type:

Dict[str, Tensor]

supervision_masks(cuts, use_alignment_if_exists=None)[source]

Returns the mask for supervised samples.

Parameters:

use_alignment_if_exists (Optional[str]) – optional str, key for alignment type to use for generating the mask. If not exists, fall back on supervision time spans.

Return type:

Tensor

Augmentation - transforms on cuts

Some transforms, in order for us to have accurate information about the start and end times of the signal and its supervisions, have to be performed on cuts (or CutSets).

class lhotse.dataset.cut_transforms.CutConcatenate(gap=1.0, duration_factor=1.0, max_duration=None)[source]

A transform on batch of cuts (CutSet) that concatenates the cuts to minimize the total amount of padding; e.g. instead of creating a batch with 40 examples, we will merge some of the examples together adding some silence between them to avoid a large number of padding frames that waste the computation.

__init__(gap=1.0, duration_factor=1.0, max_duration=None)[source]

CutConcatenate’s constructor.

Parameters:
  • gap (float) – The duration of silence in seconds that is inserted between the cuts; it’s goal is to let the model “know” that there are separate utterances in a single example.

  • duration_factor (float) – Determines the maximum duration of the concatenated cuts; by default it’s 1, setting the limit at the duration of the longest cut in the batch.

  • max_duration (Optional[float]) – If a value is given (in seconds), the maximum duration of concatenated cuts is fixed to the value while duration_factor is ignored.

class lhotse.dataset.cut_transforms.CutMix(cuts, snr=(10, 20), p=0.5, pad_to_longest=True, preserve_id=False, seed=42, random_mix_offset=False)[source]

A transform for batches of cuts (CutSet’s) that stochastically performs noise augmentation with a constant or varying SNR.

__init__(cuts, snr=(10, 20), p=0.5, pad_to_longest=True, preserve_id=False, seed=42, random_mix_offset=False)[source]

CutMix’s constructor.

Parameters:
  • cuts (CutSet) – a CutSet containing augmentation data, e.g. noise, music, babble.

  • snr (Union[float, Tuple[float, float], None]) – either a float, a pair (range) of floats, or None. It determines the SNR of the speech signal vs the noise signal that’s mixed into it. When a range is specified, we will uniformly sample SNR in that range. When it’s None, the noise will be mixed as-is – i.e. without any level adjustment. Note that it’s different from snr=0, which will adjust the noise level so that the SNR is 0.

  • pad_to_longest (bool) – when True, each processed CutSet will be padded with noise to match the duration of the longest Cut in a batch.

  • preserve_id (bool) – When True, preserves the IDs the cuts had before augmentation. Otherwise, new random IDs are generated for the augmented cuts (default).

  • seed (Union[int, Literal['trng', 'randomized'], Random]) – an optional int or “trng”. Random seed for choosing the cuts to mix and the SNR. If “trng” is provided, we’ll use the secrets module for non-deterministic results on each iteration. You can also directly pass a random.Random instance here.

  • random_mix_offset (bool) – an optional bool. When True and the duration of the to be mixed in cut in longer than the original cut, select a random sub-region from the to be mixed in cut.

class lhotse.dataset.cut_transforms.ExtraPadding(extra_frames=None, extra_samples=None, extra_seconds=None, pad_feat_value=-23.025850929940457, randomized=False, preserve_id=False, direction='both')[source]

A transform on batch of cuts (CutSet) that adds a number of extra context frames/samples/seconds on both sides of the cut. Exactly one type of duration has to specified in the constructor.

It is intended mainly for training frame-synchronous ASR models with convolutional layers to avoid using padding inside of the hidden layers, by giving the model larger context in the input. Another useful application is to shift the input by a little, so that the data seen after frame subsampling is a bit different, which makes this a data augmentation technique.

This is best used as the first transform in the transform list for dataset - it will ensure that each individual cut gets extra context before concatenation, or that it will be filled with noise, etc.

__init__(extra_frames=None, extra_samples=None, extra_seconds=None, pad_feat_value=-23.025850929940457, randomized=False, preserve_id=False, direction='both')[source]

ExtraPadding’s constructor.

Parameters:
  • extra_frames (Optional[int]) – The total number of frames to add to each cut. We will add half that number on each side of the cut (“both” directions padding).

  • extra_samples (Optional[int]) – The total number of samples to add to each cut. We will add half that number on each side of the cut (“both” directions padding).

  • extra_seconds (Optional[float]) – The total duration in seconds to add to each cut. We will add half that number on each side of the cut (“both” directions padding).

  • pad_feat_value (float) – When padding a cut with precomputed features, what value should be used for padding (the default is a very low log-energy).

  • randomized (bool) – When True, we will sample a value from a uniform distribution of [0, extra_X] for each cut (for samples/frames – sample an int, for duration – sample a float).

  • preserve_id (bool) – When True, preserves the IDs the cuts had before augmentation. Otherwise, new random IDs are generated for the augmented cuts (default).

  • direction (str) – The padding direction.

class lhotse.dataset.cut_transforms.PerturbSpeed(factors, p, randgen=None, preserve_id=False)[source]

A transform on batch of cuts (CutSet) that perturbs the speed of the recordings with a given probability p.

If the effect is applied, then one of the perturbation factors from the constructor’s factors parameter is sampled with uniform probability.

__init__(factors, p, randgen=None, preserve_id=False)[source]
class lhotse.dataset.cut_transforms.PerturbTempo(factors, p, randgen=None, preserve_id=False)[source]

A transform on batch of cuts (CutSet) that perturbs the tempo of the recordings with a given probability p.

If the effect is applied, then one of the perturbation factors from the constructor’s factors parameter is sampled with uniform probability.

__init__(factors, p, randgen=None, preserve_id=False)[source]
class lhotse.dataset.cut_transforms.PerturbVolume(p, scale_low=0.125, scale_high=2.0, randgen=None, preserve_id=False)[source]

A transform on batch of cuts (CutSet) that perturbs the volume of the recordings with a given probability p.

If the effect is applied, then one of the perturbation factors from the constructor’s factors parameter is sampled with uniform probability.

__init__(p, scale_low=0.125, scale_high=2.0, randgen=None, preserve_id=False)[source]
class lhotse.dataset.cut_transforms.ReverbWithImpulseResponse(rir_recordings=None, p=0.5, normalize_output=True, randgen=None, preserve_id=False, early_only=False, rir_channels=[0])[source]

A transform on batch of cuts (CutSet) that convolves each cut with an impulse response with some probability p. The impulse response is chosen randomly from a specified CutSet of RIRs rir_cuts. If no RIRs are specified, we will generate them using a fast random generator (https://arxiv.org/abs/2208.04101). If early_only is set to True, convolution is performed only with the first 50ms of the impulse response.

__init__(rir_recordings=None, p=0.5, normalize_output=True, randgen=None, preserve_id=False, early_only=False, rir_channels=[0])[source]

Augmentation - transforms on signals

These transforms work directly on batches of collated feature matrices (or possibly raw waveforms, if applicable).

class lhotse.dataset.signal_transforms.GlobalMVN(feature_dim)[source]

Apply global mean and variance normalization

__init__(feature_dim)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

classmethod from_cuts(cuts, max_cuts=None, extractor=None)[source]
Return type:

GlobalMVN

classmethod from_file(stats_file)[source]
Return type:

GlobalMVN

to_file(stats_file)[source]
forward(features, supervision_segments=None)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

inverse(features)[source]
Return type:

Tensor

class lhotse.dataset.signal_transforms.SpecAugment(time_warp_factor=80, num_feature_masks=2, features_mask_size=27, num_frame_masks=10, frames_mask_size=100, max_frames_mask_fraction=0.15, p=0.9)[source]

SpecAugment performs three augmentations: - time warping of the feature matrix - masking of ranges of features (frequency bands) - masking of ranges of frames (time)

The current implementation works with batches, but processes each example separately in a loop rather than simultaneously to achieve different augmentation parameters for each example.

__init__(time_warp_factor=80, num_feature_masks=2, features_mask_size=27, num_frame_masks=10, frames_mask_size=100, max_frames_mask_fraction=0.15, p=0.9)[source]

SpecAugment’s constructor.

Parameters:
  • time_warp_factor (Optional[int]) – parameter for the time warping; larger values mean more warping. Set to None, or less than 1, to disable.

  • num_feature_masks (int) – how many feature masks should be applied. Set to 0 to disable.

  • features_mask_size (int) – the width of the feature mask (expressed in the number of masked feature bins). This is the F parameter from the SpecAugment paper.

  • num_frame_masks (int) – the number of masking regions for utterances. Set to 0 to disable.

  • frames_mask_size (int) – the width of the frame (temporal) masks (expressed in the number of masked frames). This is the T parameter from the SpecAugment paper.

  • max_frames_mask_fraction (float) – limits the size of the frame (temporal) mask to this value times the length of the utterance (or supervision segment). This is the parameter denoted by p in the SpecAugment paper.

  • p – the probability of applying this transform. It is different from p in the SpecAugment paper!

forward(features, supervision_segments=None, *args, **kwargs)[source]

Computes SpecAugment for a batch of feature matrices.

Since the batch will usually already be padded, the user can optionally provide a supervision_segments tensor that will be used to apply SpecAugment only to selected areas of the input. The format of this input is described below.

Parameters:
  • features (Tensor) – a batch of feature matrices with shape (B, T, F).

  • supervision_segments (Optional[IntTensor]) – an int tensor of shape (S, 3). S is the number of supervision segments that exist in features – there may be either less or more than the batch size. The second dimension encoder three kinds of information: the sequence index of the corresponding feature matrix in features, the start frame index, and the number of frames for each segment.

Return type:

Tensor

Returns:

an augmented tensor of shape (B, T, F).

state_dict(**kwargs)[source]

Return a dictionary containing references to the whole state of the module.

Both parameters and persistent buffers (e.g. running averages) are included. Keys are corresponding parameter and buffer names. Parameters and buffers set to None are not included. :rtype: Dict[str, Any]

Note

The returned object is a shallow copy. It contains references to the module’s parameters and buffers.

Warning

Currently state_dict() also accepts positional arguments for destination, prefix and keep_vars in order. However, this is being deprecated and keyword arguments will be enforced in future releases.

Warning

Please avoid the use of argument destination as it is not designed for end-users.

Args:
destination (dict, optional): If provided, the state of module will

be updated into the dict and the same object is returned. Otherwise, an OrderedDict will be created and returned. Default: None.

prefix (str, optional): a prefix added to parameter and buffer

names to compose the keys in state_dict. Default: ''.

keep_vars (bool, optional): by default the Tensor s

returned in the state dict are detached from autograd. If it’s set to True, detaching will not be performed. Default: False.

Returns:
dict:

a dictionary containing a whole state of the module

Example:

>>> # xdoctest: +SKIP("undefined vars")
>>> module.state_dict().keys()
['bias', 'weight']
load_state_dict(state_dict)[source]

Copy parameters and buffers from state_dict into this module and its descendants.

If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Warning

If assign is True the optimizer must be created after the call to load_state_dict.

Args:
state_dict (dict): a dict containing parameters and

persistent buffers.

strict (bool, optional): whether to strictly enforce that the keys

in state_dict match the keys returned by this module’s state_dict() function. Default: True

assign (bool, optional): whether to assign items in the state

dictionary to their corresponding keys in the module instead of copying them inplace into the module’s current parameters and buffers. When False, the properties of the tensors in the current module are preserved while when True, the properties of the Tensors in the state dict are preserved. Default: False

Returns:
NamedTuple with missing_keys and unexpected_keys fields:
  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Note:

If a parameter or buffer is registered as None and its corresponding key exists in state_dict, load_state_dict() will raise a RuntimeError.

class lhotse.dataset.signal_transforms.RandomizedSmoothing(sigma=0.1, sample_sigma=True, p=0.3)[source]

Randomized smoothing - gaussian noise added to an input waveform, or a batch of waveforms. The summed audio is clipped to [-1.0, 1.0] before returning.

__init__(sigma=0.1, sample_sigma=True, p=0.3)[source]

RandomizedSmoothing’s constructor.

Parameters:
  • sigma (Union[float, Sequence[Tuple[int, float]]]) – standard deviation of the gaussian noise. Either a constant float, or a schedule, i.e. a list of tuples that specify which value to use from which step. For example, [(0, 0.01), (1000, 0.1)] means that from steps 0-999, the sigma value will be 0.01, and from step 1000 onwards, it will be 0.1.

  • sample_sigma (bool) – when False, then sigma is used as the standard deviation in each forward step. When True, the standard deviation is sampled from a uniform distribution of [-sigma, sigma] for each forward step.

  • p (float) – the probability of applying this transform.

forward(audio, *args, **kwargs)[source]

Define the computation performed at every call.

Should be overridden by all subclasses. :rtype: Tensor

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class lhotse.dataset.signal_transforms.DereverbWPE(n_fft=512, hop_length=128)[source]

Dereverberation with Weighted Prediction Error (WPE). The implementation and default values are borrowed from nara_wpe package: https://github.com/fgnt/nara_wpe

The method and library are described in the following paper: https://groups.uni-paderborn.de/nt/pubs/2018/ITG_2018_Drude_Paper.pdf

__init__(n_fft=512, hop_length=128)[source]

Initialize internal Module state, shared by both nn.Module and ScriptModule.

forward(audio, *args, **kwargs)[source]

Expects audio to be 2D or 3D tensor. 2D means a batch of single-channel audio, shape (B, T). 3D means a batch of multi-channel audio, shape (B, D, T). B => batch size; D => number of channels; T => number of audio samples.

Return type:

Tensor

Collation utilities for building custom Datasets

class lhotse.dataset.collation.TokenCollater(cuts, add_eos=True, add_bos=True, pad_symbol='<pad>', bos_symbol='<bos>', eos_symbol='<eos>', unk_symbol='<unk>')[source]

Collate list of tokens

Map sentences to integers. Sentences are padded to equal length. Beginning and end-of-sequence symbols can be added. Call .inverse(tokens_batch, tokens_lens) to reconstruct batch as string sentences.

Example:
>>> token_collater = TokenCollater(cuts)
>>> tokens_batch, tokens_lens = token_collater(cuts.subset(first=32))
>>> original_sentences = token_collater.inverse(tokens_batch, tokens_lens)
Returns:
tokens_batch: IntTensor of shape (B, L)

B: batch dimension, number of input sentences L: length of the longest sentence

tokens_lens: IntTensor of shape (B,)

Length of each sentence after adding <eos> and <bos> but before padding.

__init__(cuts, add_eos=True, add_bos=True, pad_symbol='<pad>', bos_symbol='<bos>', eos_symbol='<eos>', unk_symbol='<unk>')[source]
inverse(tokens_batch, tokens_lens)[source]
Return type:

List[str]

lhotse.dataset.collation.collate_features(cuts, pad_direction='right', executor=None)[source]

Load features for all the cuts and return them as a batch in a torch tensor. The output shape is (batch, time, features). The cuts will be padded with silence if necessary.

Parameters:
  • cuts (CutSet) – a CutSet used to load the features.

  • pad_direction (str) – where to apply the padding (right, left, or both).

  • executor (Optional[Executor]) – an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided, we will use it to read the features concurrently.

Return type:

Tuple[Tensor, Tensor]

Returns:

a tuple of tensors (features, features_lens).

lhotse.dataset.collation.collate_audio(cuts, pad_direction='right', executor=None, fault_tolerant=False, recording_field=None)[source]

Load audio samples for all the cuts and return them as a batch in a torch tensor. The output shape is (batch, time). The cuts will be padded with silence if necessary.

Parameters:
  • cuts (CutSet) – a CutSet used to load the audio samples.

  • pad_direction (str) – where to apply the padding (right, left, or both).

  • executor (Optional[Executor]) – an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided, we will use it to read audio concurrently.

  • fault_tolerant (bool) – when True, the cuts for which audio loading failed will be skipped. Setting this parameter will cause the function to return a 3-tuple, where the third element is a CutSet for which the audio data were sucessfully read.

  • recording_field (Optional[str]) – when specified, we will try to load recordings from a custom field with this name (i.e., cut.load_<recording_field>() instead of default cut.load_audio()).

Return type:

Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, CutSet]]

Returns:

a tuple of tensors (audio, audio_lens), or (audio, audio_lens, cuts).

lhotse.dataset.collation.collate_multi_channel_audio(cuts, pad_direction='right', executor=None, fault_tolerant=False, recording_field=None)

Load audio samples for all the cuts and return them as a batch in a torch tensor. The output shape is (batch, time). The cuts will be padded with silence if necessary.

Parameters:
  • cuts (CutSet) – a CutSet used to load the audio samples.

  • pad_direction (str) – where to apply the padding (right, left, or both).

  • executor (Optional[Executor]) – an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided, we will use it to read audio concurrently.

  • fault_tolerant (bool) – when True, the cuts for which audio loading failed will be skipped. Setting this parameter will cause the function to return a 3-tuple, where the third element is a CutSet for which the audio data were sucessfully read.

  • recording_field (Optional[str]) – when specified, we will try to load recordings from a custom field with this name (i.e., cut.load_<recording_field>() instead of default cut.load_audio()).

Return type:

Union[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor, CutSet]]

Returns:

a tuple of tensors (audio, audio_lens), or (audio, audio_lens, cuts).

lhotse.dataset.collation.collate_video(cuts, pad_direction='right', executor=None, fault_tolerant=False)[source]

Load video and audio for all cuts and return them as a batch in torch tensors. The output video shape is (batch, time, channel, height, width). The output audio shape is (batch, channel, time). The cuts will be padded with silence if necessary.

Note

We expect each video to contain audio and the same number of audio channels. We may support padding missing channels at a later time.

Parameters:
  • cuts (CutSet) – a CutSet used to load the audio samples.

  • pad_direction (str) – where to apply the padding (right, left, or both).

  • executor (Optional[Executor]) – an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided, we will use it to read video concurrently.

  • fault_tolerant (bool) – when True, the cuts for which video/audio loading failed will be skipped. Setting this parameter will cause the function to return a 5-tuple, where the fifth element is a CutSet for which the audio data were sucessfully read.

Returns:

a tuple of tensors (video, video_lens, audio, audio_lens), or (video, video_lens, audio, audio_lens, cuts).

lhotse.dataset.collation.collate_custom_field(cuts, field, pad_value=None, pad_direction='right')[source]

Load custom arrays for all the cuts and return them as a batch in a torch tensor. The output shapes are:

  • (batch, d0, d1, d2, ...) for lhotse.array.Array of shape (d0, d1, d2, ...).

    Note: all arrays have to be of the same shape, as we expect these represent fixed-size embeddings.

  • (batch, d0, pad_dt, d1, ...) for lhotse.array.TemporalArray of shape

    (d0, dt, d1, ...) where dt indicates temporal dimension (variable-sized), and pad_dt indicates temporal dimension after padding (equal-sized for all cuts). We expect these represent temporal data, such as alignments, posteriors, features, etc.

  • (batch, ) for anything else, such as int or float: we will simply stack them into

    a list and tensorize it.

Note

This function disregards the frame_shift attribute of lhotse.array.TemporalArray when padding; it simply pads all the arrays to the longest one found in the mini-batch. Because of that, the function will work correctly even if the user supplied inconsistent meta-data.

Note

Temporal arrays of integer type that are smaller than torch.int64, will be automatically promoted to torch.int64.

Parameters:
  • cuts (CutSet) – a CutSet used to load the features.

  • field (str) – name of the custom field to be retrieved.

  • pad_value (Union[None, int, float]) – value to be used for padding the temporal arrays. Ignored for non-temporal array and non-array attributes.

  • pad_direction (str) – where to apply the padding (right, left, or both).

Return type:

Union[Tensor, Tuple[Tensor, Tensor]]

Returns:

a collated data tensor, or a tuple of tensors (collated_data, sequence_lens).

lhotse.dataset.collation.collate_multi_channel_features(cuts)[source]

Load features for all the cuts and return them as a batch in a torch tensor. The cuts have to be of type MixedCut and their tracks will be interpreted as individual channels. The output shape is (batch, channel, time, features). The cuts will be padded with silence if necessary.

Return type:

Tensor

lhotse.dataset.collation.collate_vectors(tensors, padding_value=-100, matching_shapes=False)[source]

Convert an iterable of 1-D tensors (of possibly various lengths) into a single stacked tensor.

Parameters:
  • tensors (Iterable[Union[Tensor, ndarray]]) – an iterable of 1-D tensors.

  • padding_value (Union[int, float]) – the padding value inserted to make all tensors have the same length.

  • matching_shapes (bool) – when True, will fail when input tensors have different shapes.

Return type:

Tensor

Returns:

a tensor with shape (B, L) where B is the number of input tensors and L is the number of items in the longest tensor.

lhotse.dataset.collation.collate_matrices(tensors, padding_value=0, matching_shapes=False)[source]

Convert an iterable of 2-D tensors (of possibly various first dimension, but consistent second dimension) into a single stacked tensor.

Parameters:
  • tensors (Iterable[Union[Tensor, ndarray]]) – an iterable of 2-D tensors.

  • padding_value (Union[int, float]) – the padding value inserted to make all tensors have the same length.

  • matching_shapes (bool) – when True, will fail when input tensors have different shapes.

Return type:

Tensor

Returns:

a tensor with shape (B, L, F) where B is the number of input tensors, L is the largest found shape[0], and F is equal to shape[1].

lhotse.dataset.collation.maybe_pad(cuts, duration=None, num_frames=None, num_samples=None, direction='right', preserve_id=False)[source]

Check if all cuts’ durations are equal and pad them to match the longest cut otherwise.

Return type:

CutSet

lhotse.dataset.collation.read_audio_from_cuts(cuts, executor=None, suppress_errors=False, recording_field=None)[source]

Loads audio data from an iterable of cuts.

Parameters:
  • cuts (Iterable[Cut]) – a CutSet or iterable of cuts.

  • executor (Optional[Executor]) – optional Executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor) to perform the audio reads in parallel.

  • suppress_errors (bool) – when set to True, will enable fault-tolerant data reads; we will skip the cuts and audio data for the instances that failed (and emit a warning). When False (default), the errors will not be suppressed.

  • recording_field (Optional[str]) – when specified, we will try to load recordings from a custom field with this name (i.e., cut.load_<recording_field>() instead of default cut.load_audio()).

Return type:

Tuple[List[Tensor], CutSet]

Returns:

a tuple of two items: a list of audio tensors (with different shapes), and a list of cuts for which we read the data successfully.

lhotse.dataset.collation.read_video_from_cuts(cuts, executor=None, suppress_errors=False)[source]

Loads audio data from an iterable of cuts.

Parameters:
  • cuts (Iterable[Cut]) – a CutSet or iterable of cuts.

  • executor (Optional[Executor]) – optional Executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor) to perform the audio reads in parallel.

  • suppress_errors (bool) – when set to True, will enable fault-tolerant data reads; we will skip the cuts and audio data for the instances that failed (and emit a warning). When False (default), the errors will not be suppressed.

Return type:

Tuple[List[Tensor], List[Tensor], CutSet]

Returns:

a tuple of two items: a list of audio tensors (with different shapes), and a list of cuts for which we read the data successfully.

lhotse.dataset.collation.read_features_from_cuts(cuts, executor=None)[source]
Return type:

List[Tensor]

Dataloading seeding utilities

lhotse.dataset.dataloading.make_worker_init_fn(rank=None, world_size=None, set_different_node_and_worker_seeds=True, seed=42)[source]

Calling this function creates a worker_init_fn suitable to pass to PyTorch’s DataLoader.

It helps with two issues: :rtype: Optional[Callable[[int], None]]

  • sets the random seeds differently for each worker and node, which helps with

    avoiding duplication in randomized data augmentation techniques.

  • sets environment variables that help WebDataset detect it’s inside multi-GPU (DDP)

    training, so that it correctly de-duplicates the data across nodes.

lhotse.dataset.dataloading.worker_init_fn(worker_id, rank=None, world_size=None, set_different_node_and_worker_seeds=True, seed=42)[source]

Function created by make_worker_init_fn(), refer to its documentation for details.

Return type:

None

lhotse.dataset.dataloading.resolve_seed(seed)[source]

Resolves the special values of random seed supported in Lhotse.

If it’s an integer, we’ll just return it.

If it’s “trng”, we’ll use the secrets module to generate a random seed using a true RNG (to the extend supported by the OS).

If it’s “randomized”, we’ll check whether we’re in a dataloading worker of torch.utils.data.DataLoader. If we are, we expect that it was passed the result of make_worker_init_fn() into its worker_init_fn argument, in which case we’ll return a special seed exclusive to that worker. If we are not in a dataloading worker (or num_workers was set to 0), we’ll return Python’s random module global seed.

Return type:

int

lhotse.dataset.dataloading.get_world_size()[source]

Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56

Return type:

int

lhotse.dataset.dataloading.get_rank()[source]

Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56

Return type:

int