Source code for lhotse.dataset.dataloading

import os
import random
import secrets
from functools import partial
from typing import Callable, Literal, Optional, Union

import torch
from torch import distributed as dist

from lhotse.utils import fix_random_seed

LHOTSE_PROCESS_SEED = "LHOTSE_PROCESS_SEED"


[docs] def make_worker_init_fn( rank: Optional[int] = None, world_size: Optional[int] = None, set_different_node_and_worker_seeds: bool = True, seed: Optional[int] = 42, ) -> Optional[Callable[[int], None]]: """ Calling this function creates a worker_init_fn suitable to pass to PyTorch's DataLoader. It helps with two issues: * sets the random seeds differently for each worker and node, which helps with avoiding duplication in randomized data augmentation techniques. * sets environment variables that help WebDataset detect it's inside multi-GPU (DDP) training, so that it correctly de-duplicates the data across nodes. """ return partial( worker_init_fn, rank=rank, world_size=world_size, set_different_node_and_worker_seeds=set_different_node_and_worker_seeds, seed=seed, )
[docs] def worker_init_fn( worker_id: int, rank: Optional[int] = None, world_size: Optional[int] = None, set_different_node_and_worker_seeds: bool = True, seed: Optional[int] = 42, ) -> None: """ Function created by :func:`~lhotse.dataset.dataloading.make_worker_init_fn`, refer to its documentation for details. """ if set_different_node_and_worker_seeds: process_seed = seed + 100 * worker_id if rank is not None: process_seed += 100000 * rank fix_random_seed(process_seed) os.environ[LHOTSE_PROCESS_SEED] = str(process_seed) if rank is None and world_size is None: return assert ( rank is not None and world_size is not None ), f"Both args must be not None: rank={rank}, world_size={world_size}" # This sets the rank/world_size info for WebDataset to read it in worker subprocesses. # If we didn't do it, WebDataset will "think" this is always single-node training, # because DataLoader workers did not initialize torch.distributed. os.environ["RANK"] = str(rank) os.environ["WORLD_SIZE"] = str(world_size)
[docs] def resolve_seed(seed: Union[int, Literal["trng", "randomized"]]) -> int: """ Resolves the special values of random seed supported in Lhotse. If it's an integer, we'll just return it. If it's "trng", we'll use the ``secrets`` module to generate a random seed using a true RNG (to the extend supported by the OS). If it's "randomized", we'll check whether we're in a dataloading worker of ``torch.utils.data.DataLoader``. If we are, we expect that it was passed the result of :func:`~lhotse.dataset.dataloading.make_worker_init_fn` into its ``worker_init_fn`` argument, in which case we'll return a special seed exclusive to that worker. If we are not in a dataloading worker (or ``num_workers`` was set to ``0``), we'll return Python's ``random`` module global seed. """ if isinstance(seed, int): return seed if seed == "randomized": worker_info = torch.utils.data.get_worker_info() if worker_info is None: # not in a dataloader sub-process: get python global random seed return random.getstate()[1][0] else: # in a dataloader sub-process: read out the seed we assigned to it assert LHOTSE_PROCESS_SEED in os.environ, ( "Requested seed='randomized' for shuffling shards differently " "on each DataLoader node and worker, " "but lhotse.dataset.dataloading.worker_init_fn was not called." ) return int(os.environ[LHOTSE_PROCESS_SEED]) if seed == "trng": return secrets.randbelow(2**32) raise ValueError( f"Unexpected type or value of seed: {type(seed)=} {seed=}. " f"Supported values are: int, 'trng', and 'randomized'." )
[docs] def get_world_size() -> int: """Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56""" if "WORLD_SIZE" in os.environ: return int(os.environ["WORLD_SIZE"]) if dist.is_available() and dist.is_initialized(): return dist.get_world_size() else: return 1
[docs] def get_rank() -> int: """Source: https://github.com/danpovey/icefall/blob/74bf02bba6016c1eb37858a4e0e8a40f7d302bdb/icefall/dist.py#L56""" if "RANK" in os.environ: return int(os.environ["RANK"]) elif dist.is_available() and dist.is_initialized(): return dist.get_rank() else: return 0