Source code for lhotse.lazy

import os
import random
import types
import warnings
from collections import deque
from contextlib import contextmanager
from functools import partial
from json import JSONDecodeError
from typing import Any, Callable, Iterable, List, Literal, Optional, TypeVar, Union

from lhotse.serialization import (
    LazyMixin,
    decode_json_line,
    deserialize_item,
    open_best,
)
from lhotse.utils import Pathlike, fastcopy, is_module_available

T = TypeVar("T")


# ---------------------------------------------------------------------------
# Dill-backed pickling mixin
# ---------------------------------------------------------------------------


class Dillable:
    """
    Mix-in that will leverage ``dill`` instead of ``pickle``
    when pickling an object.

    It is useful when the user can't avoid ``pickle`` (e.g. in multiprocessing),
    but needs to use unpicklable objects such as lambdas.

    If ``dill`` is not installed, it defers to what ``pickle`` does by default.
    """

    _ENABLED_VALUES = {"1", "True", "true", "yes"}

    def __getstate__(self):
        if is_dill_enabled():
            import dill

            return dill.dumps(self.__dict__)
        else:
            return self.__dict__

    def __setstate__(self, state):
        if is_dill_enabled():
            import dill

            self.__dict__ = dill.loads(state)
        else:
            self.__dict__ = state


# ---------------------------------------------------------------------------
# Iterator node protocol
# ---------------------------------------------------------------------------


class IteratorNode(Dillable, Iterable):
    """
    Base protocol for nodes in Lhotse's lazy iterator graph.

    Conventions for child references:

    * ``self.source``  — single child iterator
    * ``self.sources`` — list of child iterators

    Iterator nodes are not necessarily checkpointable. Nodes that support
    checkpointing should set ``is_checkpointable = True`` and implement
    :meth:`state_dict` and :meth:`load_state_dict`.

    .. warning::
        Instances are **not thread-safe**. Mutable position/restoration flags
        are updated without synchronization. For multi-worker data loading use
        process-based parallelism (the default in PyTorch's ``DataLoader``),
        which gives each worker its own copy.
    """

    is_checkpointable = False
    is_indexed = False
    has_constant_time_access = False

    def state_dict(self) -> dict:
        raise NotImplementedError(
            f"{type(self).__name__} is not checkpointable and does not implement state_dict()."
        )

    def load_state_dict(self, sd: dict) -> None:
        raise NotImplementedError(
            f"{type(self).__name__} is not checkpointable and does not implement load_state_dict()."
        )

    def iter_children(self):
        """Yield child iterators following ``source``/``sources`` conventions."""
        if hasattr(self, "source"):
            yield getattr(self, "source")
        if hasattr(self, "sources"):
            yield from getattr(self, "sources")


def resolve_iterator_source(obj: Iterable) -> Iterable:
    """
    Return the effective iterator payload for graph nodes.

    Manifest wrappers such as ``CutSet`` expose their underlying iterator via
    ``.data``; using it avoids introducing wrapper objects into lazy iterator
    graphs.
    """
    try:
        from lhotse.cut import CutSet
    except Exception:
        return obj
    return obj.data if isinstance(obj, CutSet) else obj


def _try_collect_child_state(obj: Any) -> Optional[dict]:
    if isinstance(obj, IteratorNode):
        if type(obj).state_dict is IteratorNode.state_dict:
            if any(True for _ in obj.iter_children()):
                raise NotImplementedError(
                    f"{type(obj).__name__} does not support checkpointing. "
                    f"Remove it from the pipeline before checkpointing or implement "
                    f"state_dict/load_state_dict."
                )
            return None
        return obj.state_dict()
    if hasattr(obj, "state_dict") and callable(getattr(obj, "state_dict")):
        try:
            return obj.state_dict()
        except Exception:
            return None
    return None


def _try_restore_child_state(obj: Any, state: Optional[dict]) -> None:
    if state is None:
        return
    if isinstance(obj, IteratorNode):
        if type(obj).load_state_dict is IteratorNode.load_state_dict:
            raise NotImplementedError(
                f"{type(obj).__name__} does not support checkpoint restoration. "
                f"Remove it from the pipeline before checkpointing or implement "
                f"state_dict/load_state_dict."
            )
        obj.load_state_dict(state)
        return
    if hasattr(obj, "load_state_dict") and callable(getattr(obj, "load_state_dict")):
        obj.load_state_dict(state)


class GraphOriginDict(dict):
    """``dict`` subclass that can carry runtime attributes (e.g. ``_graph_origin``).

    Use as a thin wrapper when you need :func:`attach_graph_origin` to work on
    raw dict items (e.g. JSONL lines decoded but not deserialized into Cut/etc.).
    Plain dicts can't have attributes set on them, so :func:`attach_graph_origin`
    silently no-ops on them.
    """

    __slots__ = ("_graph_origin",)


def _attach_runtime_metadata(item: Any, name: str, value: Any) -> Any:
    """
    Attach iterator runtime metadata without routing through Cut.custom.

    Cut-like objects use ``CustomFieldMixin.__setattr__`` to redirect unknown
    attributes into the serialized ``custom`` field. Graph restore metadata such
    as ``_graph_origin`` must stay process-local and never appear in manifests,
    so we bypass ``__setattr__`` when possible.
    """
    try:
        object.__setattr__(item, name, value)
    except Exception:
        try:
            setattr(item, name, value)
        except Exception:
            pass
    return item


def normalize_graph_token(token: Any) -> Any:
    """Convert JSON-serialized graph tokens back to tuples recursively."""
    if isinstance(token, list):
        return tuple(normalize_graph_token(part) for part in token)
    if isinstance(token, tuple):
        return tuple(normalize_graph_token(part) for part in token)
    return token


def attach_graph_origin(item: Any, token: Any) -> Any:
    return _attach_runtime_metadata(item, "_graph_origin", token)


def get_graph_origin(item: Any) -> Any:
    return getattr(item, "_graph_origin", None)


def maybe_attach_graph_origin(item: Any, token: Any) -> Any:
    if token is None:
        return item
    return attach_graph_origin(item, token)


def require_graph_origin(item: Any, owner: str, what: str = "items") -> Any:
    token = get_graph_origin(item)
    if token is None:
        raise RuntimeError(
            f"{owner} requires '_graph_origin' on {what} from graph-restorable sources."
        )
    return token


