Source code for lhotse.lazy

import os
import random
import types
import warnings
from contextlib import contextmanager
from functools import partial
from typing import Any, Callable, Iterable, List, Literal, Optional, TypeVar, Union

from lhotse.serialization import (
    LazyMixin,
    decode_json_line,
    deserialize_item,
    extension_contains,
    open_best,
)
from lhotse.utils import (
    Pathlike,
    build_rng,
    fastcopy,
    is_module_available,
    streaming_shuffle,
)

T = TypeVar("T")


class AlgorithmMixin(LazyMixin, Iterable):
    """
    Helper base class with methods that are supposed to work identically
    on Lhotse manifest classes such as CutSet, RecordingSet, etc.
    """

    def filter(self, predicate: Callable[[T], bool]):
        """
        Return a new manifest containing only the items that satisfy ``predicate``.
        If the manifest is lazy, the filtering will also be applied lazily.

        :param predicate: a function that takes a cut as an argument and returns bool.
        :return: a filtered manifest.
        """
        cls = type(self)

        if self.is_lazy:
            return cls(LazyFilter(self, predicate=predicate))

        return cls.from_items(cut for cut in self if predicate(cut))

    def map(self, transform_fn: Callable[[T], T]):
        """
        Apply `transform_fn` to each item in this manifest and return a new manifest.
        If the manifest is opened lazy, the transform is also applied lazily.

        :param transform_fn: A callable (function) that accepts a single item instance
            and returns a new (or the same) instance of the same type.
            E.g. with CutSet, callable accepts ``Cut`` and returns also ``Cut``.
        :return: a new ``CutSet`` with transformed cuts.
        """
        cls = type(self)
        ans = cls(LazyMapper(self.data, fn=transform_fn))
        if self.is_lazy:
            return ans
        return ans.to_eager()

    @classmethod
    def mux(
        cls,
        *manifests,
        stop_early: bool = False,
        weights: Optional[List[Union[int, float]]] = None,
        seed: Union[int, Literal["trng", "randomized"]] = 0,
    ):
        """
        Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time.
        If one of the iterables is exhausted before the others, we will keep iterating until all iterables
        are exhausted. This behavior can be changed with ``stop_early`` parameter.

        :param manifests: iterables to be multiplexed.
            They can be either lazy or eager, but the resulting manifest will always be lazy.
        :param stop_early: should we stop the iteration as soon as we exhaust one of the manifests.
        :param weights: an optional weight for each iterable, affects the probability of it being sampled.
            The weights are uniform by default.
            If lengths are known, it makes sense to pass them here for uniform distribution of
            items in the expectation.
        :param seed: the random seed, ensures deterministic order across multiple iterations.
        """
        return cls(
            LazyIteratorMultiplexer(
                *manifests, stop_early=stop_early, weights=weights, seed=seed
            )
        )

    @classmethod
    def infinite_mux(
        cls,
        *manifests,
        weights: Optional[List[Union[int, float]]] = None,
        seed: Union[int, Literal["trng", "randomized"]] = 0,
        max_open_streams: Optional[int] = None,
    ):
        """
        Merges multiple manifest iterables into a new iterable by lazily multiplexing them during iteration time.
        Unlike ``mux()``, this method allows to limit the number of max open sub-iterators at any given time.

        To enable this, it performs 2-stage sampling.
        First, it samples with replacement the set of iterators ``I`` to construct a subset ``I_sub``
        of size ``max_open_streams``.
        Then, for each iteration step, it samples an iterator ``i`` from ``I_sub``,
        fetches the next item from it, and yields it.
        Once ``i`` becomes exhausted, it is replaced with a new iterator ``j`` sampled from ``I_sub``.

        .. caution:: Do not use this method with inputs that are infinitely iterable as they will
            silently break the multiplexing property by only using a subset of the input iterables.

        .. caution:: This method is not recommended for multiplexing for a small amount of iterations,
            as it may be much less accurate than ``mux()`` depending on the number of open streams,
            iterable sizes, and the random seed.

        :param manifests: iterables to be multiplexed.
            They can be either lazy or eager, but the resulting manifest will always be lazy.
        :param weights: an optional weight for each iterable, affects the probability of it being sampled.
            The weights are uniform by default.
            If lengths are known, it makes sense to pass them here for uniform distribution of
            items in the expectation.
        :param seed: the random seed, ensures deterministic order across multiple iterations.
        :param max_open_streams: the number of iterables that can be open simultaneously at any given time.
        """
        return cls(
            LazyInfiniteApproximateMultiplexer(
                *manifests,
                weights=weights,
                seed=seed,
                max_open_streams=max_open_streams,
            )
        )

    def shuffle(
        self,
        rng: Optional[random.Random] = None,
        buffer_size: int = 10000,
    ):
        """
        Shuffles the elements and returns a shuffled variant of self.
        If the manifest is opened lazily, performs shuffling on-the-fly with a fixed buffer size.

        :param rng: an optional instance of ``random.Random`` for precise control of randomness.
        :return: a shuffled copy of self, or a manifest that is shuffled lazily.
        """
        cls = type(self)

        if rng is None:
            rng = random

        if self.is_lazy:
            return cls(LazyShuffler(self.data, buffer_size=buffer_size, rng=rng))
        else:
            new: List = self.data.copy()
            rng.shuffle(new)
            return cls(new)

    def repeat(self, times: Optional[int] = None, preserve_id: bool = False):
        """
        Return a new, lazily evaluated manifest that iterates over the original elements ``times``
        number of times.

        :param times: how many times to repeat (infinite by default).
        :param preserve_id: when ``True``, we won't update the element ID with repeat number.
        :return: a repeated manifest.
        """
        cls = type(self)
        return cls(LazyRepeater(self, times=times, preserve_id=preserve_id))

    def __add__(self, other):
        cls = type(self)
        return cls(LazyIteratorChain(self.data, other.data))


