import warnings
from concurrent.futures import Executor
from functools import partial
from itertools import repeat
from typing import Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from torch.nn import CrossEntropyLoss
from lhotse import CutSet, Recording
from lhotse.audio import suppress_audio_loading_errors
from lhotse.audio.utils import suppress_video_loading_errors
from lhotse.cut import Cut, MixedCut
from lhotse.utils import DEFAULT_PADDING_VALUE, compute_num_samples
[docs]class TokenCollater:
"""Collate list of tokens
Map sentences to integers. Sentences are padded to equal length.
Beginning and end-of-sequence symbols can be added.
Call .inverse(tokens_batch, tokens_lens) to reconstruct batch as string sentences.
Example:
>>> token_collater = TokenCollater(cuts)
>>> tokens_batch, tokens_lens = token_collater(cuts.subset(first=32))
>>> original_sentences = token_collater.inverse(tokens_batch, tokens_lens)
Returns:
tokens_batch: IntTensor of shape (B, L)
B: batch dimension, number of input sentences
L: length of the longest sentence
tokens_lens: IntTensor of shape (B,)
Length of each sentence after adding <eos> and <bos>
but before padding.
"""
[docs] def __init__(
self,
cuts: CutSet,
add_eos: bool = True,
add_bos: bool = True,
pad_symbol: str = "<pad>",
bos_symbol: str = "<bos>",
eos_symbol: str = "<eos>",
unk_symbol: str = "<unk>",
):
self.pad_symbol = pad_symbol
self.bos_symbol = bos_symbol
self.eos_symbol = eos_symbol
self.unk_symbol = unk_symbol
self.add_eos = add_eos
self.add_bos = add_bos
tokens = {char for cut in cuts for char in cut.supervisions[0].text}
tokens_unique = (
[pad_symbol, unk_symbol]
+ ([bos_symbol] if add_bos else [])
+ ([eos_symbol] if add_eos else [])
+ sorted(tokens)
)
self.token2idx = {token: idx for idx, token in enumerate(tokens_unique)}
self.idx2token = [token for token in tokens_unique]
def __call__(self, cuts: CutSet) -> Tuple[torch.Tensor, torch.Tensor]:
token_sequences = [
" ".join(supervision.text for supervision in cut.supervisions)
for cut in cuts
]
max_len = len(max(token_sequences, key=len))
seqs = [
([self.bos_symbol] if self.add_bos else [])
+ list(seq)
+ ([self.eos_symbol] if self.add_eos else [])
+ [self.pad_symbol] * (max_len - len(seq))
for seq in token_sequences
]
tokens_batch = torch.from_numpy(
np.array(
[[self.token2idx[token] for token in seq] for seq in seqs],
dtype=np.int64,
)
)
tokens_lens = torch.IntTensor(
[
len(seq) + int(self.add_eos) + int(self.add_bos)
for seq in token_sequences
]
)
return tokens_batch, tokens_lens
[docs] def inverse(
self, tokens_batch: torch.LongTensor, tokens_lens: torch.IntTensor
) -> List[str]:
start = 1 if self.add_bos else 0
sentences = [
"".join(
[
self.idx2token[idx]
for idx in tokens_list[start : end - int(self.add_eos)]
]
)
for tokens_list, end in zip(tokens_batch, tokens_lens)
]
return sentences
[docs]def collate_features(
cuts: CutSet,
pad_direction: str = "right",
executor: Optional[Executor] = None,
features_dtype: Optional[torch.dtype] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Load features for all the cuts and return them as a batch in a torch tensor.
The output shape is ``(batch, time, features)``.
The cuts will be padded with silence if necessary.
:param cuts: a :class:`CutSet` used to load the features.
:param pad_direction: where to apply the padding (``right``, ``left``, or ``both``).
:param executor: an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided,
we will use it to read the features concurrently.
:return: a tuple of tensors ``(features, features_lens)``.
"""
assert all(cut.has_features for cut in cuts)
features_lens = torch.tensor([cut.num_frames for cut in cuts], dtype=torch.int)
cuts = cuts.pad(num_frames=max(features_lens).item(), direction=pad_direction)
first_cut = next(iter(cuts))
features = torch.empty(
len(cuts), first_cut.num_frames, first_cut.num_features, dtype=features_dtype
)
if executor is None:
for idx, cut in enumerate(cuts):
features[idx] = _read_features(cut)
else:
for idx, example_features in enumerate(executor.map(_read_features, cuts)):
features[idx] = example_features
return features, features_lens
[docs]def collate_audio(
cuts: CutSet,
pad_direction: str = "right",
executor: Optional[Executor] = None,
fault_tolerant: bool = False,
recording_field: Optional[str] = None,
mono_downmix: Optional[bool] = None,
) -> Union[
Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor, CutSet]
]:
"""
Load audio samples for all the cuts and return them as a batch in a torch tensor.
The output shape is ``(batch, time)`` or ``(batch, channels, time)``.
The cuts will be padded with silence if necessary.
:param cuts: a :class:`CutSet` used to load the audio samples.
:param pad_direction: where to apply the padding (``right``, ``left``, or ``both``).
:param executor: an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided,
we will use it to read audio concurrently.
:param fault_tolerant: when ``True``, the cuts for which audio loading failed
will be skipped. Setting this parameter will cause the function to return a 3-tuple,
where the third element is a CutSet for which the audio data were sucessfully read.
:param recording_field: when specified, we will try to load recordings from a custom field with this name
(i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_audio()``).
:param mono_downmix: controls channel handling.
``None`` (default): auto-detect — uses downmix semantics unless every cut in the batch
is multichannel, in which case multichannel collation is used.
``True``: multichannel audio is downmixed to mono by averaging channels; output shape
is ``(batch, time)``.
``False``: mono audio is placed in channel 0 with remaining channels zero-padded to
match the batch maximum; output shape is ``(batch, channels, time)``.
:return: a tuple of tensors ``(audio, audio_lens)``, or ``(audio, audio_lens, cuts)``.
"""
for cut in cuts:
if recording_field is None:
assert cut.has_recording, f"Missing recording in cut {cut.id}"
else:
assert cut.has_custom(
recording_field
), f"Missing custom recording field {recording_field} in cut {cut.id}"
# Remember how many samples were there in each cut (later, we might remove cuts that fail to load).
sample_counts = []
for cut in cuts:
if recording_field is None:
num_samples = cut.num_samples
else:
num_samples = compute_num_samples(
cut.duration, sampling_rate=getattr(cut, recording_field).sampling_rate
)
sample_counts.append(num_samples)
cuts = cuts.pad(
duration=max(cut.duration for cut in cuts),
direction=pad_direction,
preserve_id=True,
)
# Note: returned "cuts" may be a subset of the original "cuts" if fault_tolerant=True.
audios, cuts, sample_counts = read_audio_from_cuts(
cuts,
executor,
suppress_errors=fault_tolerant,
recording_field=recording_field,
filter_aux_iter=sample_counts,
)
if mono_downmix is None:
# Auto-detect: use False semantics only when every audio is multichannel
mono_downmix = not all(a.dim() == 2 for a in audios)
if mono_downmix:
# Downmix multichannel audio to mono by averaging channels
processed = []
for audio in audios:
if audio.dim() == 2:
audio = audio.mean(dim=0) # (channels, time) -> (time,)
processed.append(audio)
audios = collate_vectors(processed, padding_value=0.0)
else:
# Expand mono audio to match max channels in batch, then collate as multichannel
max_channels = max(
audio.shape[0] if audio.dim() == 2 else 1 for audio in audios
)
processed = []
for audio in audios:
if audio.dim() == 1:
expanded = audio.new_zeros(max_channels, audio.shape[0])
expanded[0] = audio
audio = expanded
processed.append(audio)
audios = collate_matrices(
[a.transpose(0, 1) for a in processed], padding_value=0.0
).transpose(1, 2)
audio_lens = torch.tensor(sample_counts, dtype=torch.int32)
if fault_tolerant:
return audios, audio_lens, cuts
else:
return audios, audio_lens
collate_multi_channel_audio = collate_audio # alias for backwards compatibility
[docs]def collate_video(
cuts: CutSet,
with_audio: bool = True,
pad_direction: str = "right",
executor: Optional[Executor] = None,
fault_tolerant: bool = False,
recording_field: Optional[str] = None,
):
"""
Load video and audio for all cuts and return them as a batch in torch tensors.
The output video shape is ``(batch, time, channel, height, width)``.
The output audio shape is ``(batch, channel, time)``.
The cuts will be padded with silence if necessary.
.. note:: We expect each video to contain audio and the same number of audio channels.
We may support padding missing channels at a later time.
:param cuts: a :class:`CutSet` used to load the audio samples.
:param with_audio: should the audio data be loaded.
:param pad_direction: where to apply the padding (``right``, ``left``, or ``both``).
:param executor: an instance of ThreadPoolExecutor or ProcessPoolExecutor; when provided,
we will use it to read video concurrently.
:param fault_tolerant: when ``True``, the cuts for which video/audio loading failed
will be skipped. Setting this parameter will cause the function to return a 5-tuple,
where the fifth element is a CutSet for which the audio data were sucessfully read.
:param recording_field: when specified, we will try to load recordings from a custom field with this name
(i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_video()``).
:return: a tuple of tensors ``(video, video_lens, audio, audio_lens)``,
or ``(video, video_lens, audio, audio_lens, cuts)``.
"""
for cut in cuts:
if recording_field is None:
assert cut.has_video, f"Missing video in the recording of cut {cut.id}"
else:
assert cut.has_custom(
recording_field
), f"Missing custom recording field {recording_field} in cut {cut.id}"
assert getattr(
cut, recording_field
).has_video, f"Missing video in custom recording field {recording_field} of cut {cut.id}"
# Remember how many samples were there in each cut (later, we might remove cuts that fail to load).
id2lens = {}
for cut in cuts:
if recording_field is None:
video = cut.video
num_samples = cut.num_samples
else:
video = getattr(cut, recording_field).video
num_samples = compute_num_samples(
cut.duration, getattr(cut, recording_field).sampling_rate
)
id2lens[cut.id] = (num_samples, video.num_frames)
cuts = cuts.pad(
duration=max(c.duration for c in cuts),
direction=pad_direction,
preserve_id=True,
)
# Note: returned "cuts" may be a subset of the original "cuts" if fault_tolerant=True.
videos, audios, cuts = read_video_from_cuts(
cuts, with_audio=with_audio, executor=executor, suppress_errors=fault_tolerant
)
videos = torch.stack(videos) # B x T x C x H x W
video_lens = torch.tensor([id2lens[cut.id][1] for cut in cuts], dtype=torch.int32)
if with_audio:
audios = torch.stack(audios) # B x C x T
audio_lens = torch.tensor(
[id2lens[cut.id][0] for cut in cuts], dtype=torch.int32
)
else:
audios, audio_lens = None, None
if fault_tolerant:
return videos, video_lens, audios, audio_lens, cuts
else:
return videos, video_lens, audios, audio_lens
[docs]def collate_custom_field(
cuts: CutSet,
field: str,
pad_value: Union[None, int, float] = None,
pad_direction: str = "right",
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
"""
Load custom arrays for all the cuts and return them as a batch in a torch tensor.
The output shapes are:
- ``(batch, d0, d1, d2, ...)`` for :class:`lhotse.array.Array` of shape ``(d0, d1, d2, ...)``.
Note: all arrays have to be of the same shape, as we expect these represent fixed-size
embeddings.
- ``(batch, d0, pad_dt, d1, ...)`` for :class:`lhotse.array.TemporalArray` of shape
``(d0, dt, d1, ...)`` where ``dt`` indicates temporal dimension (variable-sized),
and ``pad_dt`` indicates temporal dimension after padding (equal-sized for all cuts).
We expect these represent temporal data, such as alignments, posteriors, features, etc.
- ``(batch, )`` for anything else, such as int or float: we will simply stack them into
a list and tensorize it.
.. note:: This function disregards the ``frame_shift`` attribute of
:class:`lhotse.array.TemporalArray` when padding; it simply pads all the arrays
to the longest one found in the mini-batch. Because of that, the function
will work correctly even if the user supplied inconsistent meta-data.
.. note:: Temporal arrays of integer type that are smaller than torch.int64,
will be automatically promoted to torch.int64.
:param cuts: a :class:`CutSet` used to load the features.
:param field: name of the custom field to be retrieved.
:param pad_value: value to be used for padding the temporal arrays.
Ignored for non-temporal array and non-array attributes.
:param pad_direction: where to apply the padding (``right``, ``left``, or ``both``).
:return: a collated data tensor, or a tuple of tensors ``(collated_data, sequence_lens)``.
"""
from lhotse.array import Array, TemporalArray
from lhotse.image import Image
first_manifest = getattr(cuts[0], field)
if isinstance(first_manifest, Array):
# Expected data type: fixed-size embeddings.
# Simply stack across a new dimension inserted at 0.
assert all(getattr(c, field).shape == first_manifest.shape for c in cuts), (
"Cannot collate manifests of type Array with different shapes, "
"because we don't know which dimension must be padded. "
"Use TemporalArray manifests and try again."
)
return torch.stack([torch.from_numpy(c.load_custom(field)) for c in cuts])
elif isinstance(first_manifest, TemporalArray):
# Expected data type: variable-sized tensors (along only one dimension).
# Pad across that dimension, then stack at dimension 0.
if pad_value is None:
warnings.warn(
f"Argument 'pad_value' not passed -- we will pad field '{field}' "
f"with {DEFAULT_PADDING_VALUE}."
)
pad_value = DEFAULT_PADDING_VALUE
temporal_dim = first_manifest.temporal_dim
# We avoid cuts.pad() because the users might be defining frame_shift differently
# that we typically do in Lhotse. This may result in extra padding where they
# expected none to happen. See: https://github.com/lhotse-speech/lhotse/issues/478
# cuts = cuts.pad(direction=pad_direction, pad_value_dict={field: pad_value})
# tensors = torch.stack([torch.from_numpy(c.load_custom(field)) for c in cuts])
# Instead, we're going to load everything and pad to the longest sequence.
arrs = [torch.from_numpy(c.load_custom(field)) for c in cuts]
arr_lens = torch.tensor(
[a.shape[temporal_dim] for a in arrs], dtype=torch.int32
)
largest_arr = max(arrs, key=torch.numel)
maxlen = largest_arr.shape[temporal_dim]
collated_shape = (len(arrs), *largest_arr.shape)
dtype = largest_arr.dtype
if any(d == dtype for d in (torch.uint8, torch.int8, torch.int16, torch.int32)):
dtype = torch.int64
tensors = pad_value * torch.ones(collated_shape, dtype=dtype)
for aidx, a in enumerate(arrs):
alen = a.shape[temporal_dim]
# Construct an index expression such as tensors[:, :alen, :, :] programmatically;
# All indices are set to ':', besides temporal dim which is determined on pad_direction.
if pad_direction == "right":
temporal_slice = slice(0, alen)
elif pad_direction == "left":
temporal_slice = slice(maxlen - alen, maxlen)
elif pad_direction == "both":
half = (maxlen - alen) // 2
temporal_slice = slice(half, maxlen - half)
else:
raise ValueError(
f"Unexpected pad_direction argument: '{pad_direction}'"
)
indices = (aidx,) + tuple(
temporal_slice if i == temporal_dim else slice(None, None, None)
for i in range(len(a.shape))
)
tensors[indices] = a
return tensors, arr_lens
elif isinstance(first_manifest, Image):
return collate_images(cuts, field)
elif isinstance(first_manifest, Recording):
return collate_audio(cuts, recording_field=field, pad_direction=pad_direction)
else:
# Expected data type: int, float, string, etc.
# Get a list of them and convert to a tensor.
return torch.tensor([getattr(c, field) for c in cuts])
[docs]def collate_multi_channel_features(cuts: CutSet) -> torch.Tensor:
"""
Load features for all the cuts and return them as a batch in a torch tensor.
The cuts have to be of type ``MixedCut`` and their tracks will be interpreted as individual channels.
The output shape is ``(batch, channel, time, features)``.
The cuts will be padded with silence if necessary.
"""
assert all(cut.has_features for cut in cuts)
assert all(isinstance(cut, MixedCut) for cut in cuts)
cuts = cuts.pad(cuts)
# Output tensor shape: (B, C, T, F) -> (batch_size, num_channels, num_frames, num_features)
first_cut = next(iter(cuts))
# TODO: make MixedCut more friendly to use with multi channel audio;
# discount PaddingCuts in "tracks" when specifying the number of channels
features = torch.empty(
len(cuts), len(first_cut.tracks), first_cut.num_frames, first_cut.num_features
)
for idx, cut in enumerate(cuts):
features[idx] = torch.from_numpy(cut.load_features(mixed=False))
return features
[docs]def collate_vectors(
tensors: Iterable[Union[torch.Tensor, np.ndarray]],
padding_value: Union[int, float] = CrossEntropyLoss().ignore_index,
pad_direction: str = "right",
matching_shapes: bool = False,
) -> torch.Tensor:
"""
Convert an iterable of 1-D tensors (of possibly various lengths)
into a single stacked tensor.
:param tensors: an iterable of 1-D tensors.
:param padding_value: the padding value inserted to make all tensors have the same length.
:param pad_direction: where to apply the padding (``right`` or ``left``).
:param matching_shapes: when ``True``, will fail when input tensors have different shapes.
:return: a tensor with shape ``(B, L)`` where ``B`` is the number of input tensors and
``L`` is the number of items in the longest tensor.
"""
tensors = [
t if isinstance(t, torch.Tensor) else torch.from_numpy(t) for t in tensors
]
assert all(len(t.shape) == 1 for t in tensors), "Expected only 1-D input tensors."
if pad_direction not in ("left", "right"):
raise ValueError(
f"pad_direction must be 'left' or 'right', got {pad_direction}"
)
longest = max(tensors, key=lambda t: t.shape[0])
if matching_shapes:
assert all(
t.shape == longest.shape for t in tensors
), "All tensors must have the same shape when matching_shapes is set to True."
result = longest.new_ones(len(tensors), longest.shape[0]) * padding_value
for i, t in enumerate(tensors):
if pad_direction == "right":
result[i, : t.shape[0]] = t
else:
result[i, -t.shape[0] :] = t
return result
[docs]def collate_matrices(
tensors: Iterable[Union[torch.Tensor, np.ndarray]],
padding_value: Union[int, float] = 0,
matching_shapes: bool = False,
) -> torch.Tensor:
"""
Convert an iterable of 2-D tensors (of possibly various first dimension, but consistent second dimension)
into a single stacked tensor.
:param tensors: an iterable of 2-D tensors.
:param padding_value: the padding value inserted to make all tensors have the same length.
:param matching_shapes: when ``True``, will fail when input tensors have different shapes.
:return: a tensor with shape ``(B, L, F)`` where ``B`` is the number of input tensors,
``L`` is the largest found shape[0], and ``F`` is equal to shape[1].
"""
tensors = [
t if isinstance(t, torch.Tensor) else torch.from_numpy(t) for t in tensors
]
assert all(len(t.shape) == 2 for t in tensors), "Expected only 2-D input tensors."
longest = max(tensors, key=lambda t: t.shape[0])
if matching_shapes:
assert all(
t.shape == longest.shape for t in tensors
), "All tensors must have the same shape when matching_shapes is set to True."
result = longest.new_ones(len(tensors), *longest.shape) * padding_value
for i, t in enumerate(tensors):
result[i, : t.shape[0]] = t
return result
"""
Helper functions to dispatch jobs to the concurrent executors.
"""
[docs]def read_audio_from_cuts(
cuts: Iterable[Cut],
executor: Optional[Executor] = None,
suppress_errors: bool = False,
recording_field: Optional[str] = None,
filter_aux_iter: Optional[Iterable] = None,
) -> Union[Tuple[List[torch.Tensor], CutSet], Tuple[List[torch.Tensor], CutSet, List]]:
"""
Loads audio data from an iterable of cuts.
:param cuts: a CutSet or iterable of cuts.
:param executor: optional Executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor)
to perform the audio reads in parallel.
:param suppress_errors: when set to ``True``, will enable fault-tolerant data reads;
we will skip the cuts and audio data for the instances that failed (and emit a warning).
When ``False`` (default), the errors will not be suppressed.
:param recording_field: when specified, we will try to load recordings from a custom field with this name
(i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_audio()``).
:param filter_aux_iter: when specified, we will iterate over this iterator and discard the elements
for which a corresponding cut failed to load audio, if ``suppress_errors`` is set to ``True``.
This iterator is expected to be of the same length as ``cuts``.
:return: a tuple of two items: a list of audio tensors (with different shapes),
and a list of cuts for which we read the data successfully.
If ``filter_aux_iter`` is specified, it returns a 3-tuple where the third element is
the filtered auxiliary iterator.
"""
aux_requested = True
if filter_aux_iter is None:
filter_aux_iter = repeat([None])
aux_requested = False
map_fn = map if executor is None else executor.map
audios = []
ok_cuts = []
aux_iter_out = []
for idx, (cut, maybe_audio, aux_item) in enumerate(
zip(
cuts,
map_fn(
partial(
_read_audio,
suppress_errors=suppress_errors,
recording_field=recording_field,
),
cuts,
),
filter_aux_iter,
)
):
if maybe_audio is None:
continue
else:
audios.append(maybe_audio)
ok_cuts.append(cut)
aux_iter_out.append(aux_item)
ans = (audios, CutSet.from_cuts(ok_cuts))
if aux_requested:
ans = ans + (aux_iter_out,)
return ans
[docs]def read_video_from_cuts(
cuts: Iterable[Cut],
with_audio: bool = True,
executor: Optional[Executor] = None,
suppress_errors: bool = False,
recording_field: Optional[str] = None,
) -> Tuple[List[torch.Tensor], List[torch.Tensor], CutSet]:
"""
Loads audio data from an iterable of cuts.
:param cuts: a CutSet or iterable of cuts.
:param with_audio: should the audio data be loaded.
:param executor: optional Executor (e.g., ThreadPoolExecutor or ProcessPoolExecutor)
to perform the audio reads in parallel.
:param suppress_errors: when set to ``True``, will enable fault-tolerant data reads;
we will skip the cuts and audio data for the instances that failed (and emit a warning).
When ``False`` (default), the errors will not be suppressed.
:param recording_field: when specified, we will try to load recordings from a custom field with this name
(i.e., ``cut.load_<recording_field>()`` instead of default ``cut.load_video()``).
:return: a tuple of two items: a list of audio tensors (with different shapes),
and a list of cuts for which we read the data successfully.
"""
map_fn = map if executor is None else executor.map
videos = []
audios = []
ok_cuts = []
for idx, (cut, maybe_ans) in enumerate(
zip(
cuts,
map_fn(
partial(
_read_video,
suppress_errors=suppress_errors,
with_audio=with_audio,
recording_field=recording_field,
),
cuts,
),
)
):
if maybe_ans is None:
continue
else:
video, audio = maybe_ans
videos.append(video)
audios.append(audio)
ok_cuts.append(cut)
return videos, audios, CutSet.from_cuts(ok_cuts)
[docs]def read_features_from_cuts(
cuts: Iterable[Cut], executor: Optional[Executor] = None
) -> List[torch.Tensor]:
map_fn = map if executor is None else executor.map
return list(map_fn(_read_features, cuts))
def _read_audio(
cut: Cut, suppress_errors: bool = False, recording_field: Optional[str] = None
) -> Optional[torch.Tensor]:
"""
Loads audio data from cut, or returns None if there was an error
and ``suppress_errors`` was set to ``True``.
"""
with suppress_audio_loading_errors(enabled=suppress_errors):
if recording_field is None:
audio = cut.load_audio()
else:
attr = getattr(cut, recording_field)
assert isinstance(
attr, Recording
), f"Expected 'getattr(cut, {recording_field})' to yield Recording, got {type(attr)}"
audio = cut.load_custom(recording_field)
if audio.shape[0] == 1:
audio = audio.squeeze(0) # collapse channel dim if mono
return torch.from_numpy(audio)
def _read_features(cut: Cut) -> torch.Tensor:
return torch.from_numpy(cut.load_features())
def _read_video(
cut: Cut,
with_audio: bool = True,
suppress_errors: bool = False,
recording_field: Optional[str] = None,
) -> Optional[Tuple[torch.Tensor, Optional[torch.Tensor]]]:
"""
Loads video + audio data from cut, or returns None if there was an error
and ``suppress_errors`` was set to ``True``.
"""
with suppress_video_loading_errors(enabled=suppress_errors):
if recording_field is None:
return cut.load_video(with_audio=with_audio)
else:
attr = getattr(cut, recording_field)
assert isinstance(
attr, Recording
), f"Expected 'getattr(cut, {recording_field})' to yield Recording, got {type(attr)}"
return cut.load_custom(recording_field, with_audio=with_audio)
[docs]def collate_images(
cuts: CutSet,
image_field: str = "image",
) -> torch.Tensor:
"""
Load images for all cuts and return them as a batch in a torch tensor.
The output image shape is ``(batch, height, width, channel)``.
:param cuts: a :class:`CutSet` used to load the images.
:param image_field: the field in the cut to load the images from.
:return: tensor of collated images"""
images = [torch.as_tensor(cut.load_custom(image_field)) for cut in cuts]
images = torch.stack(images)
return images