import logging
from collections import defaultdict
from dataclasses import dataclass
from itertools import groupby, islice
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
Mapping,
NamedTuple,
Optional,
Union,
)
from tqdm import tqdm
from lhotse.custom import CustomFieldMixin
from lhotse.lazy import AlgorithmMixin
from lhotse.serialization import Serializable
from lhotse.utils import (
Pathlike,
Seconds,
TimeSpan,
add_durations,
asdict_nonull,
compute_num_samples,
exactly_one_not_null,
fastcopy,
ifnone,
is_equal_or_contains,
overspans,
perturb_num_samples,
split_manifest_lazy,
split_sequence,
)
[docs]
class AlignmentItem(NamedTuple):
"""
This class contains an alignment item, for example a word, along with its
start time (w.r.t. the start of recording) and duration. It can potentially
be used to store other kinds of alignment items, such as subwords, pdfid's etc.
"""
symbol: str
start: Seconds
duration: Seconds
# Score is an optional aligner-specific measure of confidence.
# A simple measure can be an average probability of "symbol" across
# frames covered by the AlignmentItem.
score: Optional[float] = None
[docs]
@staticmethod
def deserialize(data: Union[List, Dict]) -> "AlignmentItem":
if isinstance(data, dict):
# Support loading alignments stored in the format we had before Lhotse v1.8
return AlignmentItem(*list(data.values()))
return AlignmentItem(*data)
[docs]
def serialize(self) -> list:
return list(self)
@property
def end(self) -> Seconds:
return round(self.start + self.duration, ndigits=8)
[docs]
def with_offset(self, offset: Seconds) -> "AlignmentItem":
"""Return an identical ``AlignmentItem``, but with the ``offset`` added to the ``start`` field."""
return AlignmentItem(
start=add_durations(self.start, offset, sampling_rate=48000),
duration=self.duration,
symbol=self.symbol,
score=self.score,
)
[docs]
def perturb_speed(self, factor: float, sampling_rate: int) -> "AlignmentItem":
"""
Return an ``AlignmentItem`` that has time boundaries matching the
recording/cut perturbed with the same factor. See :meth:`SupervisionSegment.perturb_speed`
for details.
"""
start_sample = compute_num_samples(self.start, sampling_rate)
num_samples = compute_num_samples(self.duration, sampling_rate)
new_start = perturb_num_samples(start_sample, factor) / sampling_rate
new_duration = perturb_num_samples(num_samples, factor) / sampling_rate
return AlignmentItem(
symbol=self.symbol, start=new_start, duration=new_duration, score=self.score
)
[docs]
def trim(self, end: Seconds, start: Seconds = 0) -> "AlignmentItem":
"""
See :meth:`SupervisionSegment.trim`.
"""
assert start >= 0
start_exceeds_by = abs(min(0, self.start - start))
end_exceeds_by = max(0, self.end - end)
return AlignmentItem(
symbol=self.symbol,
start=max(start, self.start),
duration=add_durations(
self.duration, -end_exceeds_by, -start_exceeds_by, sampling_rate=48000
),
)
[docs]
@dataclass
class SupervisionSegment(CustomFieldMixin):
"""
:class:`~lhotse.supervsion.SupervisionSegment` represents a time interval (segment) annotated with some
supervision labels and/or metadata, such as the transcription, the speaker identity, the language, etc.
Each supervision has unique ``id`` and always refers to a specific recording (via ``recording_id``)
and one or more ``channel`` (by default, 0). Note that multiple channels of the recording
may share the same supervision, in which case the ``channel`` field will be a list of integers.
It's also characterized by the start time (relative to the beginning of a :class:`~lhotse.audio.Recording`
or a :class:`~lhotse.cut.Cut`) and a duration, both expressed in seconds.
The remaining fields are all optional, and their availability depends on specific corpora.
Since it is difficult to predict all possible types of metadata, the ``custom`` field (a dict) can be used to
insert types of supervisions that are not supported out of the box.
:class:`~lhotse.supervsion.SupervisionSegment` may contain multiple types of alignments.
The ``alignment`` field is a dict, indexed by alignment's type (e.g., ``word`` or ``phone``),
and contains a list of :class:`~lhotse.supervision.AlignmentItem` objects -- simple structures
that contain a given symbol and its time interval.
Alignments can be read from CTM files or created programatically.
Examples
A simple segment with no supervision information::
>>> from lhotse import SupervisionSegment
>>> sup0 = SupervisionSegment(
... id='rec00001-sup00000', recording_id='rec00001',
... start=0.5, duration=5.0, channel=0
... )
Typical supervision containing transcript, speaker ID, gender, and language::
>>> sup1 = SupervisionSegment(
... id='rec00001-sup00001', recording_id='rec00001',
... start=5.5, duration=3.0, channel=0,
... text='transcript of the second segment',
... speaker='Norman Dyhrentfurth', language='English', gender='M'
... )
Two supervisions denoting overlapping speech on two separate channels in a microphone array/multiple headsets
(pay attention to ``start``, ``duration``, and ``channel``)::
>>> sup2 = SupervisionSegment(
... id='rec00001-sup00002', recording_id='rec00001',
... start=15.0, duration=5.0, channel=0,
... text="i have incredibly good news for you",
... speaker='Norman Dyhrentfurth', language='English', gender='M'
... )
>>> sup3 = SupervisionSegment(
... id='rec00001-sup00003', recording_id='rec00001',
... start=18.0, duration=3.0, channel=1,
... text="say what",
... speaker='Hervey Arman', language='English', gender='M'
... )
A supervision with a phone alignment::
>>> from lhotse.supervision import AlignmentItem
>>> sup4 = SupervisionSegment(
... id='rec00001-sup00004', recording_id='rec00001',
... start=33.0, duration=1.0, channel=0,
... text="ice",
... speaker='Maryla Zechariah', language='English', gender='F',
... alignment={
... 'phone': [
... AlignmentItem(symbol='AY0', start=33.0, duration=0.6),
... AlignmentItem(symbol='S', start=33.6, duration=0.4)
... ]
... }
... )
A supervision shared across multiple channels of a recording (e.g. a microphone array)::
>>> sup5 = SupervisionSegment(
... id='rec00001-sup00005', recording_id='rec00001',
... start=33.0, duration=1.0, channel=[0, 1],
... text="ice",
... speaker='Maryla Zechariah',
... )
Converting :class:`~lhotse.supervsion.SupervisionSegment` to a ``dict``::
>>> sup0.to_dict()
{'id': 'rec00001-sup00000', 'recording_id': 'rec00001', 'start': 0.5, 'duration': 5.0, 'channel': 0}
"""
id: str
recording_id: str
start: Seconds
duration: Seconds
channel: Union[int, List[int]] = 0
text: Optional[str] = None
language: Optional[str] = None
speaker: Optional[str] = None
gender: Optional[str] = None
custom: Optional[Dict[str, Any]] = None
alignment: Optional[Dict[str, List[AlignmentItem]]] = None
@property
def end(self) -> Seconds:
return round(self.start + self.duration, ndigits=8)
[docs]
def with_alignment(
self, kind: str, alignment: List[AlignmentItem]
) -> "SupervisionSegment":
alis = self.alignment
if alis is None:
alis = {}
alis[kind] = alignment
return fastcopy(self, alignment=alis)
[docs]
def with_offset(self, offset: Seconds) -> "SupervisionSegment":
"""Return an identical ``SupervisionSegment``, but with the ``offset`` added to the ``start`` field."""
return SupervisionSegment(
id=self.id,
recording_id=self.recording_id,
start=round(self.start + offset, ndigits=8),
duration=self.duration,
channel=self.channel,
text=self.text,
language=self.language,
speaker=self.speaker,
gender=self.gender,
custom=self.custom,
alignment=self.alignment,
)
[docs]
def perturb_speed(
self, factor: float, sampling_rate: int, affix_id: bool = True
) -> "SupervisionSegment":
"""
Return a ``SupervisionSegment`` that has time boundaries matching the
recording/cut perturbed with the same factor.
:param factor: The speed will be adjusted this many times (e.g. factor=1.1 means 1.1x faster).
:param sampling_rate: The sampling rate is necessary to accurately perturb the start
and duration (going through the sample counts).
:param affix_id: When true, we will modify the ``id`` and ``recording_id`` fields
by affixing it with "_sp{factor}".
:return: a modified copy of the current ``SupervisionSegment``.
"""
start_sample = compute_num_samples(self.start, sampling_rate)
num_samples = compute_num_samples(self.duration, sampling_rate)
new_start = perturb_num_samples(start_sample, factor) / sampling_rate
new_duration = perturb_num_samples(num_samples, factor) / sampling_rate
return fastcopy(
self,
id=f"{self.id}_sp{factor}" if affix_id else self.id,
recording_id=f"{self.recording_id}_sp{factor}"
if affix_id
else self.recording_id,
start=new_start,
duration=new_duration,
alignment={
type: [
item.perturb_speed(factor=factor, sampling_rate=sampling_rate)
for item in ali
]
for type, ali in self.alignment.items()
}
if self.alignment
else None,
)
[docs]
def perturb_tempo(
self, factor: float, sampling_rate: int, affix_id: bool = True
) -> "SupervisionSegment":
"""
Return a ``SupervisionSegment`` that has time boundaries matching the
recording/cut perturbed with the same factor.
:param factor: The tempo will be adjusted this many times (e.g. factor=1.1 means 1.1x faster).
:param sampling_rate: The sampling rate is necessary to accurately perturb the start
and duration (going through the sample counts).
:param affix_id: When true, we will modify the ``id`` and ``recording_id`` fields
by affixing it with "_tp{factor}".
:return: a modified copy of the current ``SupervisionSegment``.
"""
# speed and tempo perturbation have the same effect on supervisions
perturbed = self.perturb_speed(factor, sampling_rate, affix_id=False)
return fastcopy(
perturbed,
id=f"{self.id}_tp{factor}" if affix_id else self.id,
recording_id=f"{self.recording_id}_tp{factor}"
if affix_id
else self.recording_id,
)
[docs]
def perturb_volume(
self, factor: float, affix_id: bool = True
) -> "SupervisionSegment":
"""
Return a ``SupervisionSegment`` with modified ids.
:param factor: The volume will be adjusted this many times (e.g. factor=1.1 means 1.1x louder).
:param affix_id: When true, we will modify the ``id`` and ``recording_id`` fields
by affixing it with "_vp{factor}".
:return: a modified copy of the current ``SupervisionSegment``.
"""
return fastcopy(
self,
id=f"{self.id}_vp{factor}" if affix_id else self.id,
recording_id=f"{self.recording_id}_vp{factor}"
if affix_id
else self.recording_id,
)
[docs]
def reverb_rir(
self, affix_id: bool = True, channel: Optional[Union[int, List[int]]] = None
) -> "SupervisionSegment":
"""
Return a ``SupervisionSegment`` with modified ids.
:param affix_id: When true, we will modify the ``id`` and ``recording_id`` fields
by affixing it with "_rvb".
:return: a modified copy of the current ``SupervisionSegment``.
"""
return fastcopy(
self,
id=f"{self.id}_rvb" if affix_id else self.id,
recording_id=f"{self.recording_id}_rvb" if affix_id else self.recording_id,
channel=channel if channel is not None else self.channel,
)
[docs]
def trim(self, end: Seconds, start: Seconds = 0) -> "SupervisionSegment":
"""
Return an identical ``SupervisionSegment``, but ensure that ``self.start`` is not negative (in which case
it's set to 0) and ``self.end`` does not exceed the ``end`` parameter. If a `start` is optionally
provided, the supervision is trimmed from the left (note that start should be relative to the cut times).
This method is useful for ensuring that the supervision does not exceed a cut's bounds,
in which case pass ``cut.duration`` as the ``end`` argument, since supervision times are relative to the cut.
"""
assert start >= 0
start_exceeds_by = abs(min(0, self.start - start))
end_exceeds_by = max(0, self.end - end)
return fastcopy(
self,
start=max(start, self.start),
duration=add_durations(
self.duration, -end_exceeds_by, -start_exceeds_by, sampling_rate=48000
),
alignment={
type: [item.trim(end=end, start=start) for item in ali]
for type, ali in self.alignment.items()
}
if self.alignment
else None,
)
[docs]
def map(
self, transform_fn: Callable[["SupervisionSegment"], "SupervisionSegment"]
) -> "SupervisionSegment":
"""
Return a copy of the current segment, transformed with ``transform_fn``.
:param transform_fn: a function that takes a segment as input, transforms it and returns a new segment.
:return: a modified ``SupervisionSegment``.
"""
return transform_fn(self)
[docs]
def transform_text(
self, transform_fn: Callable[[str], str]
) -> "SupervisionSegment":
"""
Return a copy of the current segment with transformed ``text`` field.
Useful for text normalization, phonetic transcription, etc.
:param transform_fn: a function that accepts a string and returns a string.
:return: a ``SupervisionSegment`` with adjusted text.
"""
if self.text is None:
return self
return fastcopy(self, text=transform_fn(self.text))
[docs]
def to_dict(self) -> dict:
if self.alignment is None:
return asdict_nonull(self)
else:
alis = {
kind: [item.serialize() for item in ali]
for kind, ali in self.alignment.items()
}
data = asdict_nonull(fastcopy(self, alignment=None))
data["alignment"] = alis
return data
[docs]
@staticmethod
def from_dict(data: dict) -> "SupervisionSegment":
from lhotse.serialization import deserialize_custom_field
if "custom" in data:
deserialize_custom_field(data["custom"])
if "alignment" in data:
data["alignment"] = {
k: [AlignmentItem.deserialize(x) for x in v]
for k, v in data["alignment"].items()
}
return SupervisionSegment(**data)
[docs]
class SupervisionSet(Serializable, AlgorithmMixin):
"""
:class:`~lhotse.supervision.SupervisionSet` represents a collection of segments containing some
supervision information (see :class:`~lhotse.supervision.SupervisionSegment`).
It acts as a Python ``list``, extended with an efficient ``find`` operation that indexes and caches
the supervision segments in an interval tree.
It allows to quickly find supervision segments that correspond to a specific time interval.
However, it can also work with lazy iterables.
When coming from Kaldi, think of :class:`~lhotse.supervision.SupervisionSet` as a ``segments`` file on steroids,
that may also contain *text*, *utt2spk*, *utt2gender*, *utt2dur*, etc.
Examples
Building a :class:`~lhotse.supervision.SupervisionSet`::
>>> from lhotse import SupervisionSet, SupervisionSegment
>>> sups = SupervisionSet.from_segments([SupervisionSegment(...), ...])
Writing/reading a :class:`~lhotse.supervision.SupervisionSet`::
>>> sups.to_file('supervisions.jsonl.gz')
>>> sups2 = SupervisionSet.from_file('supervisions.jsonl.gz')
Using :class:`~lhotse.supervision.SupervisionSet` like a dict::
>>> 'rec00001-sup00000' in sups
True
>>> sups['rec00001-sup00000']
SupervisionSegment(id='rec00001-sup00000', recording_id='rec00001', start=0.5, ...)
>>> for segment in sups:
... pass
Searching by ``recording_id`` and time interval::
>>> matched_segments = sups.find(recording_id='rec00001', start_after=17.0, end_before=25.0)
Manipulation::
>>> longer_than_5s = sups.filter(lambda s: s.duration > 5)
>>> first_100 = sups.subset(first=100)
>>> split_into_4 = sups.split(num_splits=4)
>>> shuffled = sups.shuffle()
"""
[docs]
def __init__(self, segments: Optional[Iterable[SupervisionSegment]] = None) -> None:
self.segments = ifnone(segments, {})
def __eq__(self, other: "SupervisionSet") -> bool:
return self.segments == other.segments
@property
def data(
self,
) -> Union[Dict[str, SupervisionSegment], Iterable[SupervisionSegment]]:
"""Alias property for ``self.segments``"""
return self.segments
@property
def ids(self) -> Iterable[str]:
return (s.id for s in self)
[docs]
@staticmethod
def from_segments(segments: Iterable[SupervisionSegment]) -> "SupervisionSet":
return SupervisionSet(list(segments))
from_items = from_segments
[docs]
@staticmethod
def from_dicts(data: Iterable[Dict]) -> "SupervisionSet":
return SupervisionSet.from_segments(
SupervisionSegment.from_dict(s) for s in data
)
[docs]
@staticmethod
def from_rttm(path: Union[Pathlike, Iterable[Pathlike]]) -> "SupervisionSet":
"""
Read an RTTM file located at ``path`` (or an iterator) and create a :class:`.SupervisionSet` manifest for them.
Can be used to create supervisions from custom RTTM files (see, for example, :class:`lhotse.dataset.DiarizationDataset`).
.. code:: python
>>> from lhotse import SupervisionSet
>>> sup1 = SupervisionSet.from_rttm('/path/to/rttm_file')
>>> sup2 = SupervisionSet.from_rttm(Path('/path/to/rttm_dir').rglob('ref_*'))
The following description is taken from the [dscore](https://github.com/nryant/dscore#rttm) toolkit:
Rich Transcription Time Marked (RTTM) files are space-delimited text files
containing one turn per line, each line containing ten fields:
- ``Type`` -- segment type; should always by ``SPEAKER``
- ``File ID`` -- file name; basename of the recording minus extension (e.g.,
``rec1_a``)
- ``Channel ID`` -- channel (1-indexed) that turn is on; should always be
``1``
- ``Turn Onset`` -- onset of turn in seconds from beginning of recording
- ``Turn Duration`` -- duration of turn in seconds
- ``Orthography Field`` -- should always by ``<NA>``
- ``Speaker Type`` -- should always be ``<NA>``
- ``Speaker Name`` -- name of speaker of turn; should be unique within scope
of each file
- ``Confidence Score`` -- system confidence (probability) that information
is correct; should always be ``<NA>``
- ``Signal Lookahead Time`` -- should always be ``<NA>``
For instance:
SPEAKER CMU_20020319-1400_d01_NONE 1 130.430000 2.350 <NA> <NA> juliet <NA> <NA>
SPEAKER CMU_20020319-1400_d01_NONE 1 157.610000 3.060 <NA> <NA> tbc <NA> <NA>
SPEAKER CMU_20020319-1400_d01_NONE 1 130.490000 0.450 <NA> <NA> chek <NA> <NA>
:param path: Path to RTTM file or an iterator of paths to RTTM files.
:return: a new ``SupervisionSet`` instance containing segments from the RTTM file.
"""
from pathlib import Path
path = [path] if isinstance(path, (Path, str)) else path
segments = []
for file in path:
with open(file, "r") as f:
for idx, line in enumerate(f):
parts = line.strip().split()
assert len(parts) == 10, f"Invalid RTTM line in file {file}: {line}"
recording_id = parts[1]
if float(parts[4]) == 0: # skip empty segments
continue
segments.append(
SupervisionSegment(
id=f"{recording_id}-{idx:06d}",
recording_id=recording_id,
channel=int(parts[2]),
start=float(parts[3]),
duration=float(parts[4]),
speaker=parts[7],
)
)
return SupervisionSet.from_segments(segments)
[docs]
def with_alignment_from_ctm(
self,
ctm_file: Pathlike,
type: str = "word",
match_channel: bool = False,
verbose: bool = False,
) -> "SupervisionSet":
"""
Add alignments from CTM file to the supervision set.
:param ctm: Path to CTM file.
:param type: Alignment type (optional, default = `word`).
:param match_channel: if True, also match channel between CTM and SupervisionSegment
:param verbose: if True, show progress bar
:return: A new SupervisionSet with AlignmentItem objects added to the segments.
"""
ctm_words = []
# Sometimes the channels may not be integers, so we map them here.
channel_to_int = {}
with open(ctm_file) as f:
f = tqdm(f, desc="Reading words from CTM file") if verbose else f
for line in f:
reco_id, channel, start, duration, symbol, *score = line.strip().split()
ctm_words.append(
(
reco_id,
int(channel),
float(start),
float(duration),
symbol,
float(score[0]) if score else None,
)
)
ctm_words = sorted(ctm_words, key=lambda x: (x[0], x[2]))
reco_to_ctm = defaultdict(
list, {k: list(v) for k, v in groupby(ctm_words, key=lambda x: x[0])}
)
segments = []
num_total = len(ctm_words)
num_overspanned = 0
recordings = set([s.recording_id for s in self])
recordings = (
tqdm(recordings, desc="Adding alignments") if verbose else recordings
)
for reco_id in recordings:
if reco_id in reco_to_ctm:
for seg in self.find(recording_id=reco_id):
alignment = [
AlignmentItem(
symbol=word[4],
start=word[2],
duration=word[3],
score=word[5],
)
for word in reco_to_ctm[reco_id]
if overspans(seg, TimeSpan(word[2], word[2] + word[3]))
and (seg.channel == word[1] or not match_channel)
]
num_overspanned += len(alignment)
segments.append(fastcopy(seg, alignment={type: alignment}))
else:
segments.extend(
[
fastcopy(s, alignment={type: []})
for s in self.find(recording_id=reco_id)
]
)
logging.info(
f"{num_overspanned} alignments added out of {num_total} total. If there are several"
" missing, there could be a mismatch problem."
)
return SupervisionSet.from_segments(segments)
[docs]
def write_alignment_to_ctm(self, ctm_file: Pathlike, type: str = "word") -> None:
"""
Write alignments to CTM file.
:param ctm_file: Path to output CTM file (will be created if not exists)
:param type: Alignment type to write (default = `word`)
"""
with open(ctm_file, "w") as f:
for s in self:
if type in s.alignment:
for ali in s.alignment[type]:
c = s.channel[0] if isinstance(s.channel, list) else s.channel
if ali.score is None:
f.write(
f"{s.recording_id} {c} {ali.start:.02f} {ali.duration:.02f} {ali.symbol}\n"
)
else:
f.write(
f"{s.recording_id} {c} {ali.start:.02f} {ali.duration:.02f} {ali.symbol} {ali.score:.02f}\n"
)
[docs]
def to_dicts(self) -> Iterable[dict]:
return (s.to_dict() for s in self)
[docs]
def split(
self, num_splits: int, shuffle: bool = False, drop_last: bool = False
) -> List["SupervisionSet"]:
"""
Split the :class:`~lhotse.SupervisionSet` into ``num_splits`` pieces of equal size.
:param num_splits: Requested number of splits.
:param shuffle: Optionally shuffle the recordings order first.
:param drop_last: determines how to handle splitting when ``len(seq)`` is not divisible
by ``num_splits``. When ``False`` (default), the splits might have unequal lengths.
When ``True``, it may discard the last element in some splits to ensure they are
equally long.
:return: A list of :class:`~lhotse.SupervisionSet` pieces.
"""
return [
SupervisionSet.from_segments(subset)
for subset in split_sequence(
self, num_splits=num_splits, shuffle=shuffle, drop_last=drop_last
)
]
[docs]
def split_lazy(
self, output_dir: Pathlike, chunk_size: int, prefix: str = ""
) -> List["SupervisionSet"]:
"""
Splits a manifest (either lazily or eagerly opened) into chunks, each
with ``chunk_size`` items (except for the last one, typically).
In order to be memory efficient, this implementation saves each chunk
to disk in a ``.jsonl.gz`` format as the input manifest is sampled.
.. note:: For lowest memory usage, use ``load_manifest_lazy`` to open the
input manifest for this method.
:param it: any iterable of Lhotse manifests.
:param output_dir: directory where the split manifests are saved.
Each manifest is saved at: ``{output_dir}/{prefix}.{split_idx}.jsonl.gz``
:param chunk_size: the number of items in each chunk.
:param prefix: the prefix of each manifest.
:return: a list of lazily opened chunk manifests.
"""
return split_manifest_lazy(
self, output_dir=output_dir, chunk_size=chunk_size, prefix=prefix
)
[docs]
def subset(
self, first: Optional[int] = None, last: Optional[int] = None
) -> "SupervisionSet":
"""
Return a new ``SupervisionSet`` according to the selected subset criterion.
Only a single argument to ``subset`` is supported at this time.
:param first: int, the number of first supervisions to keep.
:param last: int, the number of last supervisions to keep.
:return: a new ``SupervisionSet`` with the subset results.
"""
assert exactly_one_not_null(
first, last
), "subset() can handle only one non-None arg."
if first is not None:
assert first > 0
out = SupervisionSet.from_items(islice(self, first))
return out
if last is not None:
assert last > 0
if last > len(self):
return self
return SupervisionSet.from_segments(
islice(self, len(self) - last, len(self))
)
[docs]
def transform_text(self, transform_fn: Callable[[str], str]) -> "SupervisionSet":
"""
Return a copy of the current ``SupervisionSet`` with the segments having a transformed ``text`` field.
Useful for text normalization, phonetic transcription, etc.
:param transform_fn: a function that accepts a string and returns a string.
:return: a ``SupervisionSet`` with adjusted text.
"""
return SupervisionSet.from_segments(
s.transform_text(transform_fn) for s in self
)
[docs]
def find(
self,
recording_id: str,
channel: Optional[int] = None,
start_after: Seconds = 0,
end_before: Optional[Seconds] = None,
adjust_offset: bool = False,
tolerance: Seconds = 0.001,
) -> Iterable[SupervisionSegment]:
"""
Return an iterable of segments that match the provided ``recording_id``.
:param recording_id: Desired recording ID.
:param channel: When specified, return supervisions in that channel - otherwise, in all channels.
:param start_after: When specified, return segments that start after the given value.
:param end_before: When specified, return segments that end before the given value.
:param adjust_offset: When true, return segments as if the recordings had started at ``start_after``.
This is useful for creating Cuts. From a user perspective, when dealing with a Cut, it is no
longer helpful to know when the supervisions starts in a recording - instead, it's useful to
know when the supervision starts relative to the start of the Cut.
In the anticipated use-case, ``start_after`` and ``end_before`` would be
the beginning and end of a cut;
this option converts the times to be relative to the start of the cut.
:param tolerance: Additional margin to account for floating point rounding errors
when comparing segment boundaries.
:return: An iterator over supervision segments satisfying all criteria.
"""
segment_by_recording_id = self._index_by_recording_id_and_cache()
return (
# We only modify the offset - the duration remains the same, as we're only shifting the segment
# relative to the Cut's start, and not truncating anything.
segment.with_offset(-start_after) if adjust_offset else segment
for segment in segment_by_recording_id.get(recording_id, [])
if (channel is None or is_equal_or_contains(segment.channel, channel))
and segment.start >= start_after - tolerance
and (end_before is None or segment.end <= end_before + tolerance)
)
# This is a cache that significantly speeds up repeated ``find()`` queries.
_segments_by_recording_id: Optional[Dict[str, List[SupervisionSegment]]] = None
def _index_by_recording_id_and_cache(self):
if self._segments_by_recording_id is None:
from cytoolz import groupby
self._segments_by_recording_id = groupby(lambda seg: seg.recording_id, self)
return self._segments_by_recording_id
def __repr__(self) -> str:
return f"SupervisionSet(len={len(self)})"
def __getitem__(self, index_or_id: Union[int, str]) -> SupervisionSegment:
try:
return self.segments[index_or_id] # int passed, eager manifest, fast
except TypeError:
# either lazy manifest or str passed, both are slow
if self.is_lazy:
return next(item for idx, item in enumerate(self) if idx == index_or_id)
else:
# string id passed, support just for backward compatibility, not recommended
return next(item for item in self if item.id == index_or_id)
def __contains__(self, other: Union[str, SupervisionSegment]) -> bool:
if isinstance(other, str):
return any(other == item.id for item in self)
else:
return any(other.id == item.id for item in self)
def __iter__(self) -> Iterable[SupervisionSegment]:
yield from self.segments
def __len__(self) -> int:
return len(self.segments)