Source code for lhotse.dataset.sampling.cut_pairs
import warnings
from typing import Any, Dict, Optional, Tuple
from lhotse import CutSet, Seconds
from lhotse.dataset.sampling.base import CutSampler, TimeConstraint
from lhotse.dataset.sampling.data_source import DataSource
[docs]class CutPairsSampler(CutSampler):
"""
Samples pairs of cuts from a "source" and "target" CutSet.
It expects that both CutSet's strictly consist of Cuts with corresponding IDs.
It behaves like an iterable that yields lists of strings (cut IDs).
When one of :attr:`max_source_duration`, :attr:`max_target_duration`, or :attr:`max_cuts` is specified,
the batch size is dynamic.
Exactly zero or one of those constraints can be specified.
Padding required to collate the batch does not contribute to max source_duration/target_duration.
"""
[docs] def __init__(
self,
source_cuts: CutSet,
target_cuts: CutSet,
max_source_duration: Seconds = None,
max_target_duration: Seconds = None,
max_cuts: Optional[int] = None,
shuffle: bool = False,
drop_last: bool = False,
world_size: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
):
"""
CutPairsSampler's constructor.
:param source_cuts: the first ``CutSet`` to sample data from.
:param target_cuts: the second ``CutSet`` to sample data from.
:param max_source_duration: The maximum total recording duration from ``source_cuts``.
:param max_target_duration: The maximum total recording duration from ``target_cuts``.
:param max_cuts: The maximum number of cuts sampled to form a mini-batch.
By default, this constraint is off.
:param shuffle: When ``True``, the cuts will be shuffled at the start of iteration.
Convenient when mini-batch loop is inside an outer epoch-level loop, e.g.:
`for epoch in range(10): for batch in dataset: ...` as every epoch will see a
different cuts order.
:param drop_last: When ``True``, the last batch is dropped if it's incomplete.
:param world_size: Total number of distributed nodes. We will try to infer it by default.
:param rank: Index of distributed node. We will try to infer it by default.
:param seed: Random seed used to consistently shuffle the dataset across different processes.
"""
super().__init__(
drop_last=drop_last,
shuffle=shuffle,
world_size=world_size,
rank=rank,
seed=seed,
)
self.source_cuts = DataSource(source_cuts)
self.target_cuts = DataSource(target_cuts)
# Constraints
self.source_constraints = TimeConstraint(
max_duration=max_source_duration,
max_cuts=max_cuts,
)
self.target_constraints = TimeConstraint(
max_duration=max_target_duration,
max_cuts=max_cuts,
)
@property
def remaining_duration(self) -> Optional[float]:
"""
Remaining duration of data left in the sampler (may be inexact due to float arithmetic).
Not available when the CutSet is read in lazy mode (returns None).
.. note: For :class:`.CutPairsSampler` we return the source cuts duration.
"""
return self.source_cuts.remaining_duration
@property
def remaining_cuts(self) -> Optional[int]:
"""
Remaining number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
return self.source_cuts.remaining_cuts
@property
def num_cuts(self) -> Optional[int]:
"""
Total number of cuts in the sampler.
Not available when the CutSet is read in lazy mode (returns None).
"""
if self.source_cuts.is_lazy:
return None
return len(self.source_cuts)
[docs] def state_dict(self) -> Dict[str, Any]:
"""
Return the current state of the sampler in a state_dict.
Together with ``load_state_dict()``, this can be used to restore the
training loop's state to the one stored in the state_dict.
"""
state_dict = super().state_dict()
state_dict.update(
{
"source_constraints": self.source_constraints.state_dict(),
"target_constraints": self.target_constraints.state_dict(),
}
)
return state_dict
[docs] def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
"""
Restore the state of the sampler that is described in a state_dict.
This will result in the sampler yielding batches from where the previous training left it off.
.. caution::
The samplers are expected to be initialized with the same CutSets,
but this is not explicitly checked anywhere.
.. caution::
The input ``state_dict`` is being mutated: we remove each consumed key, and expect
it to be empty at the end of loading. If you don't want this behavior, pass a copy
inside of this function (e.g., using ``import deepcopy``).
.. note::
For implementers of sub-classes of CutSampler: the flag ``self._just_restored_state`` has to be
handled in ``__iter__`` to make it avoid resetting the just-restored state (only once).
"""
source_constraints = TimeConstraint(**state_dict.pop("source_constraints"))
if self.source_constraints != source_constraints:
warnings.warn(
"CutPairsSampler.load_state_dict(): Inconsistent source_constraint:\n"
f"expected {self.source_constraints}\n"
f"received {source_constraints}\n"
f"We will overwrite the settings with the received state_dict."
)
self.source_constraints = source_constraints
target_constraints = TimeConstraint(**state_dict.pop("target_constraints"))
if self.source_constraints != target_constraints:
warnings.warn(
"CutPairsSampler.load_state_dict(): Inconsistent target_constraint:\n"
f"expected {self.target_constraints}\n"
f"received {target_constraints}\n"
f"We will overwrite the settings with the received state_dict."
)
self.target_constraints = target_constraints
super().load_state_dict(state_dict)
# Restore the data source's state
if self.shuffle:
self.source_cuts.shuffle(self.seed + self.epoch)
self.target_cuts.shuffle(self.seed + self.epoch)
self.source_cuts.fast_forward(self.diagnostics.current_epoch_stats.total_cuts)
self.target_cuts.fast_forward(self.diagnostics.current_epoch_stats.total_cuts)
def __iter__(self) -> "CutPairsSampler":
"""
Prepare the dataset for iterating over a new epoch. Will shuffle the data if requested.
"""
# Restored state with load_state_dict()? Skip resetting only this once.
if self._just_restored_state:
return self
# Why reset the current epoch?
# Either we are iterating the epoch for the first time and it's a no-op,
# or we are iterating the same epoch again, in which case setting more steps
# than are actually available per epoch would have broken the checkpoint restoration.
self.diagnostics.reset_current_epoch()
# Reset the state to the beginning of the epoch.
if self.shuffle:
self.source_cuts.shuffle(self.seed + self.epoch)
self.target_cuts.shuffle(self.seed + self.epoch)
iter(self.source_cuts)
iter(self.target_cuts)
return self
def _next_batch(self) -> Tuple[CutSet, CutSet]:
# Keep iterating the underlying CutSets as long as we hit or exceed the constraints
# provided by user (the max number of source_feats or max number of cuts).
# Note: no actual data is loaded into memory yet because the manifests contain all the metadata
# required to do this operation.
self.source_constraints.reset()
self.target_constraints.reset()
source_cuts = []
target_cuts = []
while True:
# Check that we have not reached the end of the dataset.
try:
# We didn't - grab the next cut
next_source_cut = next(self.source_cuts)
next_target_cut = next(self.target_cuts)
assert next_source_cut.id == next_target_cut.id, (
"Sampled source and target cuts with differing IDs. "
"Ensure that your source and target cuts have the same length, "
"the same IDs, and the same order."
)
except StopIteration:
# No more cuts to sample from: if we have a partial batch,
# we may output it, unless the user requested to drop it.
# We also check if the batch is "almost there" to override drop_last.
if source_cuts and (
not self.drop_last
or self.source_constraints.close_to_exceeding()
or self.target_constraints.close_to_exceeding()
):
# We have a partial batch and we can return it.
assert len(source_cuts) == len(
target_cuts
), "Unexpected state: some cuts in source / target are missing their counterparts..."
return CutSet.from_cuts(source_cuts), CutSet.from_cuts(target_cuts)
else:
# There is nothing more to return or it's discarded:
# signal the iteration code to stop.
self.diagnostics.discard(source_cuts)
raise StopIteration()
# Check whether the cuts we're about to sample satisfy optional user-requested predicate.
if not self._filter_fn(next_source_cut) or not self._filter_fn(
next_target_cut
):
# No - try another one.
self.diagnostics.discard_single(next_source_cut)
continue
self.source_constraints.add(next_source_cut)
self.target_constraints.add(next_target_cut)
# Did we exceed the max_source_duration and max_cuts constraints?
if (
not self.source_constraints.exceeded()
and not self.target_constraints.exceeded()
):
# No - add the next cut to the batch, and keep trying.
source_cuts.append(next_source_cut)
target_cuts.append(next_target_cut)
else:
# Yes. Do we have at least one cut in the batch?
if source_cuts:
# Yes. Return it.
self.source_cuts.take_back(next_source_cut)
self.target_cuts.take_back(next_target_cut)
break
else:
# No. We'll warn the user that the constrains might be too tight,
# and return the cut anyway.
warnings.warn(
"The first cut drawn in batch collection violates one of the max_... constraints"
"we'll return it anyway. Consider increasing max_source_duration/max_cuts/etc."
)
source_cuts.append(next_source_cut)
target_cuts.append(next_target_cut)
assert len(source_cuts) == len(
target_cuts
), "Unexpected state: some cuts in source / target are missing their counterparts..."
return CutSet.from_cuts(source_cuts), CutSet.from_cuts(target_cuts)