import warnings
from typing import Any, Dict, Optional
import torch
from lhotse import validate
from lhotse.audio import AudioLoadingError, DurationMismatchError
from lhotse.augmentation import AugmentFn
from lhotse.cut import CutSet
from lhotse.dataset.collation import collate_audio, collate_features, collate_matrices
from lhotse.features import FeatureExtractor
from lhotse.utils import NonPositiveEnergyError, suppress_and_warn
[docs]class UnsupervisedDataset(torch.utils.data.Dataset):
"""
Dataset that contains no supervision - it only provides the features extracted from recordings.
.. code-block::
{
'features': (B x T x F) tensor
'features_lens': (B, ) tensor
}
"""
[docs] def __init__(self) -> None:
super().__init__()
def __getitem__(self, cuts: CutSet) -> Dict[str, Any]:
self._validate(cuts)
features, features_lens = collate_features(cuts)
return {
"cuts": cuts,
"features": features,
"features_lens": features_lens,
}
def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_features for cut in cuts)
[docs]class DynamicUnsupervisedDataset(UnsupervisedDataset):
"""
An example dataset that shows how to use on-the-fly feature extraction in Lhotse.
It accepts two additional inputs - a FeatureExtractor and an optional WavAugmenter for time-domain data augmentation..
The output is approximately the same as that of the ``UnsupervisedDataset`` -
there might be slight differences for ``MixedCut``s, because this dataset mixes them in the time domain,
and ``UnsupervisedDataset`` does that in the feature domain.
Cuts that are not mixed will yield identical results in both dataset classes.
"""
[docs] def __init__(
self,
feature_extractor: FeatureExtractor,
augment_fn: Optional[AugmentFn] = None,
):
super().__init__()
self.feature_extractor = feature_extractor
self.augment_fn = augment_fn
def __getitem__(self, cuts: CutSet) -> torch.Tensor:
self._validate(cuts)
def generate_cut(cuts: CutSet):
for cut in cuts:
with suppress_and_warn(
AudioLoadingError, DurationMismatchError, NonPositiveEnergyError
):
yield cut.compute_features(
extractor=self.feature_extractor,
augment_fn=self.augment_fn,
)
features = collate_matrices(generate_cut(cuts))
return features
def _validate(self, cuts: CutSet) -> None:
validate(cuts)
assert all(cut.has_recording for cut in cuts)