def supports_graph_restore(source: Any, *, require_length: bool = False) -> bool:
    if not getattr(source, "has_constant_time_access", False):
        return False
    if not hasattr(source, "__getitem__"):
        return False
    return not require_length or hasattr(source, "__len__")


def resolve_iteration_seed(
    seed: Optional[Union[int, Literal["trng", "randomized"]]]
) -> int:
    from lhotse.dataset.dataloading import resolve_seed

    if seed is None:
        return random.getrandbits(31)
    return resolve_seed(seed)


class AlgorithmMixin(LazyMixin, Iterable):
    """
    Helper base class with methods that are supposed to work identically
    on Lhotse manifest classes such as CutSet, RecordingSet, etc.
    """

    def filter(self, predicate: Callable[[T], bool]):
        """
        Return a new manifest containing only the items that satisfy ``predicate``.
        If the manifest is lazy, the filtering will also be applied lazily.

        :param predicate: a function that takes a cut as an argument and returns bool.
        :return: a filtered manifest.
        """
        cls = type(self)

        if self.is_lazy:
            return cls(LazyFilter(resolve_iterator_source(self), predicate=predicate))

        return cls.from_items(cut for cut in self if predicate(cut))

    def map(self, transform_fn: Callable[[T], T]):
        """
        Apply `transform_fn` to each item in this manifest and return a new manifest.
        If the manifest is opened lazy, the transform is also applied lazily.

        :param transform_fn: A callable (function) that accepts a single item instance
            and returns a new (or the same) instance of the same type.
            E.g. with CutSet, callable accepts ``Cut`` and returns also ``Cut``.
        :return: a new ``CutSet`` with transformed cuts.
        """
        cls = type(self)
        ans = cls(LazyMapper(resolve_iterator_source(self), fn=transform_fn))
        if self.is_lazy:
            return ans
        return ans.to_eager()

    @classmethod
    def mux(
        cls,
        *manifests,
        stop_early: bool = False,
        weights: Optional[List[Union[int, float]]] = None,
        seed: Union[int, Literal["trng", "randomized"]] = 0,
    ):
        """
        Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time.
        If one of the iterables is exhausted before the others, we will keep iterating until all iterables
        are exhausted. This behavior can be changed with ``stop_early`` parameter.

        :param manifests: iterables to be multiplexed.
            They can be either lazy or eager, but the resulting manifest will always be lazy.
        :param stop_early: should we stop the iteration as soon as we exhaust one of the manifests.
        :param weights: an optional weight for each iterable, affects the probability of it being sampled.
            The weights are uniform by default.
            If lengths are known, it makes sense to pass them here for uniform distribution of
            items in the expectation.
        :param seed: the random seed, ensures deterministic order across multiple iterations.
        """
        manifests = [resolve_iterator_source(m) for m in manifests]
        return cls(
            LazyIteratorMultiplexer(
                *manifests, stop_early=stop_early, weights=weights, seed=seed
            )
        )

    @classmethod
    def infinite_mux(
        cls,
        *manifests,
        weights: Optional[List[Union[int, float]]] = None,
        seed: Union[int, Literal["trng", "randomized"]] = 0,
        max_open_streams: Optional[int] = None,
    ):
        """
        Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time.
        Unlike ``mux()``, this method allows to limit the number of max open sub-iterators at any given time.

        To enable this, it performs 2-stage sampling.
        First, it samples with replacement the set of iterators ``I`` to construct a subset ``I_sub``
        of size ``max_open_streams``.
        Then, for each iteration step, it samples an iterator ``i`` from ``I_sub``,
        fetches the next item from it, and yields it.
        Once ``i`` becomes exhausted, it is replaced with a new iterator ``j`` sampled from ``I_sub``.

        .. caution:: Do not use this method with inputs that are infinitely iterable as they will
            silently break the multiplexing property by only using a subset of the input iterables.

        .. caution:: This method is not recommended for multiplexing for a small amount of iterations,
            as it may be much less accurate than ``mux()`` depending on the number of open streams,
            iterable sizes, and the random seed.

        :param manifests: iterables to be multiplexed.
            They can be either lazy or eager, but the resulting manifest will always be lazy.
        :param weights: an optional weight for each iterable, affects the probability of it being sampled.
            The weights are uniform by default.
            If lengths are known, it makes sense to pass them here for uniform distribution of
            items in the expectation.
        :param seed: the random seed, ensures deterministic order across multiple iterations.
        :param max_open_streams: the number of iterables that can be open simultaneously at any given time.
        """
        manifests = [resolve_iterator_source(m) for m in manifests]
        return cls(
            LazyInfiniteApproximateMultiplexer(
                *manifests,
                weights=weights,
                seed=seed,
                max_open_streams=max_open_streams,
            )
        )

    def shuffle(
        self,
        rng: Optional[random.Random] = None,
        buffer_size: int = 10000,
    ):
        """
        Shuffles the elements and returns a shuffled variant of self.
        If the manifest is opened lazily, performs shuffling on-the-fly with a fixed buffer size.

        :param rng: an optional instance of ``random.Random`` for precise control of randomness.
        :return: a shuffled copy of self, or a manifest that is shuffled lazily.
        """
        cls = type(self)

        if rng is None:
            rng = random

        if self.is_lazy:
            return cls(
                LazyShuffler(
                    resolve_iterator_source(self), buffer_size=buffer_size, rng=rng
                )
            )
        else:
            new: List = self.data.copy()
            rng.shuffle(new)
            return cls(new)

    def repeat(self, times: Optional[int] = None, preserve_id: bool = False):
        """
        Return a new, lazily evaluated manifest that iterates over the original elements ``times``
        number of times.

        :param times: how many times to repeat (infinite by default).
        :param preserve_id: when ``True``, we won't update the element ID with repeat number.
        :return: a repeated manifest.
        """
        cls = type(self)
        return cls(
            LazyRepeater(
                resolve_iterator_source(self), times=times, preserve_id=preserve_id
            )
        )

    def __add__(self, other):
        cls = type(self)
        return cls(
            LazyIteratorChain(
                resolve_iterator_source(self), resolve_iterator_source(other)
            )
        )


def is_dill_enabled(_ENABLED_VALUES=frozenset(("1", "True", "true", "yes"))) -> bool:
    """Returns bool indicating if dill-based pickling in Lhotse is enabled or not."""
    return (
        is_module_available("dill")
        and os.environ.get("LHOTSE_DILL_ENABLED", "0") in _ENABLED_VALUES
    )


