Source code for lhotse.dataset.collation

import warnings
from concurrent.futures import Executor
from functools import partial
from itertools import repeat
from typing import Iterable, List, Optional, Tuple, Union

import numpy as np
import torch
from torch.nn import CrossEntropyLoss

from lhotse import CutSet, Recording
from lhotse.audio import suppress_audio_loading_errors
from lhotse.audio.utils import suppress_video_loading_errors
from lhotse.cut import Cut, MixedCut
from lhotse.utils import DEFAULT_PADDING_VALUE, compute_num_samples


[docs]class TokenCollater: """Collate list of tokens Map sentences to integers. Sentences are padded to equal length. Beginning and end-of-sequence symbols can be added. Call .inverse(tokens_batch, tokens_lens) to reconstruct batch as string sentences. Example: >>> token_collater = TokenCollater(cuts) >>> tokens_batch, tokens_lens = token_collater(cuts.subset(first=32)) >>> original_sentences = token_collater.inverse(tokens_batch, tokens_lens) Returns: tokens_batch: IntTensor of shape (B, L) B: batch dimension, number of input sentences L: length of the longest sentence tokens_lens: IntTensor of shape (B,) Length of each sentence after adding <eos> and <bos> but before padding. """
[docs] def __init__( self, cuts: CutSet, add_eos: bool = True, add_bos: bool = True, pad_symbol: str = "<pad>", bos_symbol: str = "<bos>", eos_symbol: str = "<eos>", unk_symbol: str = "<unk>", ): self.pad_symbol = pad_symbol self.bos_symbol = bos_symbol self.eos_symbol = eos_symbol self.unk_symbol = unk_symbol self.add_eos = add_eos self.add_bos = add_bos tokens = {char for cut in cuts for char in cut.supervisions[0].text} tokens_unique = ( [pad_symbol, unk_symbol] + ([bos_symbol] if add_bos else []) + ([eos_symbol] if add_eos else []) + sorted(tokens) ) self.token2idx = {token: idx for idx, token in enumerate(tokens_unique)} self.idx2token = [token for token in tokens_unique]
def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.Tensor]: token_sequences = [ " ".join(supervision.text for supervision in cut.supervisions) for cut in cuts ] max_len = len(max(token_sequences, key=len)) seqs = [ ([self.bos_symbol] if self.add_bos else []) + list(seq) + ([self.eos_symbol] if self.add_eos else []) + [self.pad_symbol] * (max_len - len(seq)) for seq in token_sequences ] tokens_batch = torch.from_numpy( np.array( [[self.token2idx[token] for token in seq] for seq in seqs], dtype=np.int64, ) ) tokens_lens = torch.IntTensor( [ len(seq) + int(self.add_eos) + int(self.add_bos) for seq in token_sequences ] ) return tokens_batch, tokens_lens
[docs] def inverse( self, tokens_batch: torch.LongTensor, tokens_lens: torch.IntTensor ) -> List[str]: start = 1 if self.add_bos else 0 sentences = [ "".join( [ self.idx2token[idx] for idx in tokens_list[start : end - int(self.add_eos)] ] ) for tokens_list, end in zip(tokens_batch, tokens_lens) ] return sentences
[docs]def collate_features( cuts: CutSet, pad_direction: str = "right", executor: Optional[Executor] = None, features_dtype: Optional[torch.dtype] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Load features for all the cuts and return them as a batch in a torch tensor. The output shape is ``(batch, time, features)``. The cuts will be padded with silence if necessary. :param cuts: a :class:`CutSet` used to load the features. :param pad_direction: where to apply the padding (``right``, ``left``, or ``both``). :param executor: an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided, we will use it to read the features concurrently. :return: a tuple of tensors ``(features, features_lens)``. """ assert all(cut.has_features for cut in cuts) features_lens = torch.tensor([cut.num_frames for cut in cuts], dtype=torch.int) cuts = cuts.pad(num_frames=max(features_lens).item(), direction=pad_direction) first_cut = next(iter(cuts)) features = torch.empty( len(cuts), first_cut.num_frames, first_cut.num_features, dtype=features_dtype ) if executor is None: for idx, cut in enumerate(cuts): features[idx] = _read_features(cut) else: for idx, example_features in enumerate(executor.map(_read_features, cuts)): features[idx] = example_features return features, features_lens
[docs]def collate_audio( cuts: CutSet, pad_direction: str = "right", executor: Optional[Executor] = None, fault_tolerant: bool = False, recording_field: Optional[str] = None, mono_downmix: Optional[bool] = None, ) -> Union[ Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, CutSet] ]: """ Load audio samples for all the cuts and return them as a batch in a torch tensor. The output shape is ``(batch, time)`` or ``(batch, channels, time)``. The cuts will be padded with silence if necessary. :param cuts: a :class:`CutSet` used to load the audio samples. :param pad_direction: where to apply the padding (``right``, ``left``, or ``both``). :param executor: an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided, we will use it to read audio concurrently. :param fault_tolerant: when ``True``, the cuts for which audio loading failed will be skipped. Setting this parameter will cause the function to return a 3-tuple, where the third element is a CutSet for which the audio data were sucessfully read. :param recording_field: when specified, we will try to load recordings from a custom field with this name (i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_audio()``). :param mono_downmix: controls channel handling. ``None`` (default): auto-detect — uses downmix semantics unless every cut in the batch is multichannel, in which case multichannel collation is used. ``True``: multichannel audio is downmixed to mono by averaging channels; output shape is ``(batch, time)``. ``False``: mono audio is placed in channel 0 with remaining channels zero-padded to match the batch maximum; output shape is ``(batch, channels, time)``. :return: a tuple of tensors ``(audio, audio_lens)``, or ``(audio, audio_lens, cuts)``. """ for cut in cuts: if recording_field is None: assert cut.has_recording, f"Missing recording in cut {cut.id}" else: assert cut.has_custom( recording_field ), f"Missing custom recording field {recording_field} in cut {cut.id}" # Remember how many samples were there in each cut (later, we might remove cuts that fail to load). sample_counts = [] for cut in cuts: if recording_field is None: num_samples = cut.num_samples else: num_samples = compute_num_samples( cut.duration, sampling_rate=getattr(cut, recording_field).sampling_rate ) sample_counts.append(num_samples) cuts = cuts.pad( duration=max(cut.duration for cut in cuts), direction=pad_direction, preserve_id=True, ) # Note: returned "cuts" may be a subset of the original "cuts" if fault_tolerant=True. audios, cuts, sample_counts = read_audio_from_cuts( cuts, executor, suppress_errors=fault_tolerant, recording_field=recording_field, filter_aux_iter=sample_counts, ) if mono_downmix is None: # Auto-detect: use False semantics only when every audio is multichannel mono_downmix = not all(a.dim() == 2 for a in audios) if mono_downmix: # Downmix multichannel audio to mono by averaging channels processed = [] for audio in audios: if audio.dim() == 2: audio = audio.mean(dim=0) # (channels, time) -> (time,) processed.append(audio) audios = collate_vectors(processed, padding_value=0.0) else: # Expand mono audio to match max channels in batch, then collate as multichannel max_channels = max( audio.shape[0] if audio.dim() == 2 else 1 for audio in audios ) processed = [] for audio in audios: if audio.dim() == 1: expanded = audio.new_zeros(max_channels, audio.shape[0]) expanded[0] = audio audio = expanded processed.append(audio) audios = collate_matrices( [a.transpose(0, 1) for a in processed], padding_value=0.0 ).transpose(1, 2) audio_lens = torch.tensor(sample_counts, dtype=torch.int32) if fault_tolerant: return audios, audio_lens, cuts else: return audios, audio_lens
collate_multi_channel_audio = collate_audio # alias for backwards compatibility
[docs]def collate_video( cuts: CutSet, with_audio: bool = True, pad_direction: str = "right", executor: Optional[Executor] = None, fault_tolerant: bool = False, recording_field: Optional[str] = None, ): """ Load video and audio for all cuts and return them as a batch in torch tensors. The output video shape is ``(batch, time, channel, height, width)``. The output audio shape is ``(batch, channel, time)``. The cuts will be padded with silence if necessary. .. note:: We expect each video to contain audio and the same number of audio channels. We may support padding missing channels at a later time. :param cuts: a :class:`CutSet` used to load the audio samples. :param with_audio: should the audio data be loaded. :param pad_direction: where to apply the padding (``right``, ``left``, or ``both``). :param executor: an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided, we will use it to read video concurrently. :param fault_tolerant: when ``True``, the cuts for which video/audio loading failed will be skipped. Setting this parameter will cause the function to return a 5-tuple, where the fifth element is a CutSet for which the audio data were sucessfully read. :param recording_field: when specified, we will try to load recordings from a custom field with this name (i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_video()``). :return: a tuple of tensors ``(video, video_lens, audio, audio_lens)``, or ``(video, video_lens, audio, audio_lens, cuts)``. """ for cut in cuts: if recording_field is None: assert cut.has_video, f"Missing video in the recording of cut {cut.id}" else: assert cut.has_custom( recording_field ), f"Missing custom recording field {recording_field} in cut {cut.id}" assert getattr( cut, recording_field ).has_video, f"Missing video in custom recording field {recording_field} of cut {cut.id}" # Remember how many samples were there in each cut (later, we might remove cuts that fail to load). id2lens = {} for cut in cuts: if recording_field is None: video = cut.video num_samples = cut.num_samples else: video = getattr(cut, recording_field).video num_samples = compute_num_samples( cut.duration, getattr(cut, recording_field).sampling_rate ) id2lens[cut.id] = (num_samples, video.num_frames) cuts = cuts.pad( duration=max(c.duration for c in cuts), direction=pad_direction, preserve_id=True, ) # Note: returned "cuts" may be a subset of the original "cuts" if fault_tolerant=True. videos, audios, cuts = read_video_from_cuts( cuts, with_audio=with_audio, executor=executor, suppress_errors=fault_tolerant ) videos = torch.stack(videos) # B x T x C x H x W video_lens = torch.tensor([id2lens[cut.id][1] for cut in cuts], dtype=torch.int32) if with_audio: audios = torch.stack(audios) # B x C x T audio_lens = torch.tensor( [id2lens[cut.id][0] for cut in cuts], dtype=torch.int32 ) else: audios, audio_lens = None, None if fault_tolerant: return videos, video_lens, audios, audio_lens, cuts else: return videos, video_lens, audios, audio_lens
[docs]def collate_custom_field( cuts: CutSet, field: str, pad_value: Union[None, int, float] = None, pad_direction: str = "right", ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Load custom arrays for all the cuts and return them as a batch in a torch tensor. The output shapes are: - ``(batch, d0, d1, d2, ...)`` for :class:`lhotse.array.Array` of shape ``(d0, d1, d2, ...)``. Note: all arrays have to be of the same shape, as we expect these represent fixed-size embeddings. - ``(batch, d0, pad_dt, d1, ...)`` for :class:`lhotse.array.TemporalArray` of shape ``(d0, dt, d1, ...)`` where ``dt`` indicates temporal dimension (variable-sized), and ``pad_dt`` indicates temporal dimension after padding (equal-sized for all cuts). We expect these represent temporal data, such as alignments, posteriors, features, etc. - ``(batch, )`` for anything else, such as int or float: we will simply stack them into a list and tensorize it. .. note:: This function disregards the ``frame_shift`` attribute of :class:`lhotse.array.TemporalArray` when padding; it simply pads all the arrays to the longest one found in the mini-batch. Because of that, the function will work correctly even if the user supplied inconsistent meta-data. .. note:: Temporal arrays of integer type that are smaller than torch.int64, will be automatically promoted to torch.int64. :param cuts: a :class:`CutSet` used to load the features. :param field: name of the custom field to be retrieved. :param pad_value: value to be used for padding the temporal arrays. Ignored for non-temporal array and non-array attributes. :param pad_direction: where to apply the padding (``right``, ``left``, or ``both``). :return: a collated data tensor, or a tuple of tensors ``(collated_data, sequence_lens)``. """ from lhotse.array import Array, TemporalArray from lhotse.image import Image first_manifest = getattr(cuts[0], field) if isinstance(first_manifest, Array): # Expected data type: fixed-size embeddings. # Simply stack across a new dimension inserted at 0. assert all(getattr(c, field).shape == first_manifest.shape for c in cuts), ( "Cannot collate manifests of type Array with different shapes, " "because we don't know which dimension must be padded. " "Use TemporalArray manifests and try again." ) return torch.stack([torch.from_numpy(c.load_custom(field)) for c in cuts]) elif isinstance(first_manifest, TemporalArray): # Expected data type: variable-sized tensors (along only one dimension). # Pad across that dimension, then stack at dimension 0. if pad_value is None: warnings.warn( f"Argument 'pad_value' not passed -- we will pad field '{field}' " f"with {DEFAULT_PADDING_VALUE}." ) pad_value = DEFAULT_PADDING_VALUE temporal_dim = first_manifest.temporal_dim # We avoid cuts.pad() because the users might be defining frame_shift differently # that we typically do in Lhotse. This may result in extra padding where they # expected none to happen. See: https://github.com/lhotse-speech/lhotse/issues/478 # cuts = cuts.pad(direction=pad_direction, pad_value_dict={field: pad_value}) # tensors = torch.stack([torch.from_numpy(c.load_custom(field)) for c in cuts]) # Instead, we're going to load everything and pad to the longest sequence. arrs = [torch.from_numpy(c.load_custom(field)) for c in cuts] arr_lens = torch.tensor( [a.shape[temporal_dim] for a in arrs], dtype=torch.int32 ) largest_arr = max(arrs, key=torch.numel) maxlen = largest_arr.shape[temporal_dim] collated_shape = (len(arrs), *largest_arr.shape) dtype = largest_arr.dtype if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)): dtype = torch.int64 tensors = pad_value * torch.ones(collated_shape, dtype=dtype) for aidx, a in enumerate(arrs): alen = a.shape[temporal_dim] # Construct an index expression such as tensors[:, :alen, :, :] programmatically; # All indices are set to ':', besides temporal dim which is determined on pad_direction. if pad_direction == "right": temporal_slice = slice(0, alen) elif pad_direction == "left": temporal_slice = slice(maxlen - alen, maxlen) elif pad_direction == "both": half = (maxlen - alen) // 2 temporal_slice = slice(half, maxlen - half) else: raise ValueError( f"Unexpected pad_direction argument: '{pad_direction}'" ) indices = (aidx,) + tuple( temporal_slice if i == temporal_dim else slice(None, None, None) for i in range(len(a.shape)) ) tensors[indices] = a return tensors, arr_lens elif isinstance(first_manifest, Image): return collate_images(cuts, field) elif isinstance(first_manifest, Recording): return collate_audio(cuts, recording_field=field, pad_direction=pad_direction) else: # Expected data type: int, float, string, etc. # Get a list of them and convert to a tensor. return torch.tensor([getattr(c, field) for c in cuts])
[docs]def collate_multi_channel_features(cuts: CutSet) -> torch.Tensor: """ Load features for all the cuts and return them as a batch in a torch tensor. The cuts have to be of type ``MixedCut`` and their tracks will be interpreted as individual channels. The output shape is ``(batch, channel, time, features)``. The cuts will be padded with silence if necessary. """ assert all(cut.has_features for cut in cuts) assert all(isinstance(cut, MixedCut) for cut in cuts) cuts = cuts.pad(cuts) # Output tensor shape: (B, C, T, F) -> (batch_size, num_channels, num_frames, num_features) first_cut = next(iter(cuts)) # TODO: make MixedCut more friendly to use with multi channel audio; # discount PaddingCuts in "tracks" when specifying the number of channels features = torch.empty( len(cuts), len(first_cut.tracks), first_cut.num_frames, first_cut.num_features ) for idx, cut in enumerate(cuts): features[idx] = torch.from_numpy(cut.load_features(mixed=False)) return features
[docs]def collate_vectors( tensors: Iterable[Union[torch.Tensor, np.ndarray]], padding_value: Union[int, float] = CrossEntropyLoss().ignore_index, pad_direction: str = "right", matching_shapes: bool = False, ) -> torch.Tensor: """ Convert an iterable of 1-D tensors (of possibly various lengths) into a single stacked tensor. :param tensors: an iterable of 1-D tensors. :param padding_value: the padding value inserted to make all tensors have the same length. :param pad_direction: where to apply the padding (``right`` or ``left``). :param matching_shapes: when ``True``, will fail when input tensors have different shapes. :return: a tensor with shape ``(B, L)`` where ``B`` is the number of input tensors and ``L`` is the number of items in the longest tensor. """ tensors = [ t if isinstance(t, torch.Tensor) else torch.from_numpy(t) for t in tensors ] assert all(len(t.shape) == 1 for t in tensors), "Expected only 1-D input tensors." if pad_direction not in ("left", "right"): raise ValueError( f"pad_direction must be 'left' or 'right', got {pad_direction}" ) longest = max(tensors, key=lambda t: t.shape[0]) if matching_shapes: assert all( t.shape == longest.shape for t in tensors ), "All tensors must have the same shape when matching_shapes is set to True." result = longest.new_ones(len(tensors), longest.shape[0]) * padding_value for i, t in enumerate(tensors): if pad_direction == "right": result[i, : t.shape[0]] = t else: result[i, -t.shape[0] :] = t return result
[docs]def collate_matrices( tensors: Iterable[Union[torch.Tensor, np.ndarray]], padding_value: Union[int, float] = 0, matching_shapes: bool = False, ) -> torch.Tensor: """ Convert an iterable of 2-D tensors (of possibly various first dimension, but consistent second dimension) into a single stacked tensor. :param tensors: an iterable of 2-D tensors. :param padding_value: the padding value inserted to make all tensors have the same length. :param matching_shapes: when ``True``, will fail when input tensors have different shapes. :return: a tensor with shape ``(B, L, F)`` where ``B`` is the number of input tensors, ``L`` is the largest found shape[0], and ``F`` is equal to shape[1]. """ tensors = [ t if isinstance(t, torch.Tensor) else torch.from_numpy(t) for t in tensors ] assert all(len(t.shape) == 2 for t in tensors), "Expected only 2-D input tensors." longest = max(tensors, key=lambda t: t.shape[0]) if matching_shapes: assert all( t.shape == longest.shape for t in tensors ), "All tensors must have the same shape when matching_shapes is set to True." result = longest.new_ones(len(tensors), *longest.shape) * padding_value for i, t in enumerate(tensors): result[i, : t.shape[0]] = t return result
""" Helper functions to dispatch jobs to the concurrent executors. """
[docs]def read_audio_from_cuts( cuts: Iterable[Cut], executor: Optional[Executor] = None, suppress_errors: bool = False, recording_field: Optional[str] = None, filter_aux_iter: Optional[Iterable] = None, ) -> Union[Tuple[List[torch.Tensor], CutSet], Tuple[List[torch.Tensor], CutSet, List]]: """ Loads audio data from an iterable of cuts. :param cuts: a CutSet or iterable of cuts. :param executor: optional Executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor) to perform the audio reads in parallel. :param suppress_errors: when set to ``True``, will enable fault-tolerant data reads; we will skip the cuts and audio data for the instances that failed (and emit a warning). When ``False`` (default), the errors will not be suppressed. :param recording_field: when specified, we will try to load recordings from a custom field with this name (i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_audio()``). :param filter_aux_iter: when specified, we will iterate over this iterator and discard the elements for which a corresponding cut failed to load audio, if ``suppress_errors`` is set to ``True``. This iterator is expected to be of the same length as ``cuts``. :return: a tuple of two items: a list of audio tensors (with different shapes), and a list of cuts for which we read the data successfully. If ``filter_aux_iter`` is specified, it returns a 3-tuple where the third element is the filtered auxiliary iterator. """ aux_requested = True if filter_aux_iter is None: filter_aux_iter = repeat([None]) aux_requested = False map_fn = map if executor is None else executor.map audios = [] ok_cuts = [] aux_iter_out = [] for idx, (cut, maybe_audio, aux_item) in enumerate( zip( cuts, map_fn( partial( _read_audio, suppress_errors=suppress_errors, recording_field=recording_field, ), cuts, ), filter_aux_iter, ) ): if maybe_audio is None: continue else: audios.append(maybe_audio) ok_cuts.append(cut) aux_iter_out.append(aux_item) ans = (audios, CutSet.from_cuts(ok_cuts)) if aux_requested: ans = ans + (aux_iter_out,) return ans
[docs]def read_video_from_cuts( cuts: Iterable[Cut], with_audio: bool = True, executor: Optional[Executor] = None, suppress_errors: bool = False, recording_field: Optional[str] = None, ) -> Tuple[List[torch.Tensor], List[torch.Tensor], CutSet]: """ Loads audio data from an iterable of cuts. :param cuts: a CutSet or iterable of cuts. :param with_audio: should the audio data be loaded. :param executor: optional Executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor) to perform the audio reads in parallel. :param suppress_errors: when set to ``True``, will enable fault-tolerant data reads; we will skip the cuts and audio data for the instances that failed (and emit a warning). When ``False`` (default), the errors will not be suppressed. :param recording_field: when specified, we will try to load recordings from a custom field with this name (i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_video()``). :return: a tuple of two items: a list of audio tensors (with different shapes), and a list of cuts for which we read the data successfully. """ map_fn = map if executor is None else executor.map videos = [] audios = [] ok_cuts = [] for idx, (cut, maybe_ans) in enumerate( zip( cuts, map_fn( partial( _read_video, suppress_errors=suppress_errors, with_audio=with_audio, recording_field=recording_field, ), cuts, ), ) ): if maybe_ans is None: continue else: video, audio = maybe_ans videos.append(video) audios.append(audio) ok_cuts.append(cut) return videos, audios, CutSet.from_cuts(ok_cuts)
[docs]def read_features_from_cuts( cuts: Iterable[Cut], executor: Optional[Executor] = None ) -> List[torch.Tensor]: map_fn = map if executor is None else executor.map return list(map_fn(_read_features, cuts))
def _read_audio( cut: Cut, suppress_errors: bool = False, recording_field: Optional[str] = None ) -> Optional[torch.Tensor]: """ Loads audio data from cut, or returns None if there was an error and ``suppress_errors`` was set to ``True``. """ with suppress_audio_loading_errors(enabled=suppress_errors): if recording_field is None: audio = cut.load_audio() else: attr = getattr(cut, recording_field) assert isinstance( attr, Recording ), f"Expected 'getattr(cut, {recording_field})' to yield Recording, got {type(attr)}" audio = cut.load_custom(recording_field) if audio.shape[0] == 1: audio = audio.squeeze(0) # collapse channel dim if mono return torch.from_numpy(audio) def _read_features(cut: Cut) -> torch.Tensor: return torch.from_numpy(cut.load_features()) def _read_video( cut: Cut, with_audio: bool = True, suppress_errors: bool = False, recording_field: Optional[str] = None, ) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]: """ Loads video + audio data from cut, or returns None if there was an error and ``suppress_errors`` was set to ``True``. """ with suppress_video_loading_errors(enabled=suppress_errors): if recording_field is None: return cut.load_video(with_audio=with_audio) else: attr = getattr(cut, recording_field) assert isinstance( attr, Recording ), f"Expected 'getattr(cut, {recording_field})' to yield Recording, got {type(attr)}" return cut.load_custom(recording_field, with_audio=with_audio)
[docs]def collate_images( cuts: CutSet, image_field: str = "image", ) -> torch.Tensor: """ Load images for all cuts and return them as a batch in a torch tensor. The output image shape is ``(batch, height, width, channel)``. :param cuts: a :class:`CutSet` used to load the images. :param image_field: the field in the cut to load the images from. :return: tensor of collated images""" images = [torch.as_tensor(cut.load_custom(image_field)) for cut in cuts] images = torch.stack(images) return images