class Dillable:
    """
    Mix-in that will leverage ``dill`` instead of ``pickle``
    when pickling an object.

    It is useful when the user can't avoid ``pickle`` (e.g. in multiprocessing),
    but needs to use unpicklable objects such as lambdas.

    If ``dill`` is not installed, it defers to what ``pickle`` does by default.
    """

    _ENABLED_VALUES = {"1", "True", "true", "yes"}

    def __getstate__(self):
        if is_dill_enabled():
            import dill

            return dill.dumps(self.__dict__)
        else:
            return self.__dict__

    def __setstate__(self, state):
        if is_dill_enabled():
            import dill

            self.__dict__ = dill.loads(state)
        else:
            self.__dict__ = state


def is_dill_enabled(_ENABLED_VALUES=frozenset(("1", "True", "true", "yes"))) -> bool:
    """Returns bool indicating if dill-based pickling in Lhotse is enabled or not."""
    return (
        is_module_available("dill")
        and os.environ.get("LHOTSE_DILL_ENABLED", "0") in _ENABLED_VALUES
    )


def set_dill_enabled(value: bool) -> None:
    """Enable or disable dill-based pickling in Lhotse."""
    assert is_module_available("dill"), (
        "Cannot enable dill because dill is not installed. "
        "Please run 'pip install dill' and try again."
    )
    # We use os.environ here so that sub-processes / forks will inherit this value
    os.environ["LHOTSE_DILL_ENABLED"] = "1" if value else "0"


@contextmanager
def dill_enabled(value: bool):
    """
    Context manager that overrides the setting of Lhotse's dill-backed pickling
    and restores the previous value after exit.

    Example::

        >>> import pickle
        ... with dill_enabled(True):
        ...    pickle.dump(CutSet(...).filter(lambda c: c.duration < 5), open("cutset.pickle", "wb"))
    """
    previous = is_dill_enabled()
    set_dill_enabled(value)
    yield
    set_dill_enabled(previous)