def set_dill_enabled(value: bool) -> None:
    """Enable or disable dill-based pickling in Lhotse."""
    assert is_module_available("dill"), (
        "Cannot enable dill because dill is not installed. "
        "Please run 'pip install dill' and try again."
    )
    # We use os.environ here so that sub-processes / forks will inherit this value
    os.environ["LHOTSE_DILL_ENABLED"] = "1" if value else "0"


@contextmanager
def dill_enabled(value: bool):
    """
    Context manager that overrides the setting of Lhotse's dill-backed pickling
    and restores the previous value after exit.

    Example::

        >>> import pickle
        ... with dill_enabled(True):
        ...    pickle.dump(CutSet(...).filter(lambda c: c.duration < 5), open("cutset.pickle", "wb"))
    """
    previous = is_dill_enabled()
    set_dill_enabled(value)
    yield
    set_dill_enabled(previous)


[docs]class LazyTxtIterator(IteratorNode): """ LazyTxtIterator is a thin wrapper over builtin ``open`` function to iterate over lines in a (possibly compressed) text file. It can also provide the number of lines via __len__ via fast newlines counting. """
[docs] def __init__(self, path: Pathlike, as_text_example: bool = True) -> None: self.path = path self.as_text_example = as_text_example self._len = None
def __iter__(self): from lhotse.cut.text import TextExample tot = 0 with open_best(self.path, "r") as f: for line in f: line = line.strip() if self.as_text_example: line = TextExample(line) yield line tot += 1 if self._len is None: self._len = tot def __len__(self) -> int: if self._len is None: self._len = count_newlines_fast(self.path) return self._len
class LazyJsonlIterator(IteratorNode): """ LazyJsonlIterator provides the ability to read JSON lines as Python dicts. It can also provide the number of lines via __len__ via fast newlines counting. """ def __init__(self, path: Pathlike) -> None: self.path = path self._len = None self._position = 0 self._restored = False def __iter__(self): start = self._position if self._restored else 0 self._restored = False self._position = start tot = 0 with open_best(self.path, "r") as f: for line in f: tot += 1 if tot <= start: continue data = decode_json_line(line) self._position = tot yield data if self._len is None: self._len = tot def __len__(self) -> int: if self._len is None: self._len = count_newlines_fast(self.path) return self._len def state_dict(self) -> dict: """Return ``{"position": int}``.""" return {"position": self._position} def load_state_dict(self, sd: dict) -> None: """Restore position. Actual seeking happens in ``__iter__``.""" self._position = sd["position"] self._restored = True class LazyManifestIterator(IteratorNode): """ LazyManifestIterator provides the ability to read Lhotse objects from a JSONL file on-the-fly, without reading its full contents into memory. This class is designed to be a partial "drop-in" replacement for ordinary dicts to support lazy loading of RecordingSet, SupervisionSet and CutSet. Since it does not support random access reads, some methods of these classes might not work properly. Supports checkpointing via :meth:`state_dict` / :meth:`load_state_dict` (delegates to the inner :class:`LazyJsonlIterator`). """ is_checkpointable = True def __init__(self, path: Pathlike) -> None: self.source = LazyJsonlIterator(path) @property def path(self) -> Pathlike: return self.source.path def __iter__(self): yield from map(deserialize_item, self.source) def __len__(self) -> int: return len(self.source) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def state_dict(self) -> dict: return {"source": self.source.state_dict()} def load_state_dict(self, sd: dict) -> None: self.source.load_state_dict(sd["source"]) class LazyIndexedManifestIterator(IteratorNode): """ Lazy manifest iterator backed by an :class:`~lhotse.indexing.IndexedJsonlReader`. Supports O(1) random access via ``__getitem__`` and optional Feistel-shuffled iteration via :class:`~lhotse.indexing.LazyShuffledRange`. Unlike :class:`LazyManifestIterator`, this class requires an uncompressed JSONL file (the binary ``.idx`` index is created automatically if missing). Supports checkpointing via :meth:`state_dict` / :meth:`load_state_dict`. :param decode: callable invoked on each raw JSONL dict to produce the yielded item. Defaults to :func:`~lhotse.serialization.deserialize_item` which materializes a Lhotse Cut / Recording / etc. Pass :class:`GraphOriginDict` to keep raw dicts that can still carry graph-restore metadata. """ is_checkpointable = True def __init__( self, path: Pathlike, shuffle: bool = False, seed: int = 0, index_path: Optional[Pathlike] = None, decode: Optional[Callable[[dict], Any]] = None, skip_decode_errors: bool = False, decode_error_callback: Optional[ Callable[[BaseException, int, Pathlike], None] ] = None, ) -> None: from lhotse.dataset.dataloading import PartitionedIndexedIterator from lhotse.indexing import IndexedJsonlReader self.path = path self.shuffle = shuffle self.seed = seed self.index_path = index_path self.skip_decode_errors = skip_decode_errors self.decode_error_callback = decode_error_callback self._decode = decode if decode is not None else deserialize_item self._reader = IndexedJsonlReader(path, index_path=index_path) self._iter_state = PartitionedIndexedIterator(shuffle=shuffle, seed=seed) @property def is_indexed(self) -> bool: return True @property def has_constant_time_access(self) -> bool: return True def __getitem__(self, idx: int) -> Any: """O(1) random access: decodes the *idx*-th item.""" return self._decode_index(idx) def _decode_index(self, idx: int) -> Any: return attach_graph_origin(self._decode(self._reader[idx]), idx) def __iter__(self): for phys_idx in self._iter_state.iterate(len(self._reader)): try: yield self._decode_index(phys_idx) except (JSONDecodeError, UnicodeDecodeError) as ex: if not self.skip_decode_errors: raise if self.decode_error_callback is not None: self.decode_error_callback(ex, phys_idx, self.path) else: warnings.warn( f"Skipping malformed indexed JSONL record path={self.path!r} " f"idx={phys_idx}: {type(ex).__name__}: {ex}" ) def __len__(self) -> int: return len(self._reader) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def state_dict(self) -> dict: # ``shuffle`` and ``seed`` are surfaced in the dict for inspection # and forward-compat — they are not consumed on load (the iterator # is reconstructed with the same constructor args). return { **self._iter_state.state_dict(), "shuffle": self.shuffle, "seed": self.seed, } def load_state_dict(self, sd: dict) -> None: if self.shuffle and "range" not in sd: raise ValueError( "LazyIndexedManifestIterator with shuffle=True requires " "'range' in state_dict, but it was not found. " "The checkpoint may have been created without shuffling." ) self._iter_state.load_state_dict(sd) self._restored = True class LazyIteratorChain(IteratorNode): """ A thin wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. It iterates all underlying iterables sequentially. It also supports shuffling via ``shuffle_iters``. The shuffling strategy is chosen automatically based on whether all sub-iterators are indexed: * **Non-indexed sources** — shuffles the *order of sub-iterators* (shard-level shuffling). Every iteration increments an internal counter so the shard order is re-randomized. * **Indexed sources** — uses a Feistel-cipher permutation over the combined index range for true *item-level* shuffling that crosses sub-iterator boundaries, via O(1) random access. Supports checkpointing via :meth:`state_dict` / :meth:`load_state_dict`. .. note:: if any of the input iterables is a dict, we'll iterate only its values. """ is_checkpointable = True def __init__( self, *iterators: Iterable, shuffle_iters: bool = False, seed: Optional[Union[int, Literal["trng", "randomized"]]] = None, ) -> None: self.sources = [] self.shuffle_iters = shuffle_iters self.seed = seed self.num_iters = 0 for it in iterators: it = resolve_iterator_source(it) # Auto-flatten LazyIteratorChain instances if any are passed if isinstance(it, LazyIteratorChain): for sub_it in it.sources: self.sources.append(sub_it) else: self.sources.append(it) # Iteration tracking (sequential path) self._current_iter_idx = 0 self._iter_order: Optional[list] = None self._restored = False # Iteration tracking (globally-shuffled path) self._global_position = 0 self._global_seed = None @property def is_indexed(self) -> bool: return all(getattr(s, "is_indexed", False) for s in self.sources) @property def has_constant_time_access(self) -> bool: if self.shuffle_iters and not self.is_indexed: return False # shard order changes per iteration return all(supports_graph_restore(s, require_length=True) for s in self.sources) def __getitem__(self, idx: Any) -> Any: idx = normalize_graph_token(idx) if isinstance(idx, tuple) and len(idx) == 2: src_idx, source_token = idx return attach_graph_origin(self.sources[src_idx][source_token], idx) from bisect import bisect_right cum = self._cumulative_lengths() total = cum[-1] if idx < 0: idx += total if idx < 0 or idx >= total: raise IndexError("index out of range for LazyIteratorChain") src_idx = bisect_right(cum, idx) offset = idx - cum[src_idx - 1] if src_idx > 0 else idx return attach_graph_origin(self.sources[src_idx][offset], idx) def _cumulative_lengths(self) -> list: if getattr(self, "_cum_lens", None) is None: self._cum_lens = [] total = 0 for s in self.sources: total += len(s) self._cum_lens.append(total) return self._cum_lens def __iter__(self): if self.shuffle_iters and self.is_indexed: return self._iter_globally_shuffled() return self._iter_sequential() # ------------------------------------------------------------------ # Sequential iteration (original path — with optional shard shuffle) # ------------------------------------------------------------------ def _iter_sequential(self): from lhotse.dataset.dataloading import resolve_seed if self._restored: self._restored = False # Restore exact shard order and skip to the current shard. start_idx = self._current_iter_idx order = self._iter_order if order is None or len(order) != len(self.sources): order = list(range(len(self.sources))) else: start_idx = 0 order = list(range(len(self.sources))) if self.shuffle_iters: if self.seed is None: rng = random # global Python RNG else: rng = random.Random(resolve_seed(self.seed) + self.num_iters) rng.shuffle(order) self.num_iters += 1 self._iter_order = order self._current_iter_idx = 0 self._iter_order = order cum = self._cumulative_lengths() for idx in range(start_idx, len(order)): src_idx = order[idx] it = self.sources[src_idx] self._current_iter_idx = idx if isinstance(it, dict): it = it.values() for item in it: if self.has_constant_time_access and not self.shuffle_iters: maybe_attach_graph_origin(item, (src_idx, get_graph_origin(item))) yield item # ------------------------------------------------------------------ # Globally-shuffled iteration (O(1) random access across all sources) # ------------------------------------------------------------------ def _iter_globally_shuffled(self): from lhotse.dataset.dataloading import get_worker_partition from lhotse.indexing import LazyShuffledRange total = len(self) shard_id, num_shards = get_worker_partition() if self._restored: self._restored = False start = self._global_position base_seed = self._global_seed if base_seed is None: base_seed = resolve_iteration_seed(self.seed) saved_shard_id = getattr(self, "_global_shard_id", None) saved_num_shards = getattr(self, "_global_num_shards", None) if saved_num_shards is not None and ( saved_shard_id != shard_id or saved_num_shards != num_shards ): raise ValueError( f"LazyIteratorChain global-shuffle partition mismatch on resume: " f"saved (shard_id={saved_shard_id}, num_shards={saved_num_shards}), " f"current (shard_id={shard_id}, num_shards={num_shards}). " f"Resuming with a different DP/worker topology is not supported." ) else: start = 0 self._global_position = 0 base_seed = resolve_iteration_seed(self.seed) self._global_seed = base_seed self._global_shard_id = shard_id self._global_num_shards = num_shards shuffled = LazyShuffledRange( total, seed=base_seed + self.num_iters, shard_id=shard_id, num_shards=num_shards, ) shard_len = len(shuffled) for i in range(start, shard_len): self._global_position = i + 1 yield self[shuffled[i]] self.num_iters += 1 def __len__(self) -> int: return sum(len(it) for it in self.sources) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def state_dict(self) -> dict: sd = { "current_iter_idx": self._current_iter_idx, "num_iters": self.num_iters, "iter_order": self._iter_order, "global_position": self._global_position, "global_seed": getattr(self, "_global_seed", None), "global_shard_id": getattr(self, "_global_shard_id", None), "global_num_shards": getattr(self, "_global_num_shards", None), } # Save inner states for stateful children inner_states = [] for s in self.sources: inner_states.append(_try_collect_child_state(s)) sd["inner_states"] = inner_states return sd def load_state_dict(self, sd: dict) -> None: self._current_iter_idx = sd["current_iter_idx"] self.num_iters = sd["num_iters"] self._iter_order = sd.get("iter_order") self._global_position = sd.get("global_position", 0) self._global_seed = sd.get("global_seed") self._global_shard_id = sd.get("global_shard_id") self._global_num_shards = sd.get("global_num_shards") if self.shuffle_iters and self.is_indexed: # Globally-shuffled path: position + num_iters (+ stored per-epoch # resolved seed) are enough to reconstruct permutation deterministically. self._restored = True return # Sequential path: only restore sources that will still be iterated # (at or after current_iter_idx in iter_order). order = ( self._iter_order if self._iter_order is not None else list(range(len(self.sources))) ) active = set(order[self._current_iter_idx :]) for i, (s, inner_sd) in enumerate( zip(self.sources, sd.get("inner_states", [])) ): if i not in active or inner_sd is None: continue _try_restore_child_state(s, inner_sd) self._restored = True class LazyIteratorMultiplexer(IteratorNode): """ A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. During iteration, unlike :class:`.LazyIteratorChain`, :class:`.LazyIteratorMultiplexer` at each step randomly selects the iterable used to yield an item. Since the iterables might be of different length, we provide a ``weights`` parameter to let the user decide which iterables should be sampled more frequently than others. When an iterable is exhausted, we will keep sampling from the other iterables, until we exhaust them all, unless ``stop_early`` is set to ``True``. Supports checkpointing via :meth:`state_dict` / :meth:`load_state_dict`. """ is_checkpointable = True def __init__( self, *iterators: Iterable, stop_early: bool = False, weights: Optional[List[Union[int, float]]] = None, seed: Union[int, Literal["trng", "randomized"]] = 0, ) -> None: self.sources = [resolve_iterator_source(it) for it in iterators] self.stop_early = stop_early self.seed = seed assert ( len(self.sources) > 1 ), "There have to be at least two iterables to multiplex." if weights is None: self.weights = [1] * len(self.sources) else: self.weights = weights assert len(self.sources) == len(self.weights) # Iteration state self._rng_state = None self._exhausted: Optional[list] = None self._restored = False @property def is_indexed(self) -> bool: return all(getattr(s, "is_indexed", False) for s in self.sources) @property def has_constant_time_access(self) -> bool: return all(supports_graph_restore(s) for s in self.sources) def __getitem__(self, token: Any) -> Any: token = normalize_graph_token(token) if not isinstance(token, tuple) or len(token) != 2: raise TypeError( "LazyIteratorMultiplexer expects graph restore tokens shaped like " "(source_index, source_token)." ) source_idx, source_token = token return attach_graph_origin(self.sources[source_idx][source_token], token) def __iter__(self): from lhotse.dataset.dataloading import get_worker_partition, resolve_seed _, num_shards = get_worker_partition() # Only enforce against `seed='randomized'` when all child sources are indexed. # Indexed sources use `LazyShuffledRange(shard_id, num_shards)` to slice their # index ranges per-shard, so the multiplexer MUST pick the same source at each # step across shards or the global per-source proportions drift. Streaming # (non-indexed) sources don't partition their index ranges — each shard reads # its DDP-derived dedup slice in full regardless of how the multiplexer routes, # so per-shard RNG drift just shuffles the order within the expected ratio. # ``self.is_indexed`` is True iff every child source is indexed. if num_shards > 1 and self.seed == "randomized" and self.is_indexed: raise ValueError( "LazyIteratorMultiplexer cannot use seed='randomized' under multi-shard " "(DP rank x DataLoader worker) iteration with indexed sources: each " "shard would draw a different RNG state and pick a different source at the " "same step, causing the global weighted source distribution to drift across " "ranks. Use a fixed integer seed." ) rng = random.Random(resolve_seed(self.seed)) iters = [iter(it) for it in self.sources] restored = self._restored if restored: self._restored = False exhausted = ( list(self._exhausted) if self._exhausted is not None else [False] * len(iters) ) if self._rng_state is not None: rng.setstate(self._rng_state) else: exhausted = [False for _ in range(len(iters))] self._exhausted = exhausted def should_continue(): if self.stop_early: return not any(exhausted) else: return not all(exhausted) while should_continue(): active_indexes, active_weights = zip( *[ (i, w) for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights)) if not is_exhausted ] ) idx = rng.choices(active_indexes, weights=active_weights, k=1)[0] self._rng_state = rng.getstate() selected = iters[idx] try: item = next(selected) graph_token = None if self.has_constant_time_access: graph_token = require_graph_origin( item, "LazyIteratorMultiplexer", "items" ) maybe_attach_graph_origin( item, None if graph_token is None else (idx, graph_token) ) yield item except StopIteration: exhausted[idx] = True continue def __len__(self) -> int: return sum(len(it) for it in self.sources) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def state_dict(self) -> dict: sd = { "rng_state": self._rng_state, "exhausted": list(self._exhausted) if self._exhausted is not None else None, } inner_states = [] for s in self.sources: inner_states.append(_try_collect_child_state(s)) sd["inner_states"] = inner_states return sd def load_state_dict(self, sd: dict) -> None: self._rng_state = sd["rng_state"] self._exhausted = sd["exhausted"] active = None if self._exhausted is not None: active = {i for i, exhausted in enumerate(self._exhausted) if not exhausted} for i, (s, inner_sd) in enumerate( zip(self.sources, sd.get("inner_states", [])) ): if active is not None and i not in active: continue _try_restore_child_state(s, inner_sd) self._restored = True class LazyInfiniteApproximateMultiplexer(IteratorNode): """ A variant of :class:`.LazyIteratorMultiplexer` that allows to control the number of iterables that are simultaneously open. It is useful for large-scale data sets where opening multiple file handles in many processes leads to exhaustion of the operating system resources. If the data sets are sharded, it is recommended to pass each shard as a separate iterator when creating objects of this class. It is OK to assign a dataset-level weight to each shard (e.g., if a dataset has a weight of 0.5, assign weight 0.5 to each of its shards). There are several differences between this class and :class:`.LazyIteratorMultiplexer`: * Objects of this class are infinite iterators. * We hold a list of ``max_open_streams`` open iterators at any given time. This list is filled by sampling input iterators with replacement. These differences are necessary to guarantee the weighted sampling property. If we did not sample with replacement or make it infinite, we would simply exhaust highly-weighted iterators towards the beginning of each "epoch" and keep sampling only lowly-weighted iterators towards the end of each "epoch". .. note:: This class does **not** support checkpointing (``state_dict`` / ``load_state_dict``). Its infinite, approximate nature with dynamically replaced streams makes exact restoration infeasible. For resumable multiplexed iteration, use :class:`.LazyIteratorMultiplexer` with finite sources instead. """ def __init__( self, *iterators: Iterable, stop_early: bool = False, weights: Optional[List[Union[int, float]]] = None, seed: Union[int, Literal["trng", "randomized"]] = 0, max_open_streams: Optional[int] = None, ) -> None: self.sources = [resolve_iterator_source(it) for it in iterators] self.stop_early = stop_early self.seed = seed self.max_open_streams = max_open_streams if max_open_streams is None or max_open_streams > len(self.sources): self.max_open_streams = len(self.sources) assert len(self.sources) > 0 self.weights = weights if weights is None: self.weights = [1] * len(self.sources) assert len(self.sources) == len(self.weights) assert ( self.max_open_streams is None or self.max_open_streams >= 1 ), f"{self.max_open_streams=}" def __iter__(self): """ Assumptions - we have N streams but can only open M at the time (M < N) - the streams are finite - each stream needs to be "short" to ensure the mux property - each stream may be interpreted as a shard belonging to some larger group of streams (e.g. multiple shards of a given dataset). """ from lhotse.dataset.dataloading import resolve_seed rng = random.Random(resolve_seed(self.seed)) def shuffled_streams(): indexes = list(range(len(self.sources))) while True: selected = rng.choices(indexes, self.weights, k=1)[0] yield self.sources[selected], self.weights[selected], selected # Initialize an infinite sequence of finite streams. stream_source = shuffled_streams() # Sample the first M active streams to be multiplexed. active_streams = [None] * self.max_open_streams active_weights = [None] * self.max_open_streams stream_indexes = list(range(self.max_open_streams)) def sample_new_stream_at(pos: int) -> None: sampled_stream, sampled_weight, _ = next(stream_source) active_streams[pos] = iter(sampled_stream) active_weights[pos] = sampled_weight for stream_pos in range(self.max_open_streams): sample_new_stream_at(stream_pos) # The actual multiplexing loop. while True: stream_pos = rng.choices( stream_indexes, weights=active_weights if sum(active_weights) > 0 else None, k=1, )[0] selected = active_streams[stream_pos] try: item = next(selected) yield item except StopIteration: sample_new_stream_at(stream_pos) item = next(active_streams[stream_pos]) yield item class LazyShuffler(IteratorNode): """ A wrapper over an iterable that enables lazy shuffling. The shuffling algorithm is reservoir-sampling based. See :func:`lhotse.utils.streaming_shuffle` for details. With graph-restorable indexed sources, the shuffle buffer and RNG state can be checkpointed and restored exactly. """ def __init__( self, iterator: Iterable, buffer_size: int = 10000, rng: Optional[random.Random] = None, ) -> None: self.source = resolve_iterator_source(iterator) self.buffer_size = buffer_size self.rng = rng if rng is not None else random.Random(random.getrandbits(64)) self._buffer = deque() self._startup = True self._source_exhausted = False self._restored = False @property def is_checkpointable(self) -> bool: return supports_graph_restore(self.source) @property def is_indexed(self) -> bool: return getattr(self.source, "is_indexed", False) @property def has_constant_time_access(self) -> bool: return supports_graph_restore(self.source) def __getitem__(self, token: Any) -> Any: token = normalize_graph_token(token) return attach_graph_origin(self.source[token], token) def _reset_iteration_state(self) -> None: self._buffer.clear() self._startup = True self._source_exhausted = False def _next_source_item(self, source_iter) -> Any: try: return next(source_iter) except StopIteration: self._source_exhausted = True return None def _maybe_fill_buffer(self, source_iter) -> None: if len(self._buffer) >= self.buffer_size: return item = self._next_source_item(source_iter) if item is not None: self._buffer.append(item) def _swap_with_buffer(self, sample: Any) -> Any: if not self._buffer: return sample swap_idx = self.rng.randint(0, len(self._buffer) - 1) sample, self._buffer[swap_idx] = self._buffer[swap_idx], sample return sample def _startup_phase(self, source_iter): while self._startup and not self._source_exhausted: sample = self._next_source_item(source_iter) if sample is None: break self._maybe_fill_buffer(source_iter) sample = self._swap_with_buffer(sample) if len(self._buffer) < self.buffer_size: self._buffer.append(sample) continue self._startup = False yield sample def _steady_state_phase(self, source_iter): while not self._source_exhausted: sample = self._next_source_item(source_iter) if sample is None: break self._maybe_fill_buffer(source_iter) yield self._swap_with_buffer(sample) def __iter__(self): source_iter = iter(self.source) if self._restored: self._restored = False else: self._reset_iteration_state() yield from self._startup_phase(source_iter) yield from self._steady_state_phase(source_iter) while self._buffer: yield self._buffer.popleft() def __len__(self) -> int: return len(self.source) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def state_dict(self) -> dict: if not self.is_checkpointable: raise NotImplementedError( "LazyShuffler does not support checkpointing unless its source " "supports graph restoration." ) from lhotse.checkpoint import _rng_state_to_json source_state = _try_collect_child_state(self.source) return { "buffer": [ require_graph_origin(item, "LazyShuffler", "buffered items") for item in self._buffer ], "startup": self._startup, "source_exhausted": self._source_exhausted, "rng_state": _rng_state_to_json(self.rng.getstate()), "source": source_state, } def load_state_dict(self, sd: dict) -> None: if not self.is_checkpointable: raise NotImplementedError( "LazyShuffler does not support checkpointing unless its source " "supports graph restoration." ) from lhotse.checkpoint import _rng_state_from_json _try_restore_child_state(self.source, sd.get("source")) self._buffer = deque( self.source[normalize_graph_token(token)] for token in sd.get("buffer", []) ) self._startup = sd.get("startup", True) self._source_exhausted = sd.get("source_exhausted", False) self.rng.setstate(_rng_state_from_json(sd["rng_state"])) self._restored = True class LazyFilter(IteratorNode): """ A wrapper over an iterable that enables lazy filtering. It works like Python's `filter` built-in by applying the filter predicate to each element and yielding it further if predicate returned ``True``. Supports checkpointing via :meth:`state_dict` / :meth:`load_state_dict` (delegates to the inner ``source`` iterator; the filter itself is stateless). """ is_checkpointable = True def __init__(self, iterator: Iterable, predicate: Callable[[Any], bool]) -> None: self.source = resolve_iterator_source(iterator) self.predicate = predicate assert callable( self.predicate ), f"LazyFilter: 'predicate' arg must be callable (got {predicate})." if ( isinstance(self.predicate, types.LambdaType) and self.predicate.__name__ == "<lambda>" and not is_module_available("dill") ): warnings.warn( "A lambda was passed to LazyFilter: it may prevent you from forking this process. " "If you experience issues with num_workers > 0 in torch.utils.data.DataLoader, " "try passing a regular function instead." ) @property def is_indexed(self) -> bool: return getattr(self.source, "is_indexed", False) @property def has_constant_time_access(self) -> bool: return supports_graph_restore(self.source) def __getitem__(self, token: Any) -> Any: token = normalize_graph_token(token) item = self.source[token] if not self.predicate(item): raise RuntimeError( "LazyFilter received a graph restore token that does not satisfy its " "predicate." ) return attach_graph_origin(item, token) def __iter__(self): for item in self.source: if self.predicate(item): yield item def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: raise TypeError( "LazyFilter does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " "`.to_eager()`. Note that this will require loading the whole iterator into memory." ) def state_dict(self) -> dict: sd = {} source_state = _try_collect_child_state(self.source) if source_state is not None: sd["source"] = source_state return sd def load_state_dict(self, sd: dict) -> None: _try_restore_child_state(self.source, sd.get("source")) class LazyMapper(IteratorNode): """ A wrapper over an iterable that enables lazy function evaluation on each item. It works like Python's `map` built-in by applying a callable ``fn`` to each element ``x`` and yielding the result of ``fn(x)`` further. New in Lhotse v1.22.0: ``apply_fn`` can be provided to decide whether ``fn`` should be applied to a given example or not (in which case it will return it as-is, i.e., it does not filter). Supports checkpointing via :meth:`state_dict` / :meth:`load_state_dict` (delegates to the inner ``source`` iterator; the mapper itself is stateless). """ is_checkpointable = True def __init__( self, iterator: Iterable, fn: Callable[[Any], Any], apply_fn: Optional[Callable[[Any], bool]] = None, ) -> None: self.source = resolve_iterator_source(iterator) self.fn = fn self.apply_fn = apply_fn assert callable(self.fn), f"LazyMapper: 'fn' arg must be callable (got {fn})." if self.apply_fn is not None: assert callable( self.apply_fn ), f"LazyMapper: 'apply_fn' arg must be callable (got {fn})." if ( (isinstance(self.fn, types.LambdaType) and self.fn.__name__ == "<lambda>") or ( isinstance(self.apply_fn, types.LambdaType) and self.apply_fn.__name__ == "<lambda>" ) and not is_dill_enabled() ): warnings.warn( "A lambda was passed to LazyMapper: it may prevent you from forking this process. " "If you experience issues with num_workers > 0 in torch.utils.data.DataLoader, " "try passing a regular function instead." ) @property def is_indexed(self) -> bool: return getattr(self.source, "is_indexed", False) @property def has_constant_time_access(self) -> bool: return supports_graph_restore(self.source) def __getitem__(self, idx: Any) -> Any: graph_token = normalize_graph_token(idx) item = self.source[graph_token] if self.apply_fn is None or self.apply_fn(item): item = self.fn(item) return attach_graph_origin(item, graph_token) def __iter__(self): if self.apply_fn is None: for item in self.source: graph_idx = get_graph_origin(item) item = self.fn(item) yield maybe_attach_graph_origin(item, graph_idx) else: for item in self.source: graph_idx = get_graph_origin(item) if self.apply_fn(item): ans = self.fn(item) else: ans = item yield maybe_attach_graph_origin(ans, graph_idx) def __len__(self) -> int: return len(self.source) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def state_dict(self) -> dict: sd = {} source_state = _try_collect_child_state(self.source) if source_state is not None: sd["source"] = source_state return sd def load_state_dict(self, sd: dict) -> None: _try_restore_child_state(self.source, sd.get("source")) class LazyFlattener(IteratorNode): """ A wrapper over an iterable of collections that flattens it to an iterable of items. With graph-restorable outer sources, this node checkpoints exactly by saving the current outer-item token and the local offset within that collection. Example:: >>> list_of_cut_sets: List[CutSet] = [CutSet(...), CutSet(...)] >>> list_of_cuts: List[Cut] = list(LazyFlattener(list_of_cut_sets)) """ def __init__(self, iterator: Iterable) -> None: self.source = resolve_iterator_source(iterator) self._active_outer_token = None self._inner_position = 0 self._restored = False @property def is_checkpointable(self) -> bool: return supports_graph_restore(self.source) @property def is_indexed(self) -> bool: return getattr(self.source, "is_indexed", False) @property def has_constant_time_access(self) -> bool: return supports_graph_restore(self.source) def _resolve_collection(self, collection: Any) -> Any: return resolve_iterator_source(collection) def _inner_token(self, item: Any, inner_idx: int) -> Any: token = get_graph_origin(item) return inner_idx if token is None else token def _restore_inner_item(self, collection: Any, token: Any) -> Any: collection = self._resolve_collection(collection) token = normalize_graph_token(token) if isinstance(token, int): if hasattr(collection, "__getitem__"): return collection[token] for idx, item in enumerate(collection): if idx == token: return item raise IndexError( f"LazyFlattener inner index {token} is out of range for {type(collection).__name__}." ) if supports_graph_restore(collection): return collection[token] raise RuntimeError( "LazyFlattener received a non-integer inner graph token for a collection " "that does not support graph restoration." ) def __getitem__(self, idx: Any) -> Any: token = normalize_graph_token(idx) if not isinstance(token, tuple) or len(token) != 2: raise TypeError( "LazyFlattener expects graph restore tokens shaped like " "(outer_token, inner_token)." ) outer_token, inner_token = token collection = self.source[outer_token] item = self._restore_inner_item(collection, inner_token) return attach_graph_origin(item, token) def _iter_collection( self, collection: Any, outer_token: Any, start_inner: int = 0 ) -> Iterable[Any]: collection = self._resolve_collection(collection) for inner_idx, item in enumerate(collection): if inner_idx < start_inner: continue self._active_outer_token = outer_token self._inner_position = inner_idx + 1 token = None if outer_token is not None: token = (outer_token, self._inner_token(item, inner_idx)) yield maybe_attach_graph_origin(item, token) self._active_outer_token = None self._inner_position = 0 def __iter__(self): if self._restored and self._active_outer_token is not None: collection = self.source[self._active_outer_token] yield from self._iter_collection( collection, self._active_outer_token, start_inner=self._inner_position, ) self._restored = False for cuts in self.source: outer_token = ( require_graph_origin(cuts, "LazyFlattener", "outer collections") if self.is_checkpointable else None ) yield from self._iter_collection(cuts, outer_token) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: raise TypeError( "LazyFlattener does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " "`.to_eager()`. Note that this will require loading the whole iterator into memory." ) def state_dict(self) -> dict: if not self.is_checkpointable: raise NotImplementedError( "LazyFlattener does not support checkpointing unless its outer " "source supports graph restoration." ) source_state = _try_collect_child_state(self.source) state = { "active_outer_token": self._active_outer_token, "inner_position": self._inner_position, "source": source_state, } return state def load_state_dict(self, sd: dict) -> None: if not self.is_checkpointable: raise NotImplementedError( "LazyFlattener does not support checkpointing unless its outer " "source supports graph restoration." ) self._active_outer_token = normalize_graph_token(sd.get("active_outer_token")) self._inner_position = sd.get("inner_position", 0) _try_restore_child_state(self.source, sd.get("source")) self._restored = True class LazyRepeater(IteratorNode): """ A wrapper over an iterable that enables to repeat it N times or infinitely (default). Supports checkpointing via :meth:`state_dict` / :meth:`load_state_dict`. Captures the current epoch and the state of the inner ``source`` iterator. """ is_checkpointable = True def __init__( self, iterator: Iterable, times: Optional[int] = None, preserve_id: bool = False ) -> None: self.source = resolve_iterator_source(iterator) self.times = times self.preserve_id = preserve_id assert self.times is None or self.times > 0 self._current_epoch = 0 self._restored = False @property def is_indexed(self) -> bool: return getattr(self.source, "is_indexed", False) @property def has_constant_time_access(self) -> bool: return supports_graph_restore(self.source) def __getitem__(self, idx: Any) -> Any: graph_token = normalize_graph_token(idx) if isinstance(graph_token, tuple) and len(graph_token) == 2: repeat_idx, source_token = graph_token item = self.source[source_token] else: n = len(self.source) repeat_idx = graph_token // n item = self.source[graph_token % n] if self.preserve_id: return attach_graph_origin(item, graph_token) return attach_graph_origin( attach_repeat_idx_to_id(item, repeat_idx), graph_token ) def __iter__(self): restored = self._restored epoch = self._current_epoch if restored else 0 self._restored = False while self.times is None or epoch < self.times: self._current_epoch = epoch if self.preserve_id: iterator = self.source else: iterator = LazyMapper( self.source, partial(attach_repeat_idx_to_id, idx=epoch) ) at_least_once = False for item in iterator: at_least_once = True source_idx = get_graph_origin(item) maybe_attach_graph_origin( item, None if source_idx is None else (epoch, source_idx) ) yield item if not at_least_once and not restored: return # Detect empty iterables to avoid hanging the program. # After the first (possibly restored) epoch, behave normally. restored = False epoch += 1 def __len__(self) -> int: if self.times is None: raise TypeError( f"object of type '{type(self).__name__}' is an infinite iterator" ) return len(self.source) * self.times def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def state_dict(self) -> dict: sd = {"current_epoch": self._current_epoch} source_state = _try_collect_child_state(self.source) if source_state is not None: sd["source"] = source_state return sd def load_state_dict(self, sd: dict) -> None: self._current_epoch = sd["current_epoch"] _try_restore_child_state(self.source, sd.get("source")) self._restored = True class LazySlicer(IteratorNode): """ A wrapper over an iterable that enables selecting k-th element every n elements. Supports checkpointing via :meth:`state_dict` / :meth:`load_state_dict` (delegates to the inner ``source`` iterator). """ is_checkpointable = True def __init__(self, iterator: Iterable, k: int, n: int) -> None: self.source = resolve_iterator_source(iterator) assert ( k < n ), f"When selecting k-th element every n elements, k must be less than n (got k={k} n={n})." self.k = k self.n = n self._source_offset = 0 self._restored = False @property def is_indexed(self) -> bool: return getattr(self.source, "is_indexed", False) @property def has_constant_time_access(self) -> bool: return supports_graph_restore(self.source) def __getitem__(self, idx: Any) -> Any: graph_token = normalize_graph_token(idx) if ( isinstance(graph_token, tuple) and len(graph_token) == 2 and graph_token[0] == "source" ): return attach_graph_origin(self.source[graph_token[1]], graph_token) if isinstance(graph_token, int): return attach_graph_origin(self.source[graph_token * self.n + self.k], idx) return attach_graph_origin(self.source[graph_token], graph_token) def __iter__(self): start = self._source_offset if self._restored else 0 self._restored = False for idx, item in enumerate(self.source, start=start): self._source_offset = idx + 1 if idx % self.n == self.k: source_idx = get_graph_origin(item) maybe_attach_graph_origin( item, None if source_idx is None else ("source", source_idx) ) yield item def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: raise TypeError( "LazySlicer does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " "`.to_eager()`. Note that this will require loading the whole iterator into memory." ) def state_dict(self) -> dict: sd = {"source_offset": self._source_offset} source_state = _try_collect_child_state(self.source) if source_state is not None: sd["source"] = source_state return sd def load_state_dict(self, sd: dict) -> None: self._source_offset = sd.get("source_offset", 0) _try_restore_child_state(self.source, sd.get("source")) self._restored = True def attach_repeat_idx_to_id(item: Any, idx: int) -> Any: if not hasattr(item, "id"): return item return fastcopy(item, id=f"{item.id}_repeat{idx}") def count_newlines_fast(path: Pathlike): """ Counts newlines in a file using buffered chunk reads. The fastest possible option in Python according to: https://stackoverflow.com/a/68385697/5285891 (This is a slightly modified variant of that answer.) """ def _make_gen(reader): b = reader(2**16) while b: yield b b = reader(2**16) read_mode = "rb" if not str(path) == "-" else "r" with open_best(path, read_mode) as f: count = sum(buf.count(b"\n") for buf in _make_gen(f.read)) return count