import os
import random
import secrets
import sys
from functools import partial
from typing import Callable, Generator, Literal, Optional, Union
import torch
from torch import distributed as dist
from lhotse.utils import fix_random_seed
LHOTSE_PROCESS_SEED = "LHOTSE_PROCESS_SEED"
# Set by :func:`worker_init_fn` (called either by PyTorch's DataLoader in worker
# subprocesses or eagerly in the main process for the ``num_workers=0`` iterable
# path). Acts as the signal that :func:`get_worker_partition` should return a
# non-trivial ``(shard_id, num_shards)`` partition, so that indexed lazy iterators
# can split sample indices across DP rank x DataLoader worker. Map-style mode
# never calls ``worker_init_fn``, so this stays unset and partition collapses to
# ``(0, 1)`` — the sampler's own over-sample-and-discard handles DP dedup there.
LHOTSE_USE_WORKER_PARTITION = "LHOTSE_USE_WORKER_PARTITION"
[docs]def make_worker_init_fn(
rank: Optional[int] = None,
world_size: Optional[int] = None,
set_different_node_and_worker_seeds: bool = True,
seed: Optional[int] = 42,
) -> Optional[Callable[[int], None]]:
"""
Calling this function creates a worker_init_fn suitable to pass to PyTorch's DataLoader.
It helps with two issues:
* sets the random seeds differently for each worker and node, which helps with
avoiding duplication in randomized data augmentation techniques.
* sets environment variables that help WebDataset detect it's inside multi-GPU (DDP)
training, so that it correctly de-duplicates the data across nodes.
"""
return partial(
worker_init_fn,
rank=rank,
world_size=world_size,
set_different_node_and_worker_seeds=set_different_node_and_worker_seeds,
seed=seed,
)
[docs]def worker_init_fn(
worker_id: int,
rank: Optional[int] = None,
world_size: Optional[int] = None,
set_different_node_and_worker_seeds: bool = True,
seed: Optional[int] = 42,
) -> None:
"""
Function created by :func:`~lhotse.dataset.dataloading.make_worker_init_fn`, refer to its documentation for details.
"""
if set_different_node_and_worker_seeds:
process_seed = seed + 100 * worker_id
if rank is not None:
process_seed += 100000 * rank
fix_random_seed(process_seed)
os.environ[LHOTSE_PROCESS_SEED] = str(process_seed)
if rank is None and world_size is None:
return
assert (
rank is not None and world_size is not None
), f"Both args must be not None: rank={rank}, world_size={world_size}"
# This sets the rank/world_size info for WebDataset to read it in worker subprocesses.
# If we didn't do it, WebDataset will "think" this is always single-node training,
# because DataLoader workers did not initialize torch.distributed.
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
# Signal that worker-level partition is active for indexed lazy iterators
# (consumed by get_worker_partition). Map-style mode never calls this function,
# so the flag stays unset and partition is (0, 1) there.
os.environ[LHOTSE_USE_WORKER_PARTITION] = "1"
[docs]def resolve_seed(seed: Union[int, Literal["trng", "randomized"], None]) -> int:
"""
Resolves the special values of random seed supported in Lhotse.
If it's an integer, we'll just return it.
If it's "trng", we'll use the ``secrets`` module to generate a random seed
using a true RNG (to the extend supported by the OS).
If it's "randomized", we'll check whether we're in a dataloading worker of ``torch.utils.data.DataLoader``.
If we are, we expect that it was passed the result of :func:`~lhotse.dataset.dataloading.make_worker_init_fn`
into its ``worker_init_fn`` argument, in which case we'll return a special seed exclusive to that worker.
If we are not in a dataloading worker (or ``num_workers`` was set to ``0``), we'll return Python's ``random``
module global seed.
"""
# Specific number provided: use it.
if isinstance(seed, int):
return seed
# No request for a specific type of random seed resolution: return Python's global random seed.
if seed is None:
return random.getstate()[1][0]
# Deterministic randomized random seed resolution:
# Each dataloading worker and DDP rank gets a separate random seed.
# If we're not in a dataloading worker, use global RNG's current seed.
if seed == "randomized":
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
# Not in a dataloader sub-process: get Python's global random seed.
return random.getstate()[1][0]
else:
# In a dataloader sub-process: read out the seed we assigned to it.
assert LHOTSE_PROCESS_SEED in os.environ, (
"Requested seed='randomized' for shuffling shards differently "
"on each DataLoader node and worker, "
"but lhotse.dataset.dataloading.worker_init_fn was not called."
)
return int(os.environ[LHOTSE_PROCESS_SEED])
# True-random number generator requested for seed generation ("complete randomness").
if seed == "trng":
# 2**32 may trigger the following exception if you add anything:
# File "_mt19937.pyx", line 180, in numpy.random._mt19937.MT19937._legacy_seeding
# ValueError: Seed must be between 0 and 2**32 - 1
return secrets.randbelow(2**31)
raise ValueError(
f"Unexpected type or value of seed: {type(seed)=} {seed=}. "
f"Supported values are: None, int, 'trng', and 'randomized'."
)
[docs]def get_worker_partition() -> tuple:
"""
Resolve the global ``(shard_id, num_shards)`` partition for the calling
code, combining the DP rank with the DataLoader worker id.
Returns ``(shard_id, num_shards)`` where
``shard_id = rank * num_workers + worker_id`` and
``num_shards = world_size * max(num_workers, 1)``.
Returns the trivial ``(0, 1)`` partition when the ``LHOTSE_USE_WORKER_PARTITION``
env var is not set — i.e. when :func:`worker_init_fn` has not been called.
This keeps map-style mode (where the sampler runs in the main process and uses
its own over-sample-and-discard DP dedup) unaffected even when RANK/WORLD_SIZE
are already set in the environment (e.g. by torchrun).
Used by indexed-manifest iterators (via :class:`~lhotse.indexing.LazyShuffledRange`)
to deterministically split index ranges across DP ranks × DataLoader workers
so each tuple yields a disjoint, non-overlapping subset.
Reads DP info via :func:`get_rank` / :func:`get_world_size` (env-var aware;
populated by :func:`worker_init_fn` inside DataLoader worker subprocesses).
Reads the DataLoader worker info via :func:`torch.utils.data.get_worker_info`;
when called outside a DataLoader worker (e.g. ``num_workers=0``), treats
the caller as a single worker (``worker_id=0, num_workers=1``).
"""
if os.environ.get(LHOTSE_USE_WORKER_PARTITION) != "1":
return 0, 1
rank = get_rank()
world_size = get_world_size()
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
worker_id, num_workers = 0, 1
else:
worker_id = worker_info.id
num_workers = max(worker_info.num_workers, 1)
shard_id = rank * num_workers + worker_id
num_shards = world_size * num_workers
return shard_id, num_shards
[docs]class PartitionedIndexedIterator:
"""Shared partition-aware iteration driver for indexed leaf iterators.
Encapsulates the (shard_id, num_shards) partition lookup, position
tracking across DataLoader worker subprocesses, and topology-validated
resume — the bits every indexed ``IteratorNode`` needs to repeat
correctly. Yields global indices into the leaf source; the caller is
responsible for decoding each index into the user-facing item.
Two iteration modes are supported and selected at construction time:
* **Stride** (``shuffle=False``, default): yields
``[shard_id, shard_id + num_shards, shard_id + 2 * num_shards, …]``
— the simplest disjoint-per-rank partition.
* **Feistel-shuffled** (``shuffle=True``, with ``seed``): yields a
Feistel permutation of the full range restricted to this rank's
slice, via :class:`~lhotse.indexing.LazyShuffledRange`. Useful when
the underlying source is in a deterministic on-disk order but the
consumer wants item-level shuffling within each shard.
Typical wiring::
class MyIndexedIterator(IteratorNode):
def __init__(self, ...):
...
self._iter_state = PartitionedIndexedIterator()
def __iter__(self):
for global_idx in self._iter_state.iterate(self._total_len):
item = self._decode_at(global_idx)
if item is None:
continue
yield item
def state_dict(self) -> dict:
return {**self._iter_state.state_dict(), "epoch": self.epoch}
def load_state_dict(self, sd: dict) -> None:
self._iter_state.load_state_dict(sd)
self.epoch = sd.get("epoch", 0)
Notes:
* The partition is ``shard_id = rank * num_workers + worker_id``,
``num_shards = world_size * num_workers``. Outside DataLoader
workers (or when :data:`LHOTSE_USE_WORKER_PARTITION` is unset —
i.e. map-style mode) the partition collapses to ``(0, 1)`` and
iteration covers the full range, matching pre-partition behavior.
* ``state_dict`` stores the local-within-shard ``position`` plus
the ``(shard_id, num_shards)`` topology captured at save time;
on resume we refuse to continue under a different topology
because the per-shard index sequence would diverge.
"""
[docs] def __init__(self, shuffle: bool = False, seed: int = 0) -> None:
self._shuffle = shuffle
self._seed = seed
self._position = 0
self._shard_id: Optional[int] = None
self._num_shards: Optional[int] = None
self._restored = False
# Constructed lazily inside :meth:`iterate` so the partition info
# is read in the same process that owns the DataLoader worker env.
self._range = None
self._pending_range_state = None
@property
def position(self) -> int:
"""Local position within the current shard (0-indexed next element)."""
return self._position
[docs] def iterate(self, total_len: int) -> Generator[int, None, None]:
"""Yield global indices for this rank's slice of ``range(total_len)``.
Raises:
ValueError: if resuming from a saved state under a different
``(shard_id, num_shards)`` topology than the one recorded
at save time.
"""
shard_id, num_shards = get_worker_partition()
if self._restored:
self._restored = False
if self._num_shards is not None and (
self._shard_id != shard_id or self._num_shards != num_shards
):
raise ValueError(
f"PartitionedIndexedIterator topology mismatch on resume: "
f"saved (shard_id={self._shard_id}, num_shards={self._num_shards}), "
f"current (shard_id={shard_id}, num_shards={num_shards}). "
f"Resuming with a different DP rank / DataLoader worker count "
f"is not supported (per-shard index sequence would diverge)."
)
start = self._position
else:
start = 0
self._position = 0
self._shard_id = shard_id
self._num_shards = num_shards
if self._shuffle:
from lhotse.indexing import LazyShuffledRange
self._range = LazyShuffledRange(
total_len, seed=self._seed, shard_id=shard_id, num_shards=num_shards
)
if self._pending_range_state is not None:
self._range.load_state_dict(self._pending_range_state)
self._pending_range_state = None
shard_len = len(self._range)
else:
self._range = None
if total_len > shard_id:
shard_len = (total_len - shard_id + num_shards - 1) // num_shards
else:
shard_len = 0
for i in range(start, shard_len):
self._position = i + 1
if self._range is not None:
yield self._range[i]
else:
yield shard_id + i * num_shards
[docs] def state_dict(self) -> dict:
sd = {
"position": self._position,
"shard_id": self._shard_id,
"num_shards": self._num_shards,
}
if self._range is not None:
sd["range"] = self._range.state_dict()
elif self._pending_range_state is not None:
sd["range"] = self._pending_range_state
return sd
[docs] def load_state_dict(self, sd: dict) -> None:
self._position = sd.get("position", 0)
self._shard_id = sd.get("shard_id")
self._num_shards = sd.get("num_shards")
if self._shuffle:
# Stash the LazyShuffledRange state; it gets applied in iterate()
# once the current (shard_id, num_shards) is known so the
# topology-mismatch check there can fire first with a clearer
# error than what LazyShuffledRange itself would raise.
self._pending_range_state = sd.get("range")
self._range = None
self._restored = True
[docs]def get_world_size() -> int:
"""Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56"""
if "WORLD_SIZE" in os.environ:
return int(os.environ["WORLD_SIZE"])
if dist.is_available() and dist.is_initialized():
return dist.get_world_size()
else:
return 1
[docs]def get_rank() -> int:
"""Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56"""
if "RANK" in os.environ:
return int(os.environ["RANK"])
elif dist.is_available() and dist.is_initialized():
return dist.get_rank()
else:
return 0