[docs]class LazyTxtIterator: """ LazyTxtIterator is a thin wrapper over builtin ``open`` function to iterate over lines in a (possibly compressed) text file. It can also provide the number of lines via __len__ via fast newlines counting. """
[docs] def __init__(self, path: Pathlike, as_text_example: bool = True) -> None: self.path = path self.as_text_example = as_text_example self._len = None
def __iter__(self): from lhotse.cut.text import TextExample tot = 0 with open_best(self.path, "r") as f: for line in f: line = line.strip() if self.as_text_example: line = TextExample(line) yield line tot += 1 if self._len is None: self._len = tot def __len__(self) -> int: if self._len is None: self._len = count_newlines_fast(self.path) return self._len
class LazyJsonlIterator: """ LazyJsonlIterator provides the ability to read JSON lines as Python dicts. It can also provide the number of lines via __len__ via fast newlines counting. """ def __init__(self, path: Pathlike) -> None: self.path = path self._len = None def __iter__(self): tot = 0 with open_best(self.path, "r") as f: for line in f: data = decode_json_line(line) yield data tot += 1 if self._len is None: self._len = tot def __len__(self) -> int: if self._len is None: self._len = count_newlines_fast(self.path) return self._len class LazyManifestIterator(Dillable): """ LazyManifestIterator provides the ability to read Lhotse objects from a JSONL file on-the-fly, without reading its full contents into memory. This class is designed to be a partial "drop-in" replacement for ordinary dicts to support lazy loading of RecordingSet, SupervisionSet and CutSet. Since it does not support random access reads, some methods of these classes might not work properly. """ def __init__(self, path: Pathlike) -> None: assert extension_contains(".jsonl", path) or str(path) == "-" self.source = LazyJsonlIterator(path) @property def path(self) -> Pathlike: return self.source.path def __iter__(self): yield from map(deserialize_item, self.source) def __len__(self) -> int: return len(self.source) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) class LazyIteratorChain(Dillable): """ A thin wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. It iterates all underlying iterables sequentially. It also supports shuffling the sub-iterators when it's iterated over. This can be used to implement sharding (where each iterator is a shard) with randomized shard order. Every iteration of this object will increment an internal counter so that the next time it's iterated, the order of shards is again randomized. .. note:: if any of the input iterables is a dict, we'll iterate only its values. """ def __init__( self, *iterators: Iterable, shuffle_iters: bool = False, seed: Optional[Union[int, Literal["trng", "randomized"]]] = None, ) -> None: self.iterators = [] self.shuffle_iters = shuffle_iters self.seed = seed self.num_iters = 0 for it in iterators: # Auto-flatten LazyIteratorChain instances if any are passed if isinstance(it, LazyIteratorChain): for sub_it in it.iterators: self.iterators.append(sub_it) else: self.iterators.append(it) def __iter__(self): from lhotse.dataset.dataloading import resolve_seed iterators = self.iterators if self.shuffle_iters: if self.seed is None: rng = random # global Python RNG else: rng = random.Random(resolve_seed(self.seed) + self.num_iters) rng.shuffle(iterators) self.num_iters += 1 for it in iterators: if isinstance(it, dict): it = it.values() yield from it def __len__(self) -> int: return sum(len(it) for it in self.iterators) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) class LazyIteratorMultiplexer(Dillable): """ A wrapper over multiple iterators that enables to combine lazy manifests in Lhotse. During iteration, unlike :class:`.LazyIteratorChain`, :class:`.LazyIteratorMultiplexer` at each step randomly selects the iterable used to yield an item. Since the iterables might be of different length, we provide a ``weights`` parameter to let the user decide which iterables should be sampled more frequently than others. When an iterable is exhausted, we will keep sampling from the other iterables, until we exhaust them all, unless ``stop_early`` is set to ``True``. """ def __init__( self, *iterators: Iterable, stop_early: bool = False, weights: Optional[List[Union[int, float]]] = None, seed: Union[int, Literal["trng", "randomized"]] = 0, ) -> None: self.iterators = list(iterators) self.stop_early = stop_early self.seed = seed assert ( len(self.iterators) > 1 ), "There have to be at least two iterables to multiplex." if weights is None: self.weights = [1] * len(self.iterators) else: self.weights = weights assert len(self.iterators) == len(self.weights) def __iter__(self): from lhotse.dataset.dataloading import resolve_seed rng = random.Random(resolve_seed(self.seed)) iters = [iter(it) for it in self.iterators] exhausted = [False for _ in range(len(iters))] def should_continue(): if self.stop_early: return not any(exhausted) else: return not all(exhausted) while should_continue(): active_indexes, active_weights = zip( *[ (i, w) for i, (is_exhausted, w) in enumerate(zip(exhausted, self.weights)) if not is_exhausted ] ) idx = rng.choices(active_indexes, weights=active_weights, k=1)[0] selected = iters[idx] try: item = next(selected) yield item except StopIteration: exhausted[idx] = True continue def __len__(self) -> int: return sum(len(it) for it in self.iterators) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) class LazyInfiniteApproximateMultiplexer(Dillable): """ A variant of :class:`.LazyIteratorMultiplexer` that allows to control the number of iterables that are simultaneously open. It is useful for large-scale data sets where opening multiple file handles in many processes leads to exhaustion of the operating system resources. If the data sets are sharded, it is recommended to pass each shard as a separate iterator when creating objects of this class. It is OK to assign a dataset-level weight to each shard (e.g., if a dataset has a weight of 0.5, assign weight 0.5 to each of its shards). There are several differences between this class and :class:`.LazyIteratorMultiplexer`: * Objects of this class are infinite iterators. * We hold a list of ``max_open_streams`` open iterators at any given time. This list is filled by sampling input iterators with replacement. These differences are necessary to guarantee the weighted sampling property. If we did not sample with replacement or make it infinite, we would simply exhaust highly-weighted iterators towards the beginning of each "epoch" and keep sampling only lowly-weighted iterators towards the end of each "epoch". """ def __init__( self, *iterators: Iterable, stop_early: bool = False, weights: Optional[List[Union[int, float]]] = None, seed: Union[int, Literal["trng", "randomized"]] = 0, max_open_streams: Optional[int] = None, ) -> None: self.iterators = list(iterators) self.stop_early = stop_early self.seed = seed self.max_open_streams = max_open_streams if max_open_streams is None or max_open_streams > len(self.iterators): self.max_open_streams = len(self.iterators) assert len(self.iterators) > 0 self.weights = weights if weights is None: self.weights = [1] * len(self.iterators) assert len(self.iterators) == len(self.weights) assert ( self.max_open_streams is None or self.max_open_streams >= 1 ), f"{self.max_open_streams=}" def __iter__(self): """ Assumptions - we have N streams but can only open M at the time (M < N) - the streams are finite - each stream needs to be "short" to ensure the mux property - each stream may be interpreted as a shard belonging to some larger group of streams (e.g. multiple shards of a given dataset). """ from lhotse.dataset.dataloading import resolve_seed rng = random.Random(resolve_seed(self.seed)) def shuffled_streams(): # Create an infinite iterable of our streams. # Assume N is "small" enough that shuffling it will be quick # # we need to incorporate weights into shuffling here # and sample iterators with replacement. # consider it0=[shard00, shard01] with weight 0.95 # and it1=[shard10, shard11] with weight 0.05 # so we have 4 streams [shard{01}{01}] # if we just shuffle randomly and sample without replacement # per each "epoch" (epoch = 4 shards) then we would have # ignored the weights because we'll just exhaust it0 shards # towards the beginning of an "epoch" and then keep yielding # from it1 shards until the epoch is finished and we can sample # from it0 again... indexes = list(range(len(self.iterators))) while True: selected = rng.choices(indexes, self.weights, k=1)[0] yield self.iterators[selected], self.weights[selected] # Initialize an infinite sequence of finite streams. # It is sampled with weights and replacement from ``self.iterators``, # which are of length N. stream_source = shuffled_streams() # Sample the first M active streams to be multiplexed. # As streams get depleted, we will replace them with # new streams sampled from the stream source. active_streams = [None] * self.max_open_streams active_weights = [None] * self.max_open_streams stream_indexes = list(range(self.max_open_streams)) def sample_new_stream_at(pos: int) -> None: sampled_stream, sampled_weight = next(stream_source) active_streams[pos] = iter(sampled_stream) active_weights[pos] = sampled_weight for stream_pos in range(self.max_open_streams): sample_new_stream_at(stream_pos) # The actual multiplexing loop. while True: # Select a stream from the currently active streams. # We actually sample an index so that we know which position # to replace if a stream is exhausted. stream_pos = rng.choices( stream_indexes, weights=active_weights if sum(active_weights) > 0 else None, k=1, )[0] selected = active_streams[stream_pos] try: # Sample from the selected stream. item = next(selected) yield item except StopIteration: # The selected stream is exhausted. Replace it with another one, # and return a sample from the newly opened stream. sample_new_stream_at(stream_pos) item = next(active_streams[stream_pos]) yield item class LazyShuffler(Dillable): """ A wrapper over an iterable that enables lazy shuffling. The shuffling algorithm is reservoir-sampling based. See :func:`lhotse.utils.streaming_shuffle` for details. """ def __init__( self, iterator: Iterable, buffer_size: int = 10000, rng: Optional[random.Random] = None, ) -> None: self.iterator = iterator self.buffer_size = buffer_size self.rng = rng def __iter__(self): return iter( streaming_shuffle( iter(self.iterator), bufsize=self.buffer_size, rng=self.rng, ) ) def __len__(self) -> int: return len(self.iterator) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) class LazyFilter(Dillable): """ A wrapper over an iterable that enables lazy filtering. It works like Python's `filter` built-in by applying the filter predicate to each element and yielding it further if predicate returned ``True``. """ def __init__(self, iterator: Iterable, predicate: Callable[[Any], bool]) -> None: self.iterator = iterator self.predicate = predicate assert callable( self.predicate ), f"LazyFilter: 'predicate' arg must be callable (got {predicate})." if ( isinstance(self.predicate, types.LambdaType) and self.predicate.__name__ == "<lambda>" and not is_module_available("dill") ): warnings.warn( "A lambda was passed to LazyFilter: it may prevent you from forking this process. " "If you experience issues with num_workers > 0 in torch.utils.data.DataLoader, " "try passing a regular function instead." ) def __iter__(self): return filter(self.predicate, self.iterator) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: raise TypeError( "LazyFilter does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " "`.to_eager()`. Note that this will require loading the whole iterator into memory." ) class LazyMapper(Dillable): """ A wrapper over an iterable that enables lazy function evaluation on each item. It works like Python's `map` built-in by applying a callable ``fn`` to each element ``x`` and yielding the result of ``fn(x)`` further. New in Lhotse v1.22.0: ``apply_fn`` can be provided to decide whether ``fn`` should be applied to a given example or not (in which case it will return it as-is, i.e., it does not filter). """ def __init__( self, iterator: Iterable, fn: Callable[[Any], Any], apply_fn: Optional[Callable[[Any], bool]] = None, ) -> None: self.iterator = iterator self.fn = fn self.apply_fn = apply_fn assert callable(self.fn), f"LazyMapper: 'fn' arg must be callable (got {fn})." if self.apply_fn is not None: assert callable( self.apply_fn ), f"LazyMapper: 'apply_fn' arg must be callable (got {fn})." if ( (isinstance(self.fn, types.LambdaType) and self.fn.__name__ == "<lambda>") or ( isinstance(self.apply_fn, types.LambdaType) and self.apply_fn.__name__ == "<lambda>" ) and not is_dill_enabled() ): warnings.warn( "A lambda was passed to LazyMapper: it may prevent you from forking this process. " "If you experience issues with num_workers > 0 in torch.utils.data.DataLoader, " "try passing a regular function instead." ) def __iter__(self): if self.apply_fn is None: yield from map(self.fn, self.iterator) else: for item in self.iterator: if self.apply_fn(item): ans = self.fn(item) else: ans = item yield ans def __len__(self) -> int: return len(self.iterator) def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) class LazyFlattener(Dillable): """ A wrapper over an iterable of collections that flattens it to an iterable of items. Example:: >>> list_of_cut_sets: List[CutSet] = [CutSet(...), CutSet(...)] >>> list_of_cuts: List[Cut] = list(LazyFlattener(list_of_cut_sets)) """ def __init__(self, iterator: Iterable) -> None: self.iterator = iterator def __iter__(self): for cuts in self.iterator: yield from cuts def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: raise TypeError( "LazyFlattener does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " "`.to_eager()`. Note that this will require loading the whole iterator into memory." ) class LazyRepeater(Dillable): """ A wrapper over an iterable that enables to repeat it N times or infinitely (default). """ def __init__( self, iterator: Iterable, times: Optional[int] = None, preserve_id: bool = False ) -> None: self.iterator = iterator self.times = times self.preserve_id = preserve_id assert self.times is None or self.times > 0 def __iter__(self): epoch = 0 while self.times is None or epoch < self.times: if self.preserve_id: iterator = self.iterator else: iterator = LazyMapper( self.iterator, partial(attach_repeat_idx_to_id, idx=epoch) ) yield from iterator epoch += 1 def __len__(self) -> int: if self.times is None: raise AttributeError() return len(self.iterator) * self.times def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) class LazySlicer(Dillable): """ A wrapper over an iterable that enables selecting k-th element every n elements. """ def __init__(self, iterator: Iterable, k: int, n: int) -> None: self.iterator = iterator assert ( k < n ), f"When selecting k-th element every n elements, k must be less than n (got k={k} n={n})." self.k = k self.n = n def __iter__(self): for idx, item in enumerate(self.iterator): if idx % self.n == self.k: yield item def __add__(self, other) -> "LazyIteratorChain": return LazyIteratorChain(self, other) def __len__(self) -> int: raise TypeError( "LazySlicer does not support __len__ because it would require " "iterating over the whole iterator, which is not possible in a lazy fashion. " "If you really need to know the length, convert to eager mode first using " "`.to_eager()`. Note that this will require loading the whole iterator into memory." ) def attach_repeat_idx_to_id(item: Any, idx: int) -> Any: if not hasattr(item, "id"): return item return fastcopy(item, id=f"{item.id}_repeat{idx}") def count_newlines_fast(path: Pathlike): """ Counts newlines in a file using buffered chunk reads. The fastest possible option in Python according to: https://stackoverflow.com/a/68385697/5285891 (This is a slightly modified variant of that answer.) """ def _make_gen(reader): b = reader(2**16) while b: yield b b = reader(2**16) read_mode = "rb" if not str(path) == "-" else "r" with open_best(path, read_mode) as f: count = sum(buf.count(b"\n") for buf in _make_gen(f.read)) return count