import logging
import random
from pathlib import Path
from typing import (
Callable,
Dict,
Generator,
Iterable,
List,
Optional,
Sequence,
Tuple,
Union,
)
import torch
from cytoolz import compose_left
from lhotse import CutSet, Seconds
from lhotse.cut.set import deserialize_cut
from lhotse.dataset.dataloading import get_rank, get_world_size
from lhotse.dataset.sampling.base import SamplingDiagnostics
from lhotse.lazy import Dillable
from lhotse.serialization import decode_json_line
from lhotse.utils import Pathlike
PathlikeAndScale = Tuple[Pathlike, float]
[docs]class StatelessSampler(torch.utils.data.Sampler, Dillable):
"""
An infinite and stateless cut sampler that selects data at random from one or more cut manifests.
The main idea is to make training resumption easy while guaranteeing the data seen each
time by the model is shuffled differently.
It discards the notion of an "epoch" and it never finishes iteration.
It makes no strong guarantees about avoiding data duplication, but in practice you would
rarely see duplicated data.
The recommended way to use this sampler is by placing it into a dataloader worker
subprocess with Lhotse's :class:``~lhotse.dataset.iterable_dataset.IterableDatasetWrapper``,
so that each worker has its own sampler replica that uses a slightly different random seed::
>>> import torch
>>> import lhotse
>>> dloader = torch.utils.data.DataLoader(
... lhotse.dataset.iterable_dataset.IterableDatasetWrapper(
... dataset=lhotse.dataset.K2SpeechRecognitionDataset(...),
... sampler=StatelessSampler(...),
... ),
... batch_size=None,
... num_workers=4,
... )
This sampler's design was originally proposed by Dan Povey. For details see:
https://github.com/lhotse-speech/lhotse/issues/1096
Example 1: Get a non-bucketing :class:``.StatelessSampler``::
>>> sampler = StatelessSampler(
... cuts_paths=["data/cuts_a.jsonl", "data/cuts_b.jsonl"],
... index_path="data/files.idx",
... max_duration=600.0,
... )
Example 2: Get a bucketing :class:``.StatelessSampler``::
>>> sampler = StatelessSampler(
... cuts_paths=["data/cuts_a.jsonl", "data/cuts_b.jsonl"],
... index_path="data/files.idx",
... max_duration=600.0,
... num_buckets=50,
... quadratic_duration=30.0,
... )
Example 3: Get a bucketing :class:``.StatelessSampler`` with scaled weights for each cutset::
>>> sampler = StatelessSampler(
... cuts_paths=[
... ("data/cuts_a.jsonl", 2.0),
... ("data/cuts_b.jsonl", 1.0),
... ],
... index_path="data/files.idx",
... max_duration=600.0,
... num_buckets=50,
... quadratic_duration=30.0,
... )
.. note:: This sampler works only with uncompressed jsonl manifests, as it creates extra index files
with line byte offsets to quickly find and sample JSON lines.
This means this sampler will not work with Webdataset and Lhotse Shar data format.
:param cuts_paths: Path, or list of paths, or list of tuples of (path, scale) to cutset files.
:param index_path: Path to a file that contains the index of all cutsets and their line count
(will be auto-created the first time this object is initialized).
:param base_seed: Int, user-provided part of the seed used to initialize the RNG
for sampling (each node and worker are still going to produce different results).
When continuing the training it should be a function of the number of training steps
to ensure the model doesn't see identical mini-batches again.
:param max_duration: Maximum total number of audio seconds in a mini-batch (dynamic batch size).
:param max_cuts: Maximum number of examples in a mini-batch (static batch size).
:param num_buckets: If set, enables bucketing (each mini-batch has examples of a similar duration).
:param duration_bins: A list of floats (seconds); when provided, we'll skip the initial
estimation of bucket duration bins (useful to speed-up the launching of experiments).
:param quadratic_duration: If set, adds a penalty term for longer duration cuts.
Works well with models that have quadratic time complexity to keep GPU utilization similar
when using bucketing. Suggested values are between 30 and 45.
"""
[docs] def __init__(
self,
cuts_paths: Union[Pathlike, Iterable[Pathlike], Iterable[PathlikeAndScale]],
index_path: Pathlike,
base_seed: int,
max_duration: Optional[Seconds] = None,
max_cuts: Optional[int] = None,
num_buckets: Optional[int] = None,
duration_bins: List[Seconds] = None,
quadratic_duration: Optional[Seconds] = None,
) -> None:
super().__init__(data_source=None)
# Parse file paths and scales
self.paths = []
self.scales = []
if isinstance(cuts_paths, (Path, str)):
# Single path input
self.paths.append(Path(cuts_paths))
self.scales.append(1.0)
else:
# Multiple path input
cuts_paths = list(cuts_paths)
self.paths = []
if isinstance(cuts_paths[0], (Path, str)):
# Without scales
for p in cuts_paths:
assert isinstance(
p, (Path, str)
), "Mixing paths with and without scales is not allowed."
self.paths.append(Path(p))
self.scales.append(1.0)
else:
for tpl in cuts_paths:
# With scales
assert len(tpl) == 2, (
f"Expected (path, scale) but got: {tpl} "
f"[note: mixing paths with and without scales is not allowed]"
)
p, scale = tpl
assert isinstance(
p, (Path, str)
), f"Path must be a string or Path, got: {p}"
assert isinstance(
scale, (int, float)
), f"Scale must be an int or float, got: {scale}"
self.paths.append(Path(p))
self.scales.append(scale)
self.index_path = index_path
self.max_duration = max_duration
self.max_cuts = max_cuts
self.num_buckets = num_buckets
self.duration_bins = duration_bins
self.quadratic_duration = quadratic_duration
self.base_seed = base_seed
assert any(
v is not None for v in (self.max_duration, self.max_cuts)
), "At least one of max_duration or max_cuts has to be set."
self.diagnostics = SamplingDiagnostics()
self.index = ManifestIndex(self.paths, self.index_path)
self.scaled_line_counts = [
lc * scale
for lc, scale in zip(self.index.line_counts.values(), self.scales)
]
self._transforms = []
# DDP related info
self.rank = get_rank()
self.world_size = get_world_size()
[docs] def map(self, fn: Callable[[CutSet], CutSet]) -> "StatelessSampler":
"""Apply ``fn`` to each mini-batch of ``CutSet`` before yielding it."""
self._transforms.append(fn)
return self
[docs] def state_dict(self) -> Dict:
"""Stub state_dict method that returns nothing - this sampler is stateless."""
return {}
[docs] def load_state_dict(self, state_dict: Dict) -> None:
"""Stub load_state_dict method that does nothing - this sampler is stateless."""
return
def __iter__(self) -> Generator[CutSet, None, None]:
from lhotse.dataset import DynamicBucketingSampler, DynamicCutSampler
worker_info = torch.utils.data.get_worker_info()
worker_id = 0 if worker_info is None else worker_info.id
my_id = worker_id + 1000 * self.rank
seed = self.base_seed + my_id
rng = random.Random(seed)
logging.info(
f"[{type(self).__name__}] Initialized sampler RNG with seed {seed} (== base_seed={self.base_seed} + my_id={my_id}) "
f"[ddp_rank={self.rank} worker_id={worker_id}]"
)
def _inner():
"""
Infinite generator of cuts.
Each cut is samples in two steps:
- first we select a cutset file, weighted by line count (num cuts)
- then we randomly select a line from that file using uniform distribution
"""
n = 0
while True:
# Choose a file
path = rng.choices(self.paths, self.scaled_line_counts)[0]
# Choose a line
line_offsets = self.index.line_offsets[path]
begin_idx = rng.randrange(len(line_offsets) - 1)
begin, end = line_offsets[begin_idx], line_offsets[begin_idx + 1]
# Read JSON line and initialize a Cut object
with path.open() as f:
f.seek(begin)
line = f.read(end - begin)
data = decode_json_line(line)
cut = deserialize_cut(data)
# Update cut ID since the same item may land in a single mini-batch
# (note: CutSet prohibits two items sharing the same ID)
cut.id = f"{cut.id}_it{n}"
yield cut
n += 1
if self.num_buckets is not None or self.duration_bins is not None:
inner_sampler = DynamicBucketingSampler(
_inner(),
max_duration=self.max_duration,
max_cuts=self.max_cuts,
num_buckets=self.num_buckets,
duration_bins=self.duration_bins,
shuffle=False,
drop_last=False,
quadratic_duration=self.quadratic_duration,
world_size=1,
rank=0,
)
else:
inner_sampler = DynamicCutSampler(
_inner(),
max_duration=self.max_duration,
max_cuts=self.max_cuts,
shuffle=False,
drop_last=False,
world_size=1,
rank=0,
)
inner_sampler.map(compose_left(*self._transforms))
self.diagnostics = inner_sampler.diagnostics
yield from inner_sampler
[docs] def get_report(self) -> str:
"""Returns a string describing the statistics of the sampling process so far."""
return self.diagnostics.get_report()
class ManifestIndex:
"""
An index of line count and line offset for each cutset manifest.
When created for the first time, it writes a .jsonl.idx file for each .jsonl file that contains byte offsets for each line.
It also writes a file at ``index_path`` that has the line count and path for each manifest.
When this object is instantiated again (e.g. when resuming training), it will just load the contents of existing files from disk.
Objects of this class expose two members: ``line_counts: Dict[Path, int]`` and ``line_offsets: Dict[Path, List[int]]`` to simplify working with manifests.
:param manifest_paths: A list of paths to cut sets.
:param index_path: A path where we should write the line count index (if it doesn't exist).
:param force: When true, we'll ignore existing files and reindex the cutsets.
"""
def __init__(
self,
manifest_paths: Sequence[Pathlike],
index_path: Pathlike,
force: bool = False,
) -> None:
self.line_counts: Dict[Path, int] = {}
self.line_offsets: Dict[Path, Tuple[int]] = {}
for p in map(Path, manifest_paths):
assert (
p.suffix == ".jsonl"
), f"We only support uncompressed .jsonl files in this sampler, but received: {p}"
offset_path = p.with_suffix(".jsonl.idx")
if offset_path.is_file() and not force:
offsets = self._load(offset_path)
else:
offsets = self._process(p, offset_path)
self.line_counts[p] = len(offsets)
self.line_offsets[p] = offsets
# Write a single cutset index in format:
# <number-of-lines> <cutset-path>
# Example:
# 10015 data/cuts-part-0001.jsonl
# 376101 data/cuts-part-0002.jsonl
# 572 data/cuts-part-0003.jsonl
if not index_path.is_file() or force:
with index_path.open("w") as index_f:
for p, lc in self.line_counts.items():
print(f"{lc} {p}", file=index_f)
def _load(self, file_index: Path) -> Tuple[int]:
with file_index.open() as f:
offsets = tuple(map(int, f))
return offsets
def _process(self, manifest: Path, file_index: Path) -> Tuple[int]:
# Write line index for each cutset in format <begin-byte> per line, e.g.:
# 0
# 214
# 357
# ...
offsets = [0]
with manifest.open() as cuts_f, file_index.open("w") as index_f:
print(0, file=index_f)
line = cuts_f.readline()
while line:
offsets.append(cuts_f.tell())
print(offsets[-1], file=index_f)
line = cuts_f.readline()
return tuple(offsets)