Source code for pymc.variational.streaming

#   Copyright 2024 - present The PyMC Developers
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
"""Out-of-core minibatching for variational inference.

``pm.Minibatch`` random-indexes an array that is fully resident in memory; its
peak memory is therefore O(N) in the dataset size. This module instead streams
minibatches from an out-of-core source into a ``pm.Data`` placeholder, so peak
memory is set by the batch, the source chunk, and the optional shuffle buffer,
independent of N.

The API follows PyTorch's ``torch.utils.data``:

* :class:`IterableDataset`: a re-iterable, out-of-core source of rows
  (e.g. :func:`parquet_source` over a directory of shards). It never loads the
  whole dataset; it yields it a chunk at a time.
* :class:`DataLoader`: turns a dataset into fixed-size (optionally shuffled)
  minibatches; it is iterable (the minibatch stream) and sized. Note ``len(loader)``
  is the row count ``N`` (what the observed distribution needs for ``total_size``),
  not the batch count ``torch.utils.data.DataLoader.__len__`` returns.

With bounded source chunks the full data never sits in RAM at once. The model
graph observes only a ``(batch_size, *sample_shape)`` ``pm.Data`` placeholder
that is overwritten with the next minibatch every step. Passing a
directory of Parquet shards far larger than RAM still gives a model whose
resident footprint is one batch (:func:`parquet_source` reads one row group at
a time).

The unbiased-gradient rescaling is the same as for ``pm.Minibatch``: the
observed log-likelihood must be scaled by ``N / batch_size`` through the existing
:func:`~pymc.variational.minibatch_rv.create_minibatch_rv`. ``N`` is exactly
``len(loader)`` (the loader is sized; ``len`` returns the row count ``N``), so the
model passes ``total_size=len(loader)``. (Folding that scaling into the inference
step, so it drops out of the model body, is the next step in PyMC's VI rework.)

Batches have exactly ``batch_size`` rows, so each pass drops the final
``N mod batch_size`` rows (torch's ``drop_last``). With ``shuffle=True`` that
remainder is re-drawn every epoch, so all rows participate across epochs; with
a source that replays a fixed order, the same rows are dropped every pass (after
a one-time on-disk pre-shuffle that fixed remainder is a random subset).

One difference from ``pm.Minibatch`` is shuffling.
``pm.Minibatch`` draws a fresh uniform index over all N rows every step, so its
minibatches are i.i.d. by construction. A streaming source is only as well
mixed as the order it yields rows in: reading time/row-ordered data through a
bounded buffer is merely a block-shuffle, and the resulting non-representative
minibatches can bias the variational posterior.
Pre-shuffle the data once on disk (or interleave shards) and/or pass
``shuffle=True``.

Examples
--------
.. code-block:: python

    import numpy as np
    import pymc as pm
    from pymc.variational.streaming import DataLoader, parquet_source

    # The data was pre-shuffled on disk once (see the module note on shuffling),
    # so the loader streams it sequentially. The full table stays on disk.
    loader = DataLoader(
        parquet_source("shuffled/"),  # an IterableDataset over the shards
        batch_size=4096,
        sample_shape=(4,),  # 3 features + 1 observed column
        total_size="auto",  # infer N from Parquet metadata; N == len(loader)
    )

    with pm.Model() as model:
        b = pm.Normal("b", 0.0, 3.0, shape=4)
        batch = pm.Data("batch", np.zeros((4096, 4)))  # placeholder for one minibatch
        logit = b[0] + b[1] * batch[:, 0] + b[2] * batch[:, 1] + b[3] * batch[:, 2]
        pm.Bernoulli("y", logit_p=logit, observed=batch[:, 3], total_size=len(loader))

    # The loader is sized (len(loader) == N, what total_size needs) and iterable:
    # each epoch yields validated (batch_size, *sample_shape) minibatches. Stream
    # each into the "batch" placeholder with model.set_data before a step.
    with model:
        for minibatch in loader:
            model.set_data("batch", minibatch)
            ...  # one variational step over this minibatch
"""

from __future__ import annotations

import glob
import numbers
import os
import warnings

from collections.abc import Callable, Iterable, Iterator

import numpy as np

__all__ = ["DataLoader", "IterableDataset", "parquet_source", "shuffle_buffer"]


def _is_positive_int(value: object) -> bool:
    """True for a strictly positive integer (incl. numpy integer types), excluding bool."""
    return isinstance(value, numbers.Integral) and not isinstance(value, bool) and int(value) > 0


[docs] class IterableDataset: """A re-iterable, out-of-core source of rows, like ``torch.utils.data.IterableDataset``. Subclass and implement :meth:`__iter__` to yield ``np.ndarray`` blocks of rows (shape ``(rows, *sample_shape)``); :class:`DataLoader` re-batches those blocks into fixed-size minibatches. ``__iter__`` must return a fresh iterator each call so the dataset can be replayed across epochs. Optionally set :attr:`n_rows` (the total row count, if known cheaply, e.g. from file metadata) so a :class:`DataLoader` with ``total_size="auto"`` can resolve ``N`` without a counting pass. A plain zero-arg factory (``Callable[[], Iterator[np.ndarray]]``) or any re-iterable is also accepted directly by :class:`DataLoader`; this base class is only needed when you want to attach behavior or ``n_rows`` to a custom source. """ n_rows: int | None = None def __iter__(self) -> Iterator[np.ndarray]: raise NotImplementedError("IterableDataset subclasses must implement __iter__")
[docs] class DataLoader: """Turn an out-of-core dataset into fixed-size minibatches for variational inference. Like ``torch.utils.data.DataLoader``, it batches (and optionally shuffles) an :class:`IterableDataset` into a minibatch stream for variational inference. It is iterable and sized (``len(loader)`` is the dataset size ``N``). With bounded source chunks the full dataset is never resident at once. Parameters ---------- dataset : IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]] The source of rows. An :class:`IterableDataset`, a re-iterable (including a plain ``np.ndarray``), or a zero-arg factory returning a fresh iterator (preferred, so the stream can be restarted each epoch). It may yield single samples (e.g. the rows of a raw array) or blocks of any size; the loader re-batches them, in order, to exactly ``batch_size`` rows. Trailing rows that do not fill a final batch are dropped at the end of a pass, like ``drop_last=True`` in PyTorch (required here because the model observes a fixed-shape placeholder). With ``shuffle=True`` the dropped remainder differs per epoch; with a fixed replay order it is the same rows every pass. batch_size : int Leading dimension of every yielded minibatch. shuffle : bool, default False If ``True``, wrap the source in a bounded :func:`shuffle_buffer` of ``buffer_size`` rows. This only approximates i.i.d. batches for an already unordered stream; a bounded buffer cannot fix strongly time/row-ordered data (pre-shuffle on disk for that; see the module docstring). buffer_size : int, optional Shuffle-buffer size in rows when ``shuffle=True``. Defaults to ``50 * batch_size``. Ignored when ``shuffle=False``. A buffer at least as large as the dataset holds all of it in memory (a full shuffle). seed : int, optional Seed for the shuffle buffer (ignored when ``shuffle=False``). sample_shape : tuple of int, optional Trailing shape of a single observation. ``()`` for scalar observations, ``(k,)`` to stream ``k`` columns (e.g. features + the observed column). Defaults to ``dataset.shape[1:]`` for a raw ``np.ndarray`` source (its rows are the samples, like torch's ``TensorDataset``), else ``()``. dtype : str, default "float64" Dtype each prepared batch is cast to; match the dtype of the ``pm.Data`` placeholder the batches are streamed into. total_size : int or "auto", optional The true dataset size ``N`` (a positive integer), or ``"auto"`` to infer it (from the source's ``n_rows`` if available, else a single counting pass). Pass it on to the observed distribution as ``total_size=len(loader)`` so the minibatch log-likelihood is rescaled by ``N / batch_size`` (the same mechanism as ``pm.Minibatch``). Unlike ``pm.Minibatch`` it cannot be inferred from a resident array; ``None`` warns at construction and a non-positive value raises (it would otherwise silently disable or invert the rescaling). preprocess_fn : callable, optional Pure transform applied to each batch before validation (e.g. normalization). It must preserve the row count and ``sample_shape``; to select columns, do it at the source instead (``parquet_source(columns=...)``). """
[docs] def __init__( self, dataset: IterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], *, batch_size: int, shuffle: bool = False, buffer_size: int | None = None, seed: int | None = None, sample_shape: tuple[int, ...] | None = None, dtype: str = "float64", total_size: int | str | None = None, preprocess_fn: Callable[[np.ndarray], np.ndarray] | None = None, ): if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") if sample_shape is None: # A raw array is rows-of-samples; without this default a 2-D array # would be read as blocks of scalars and silently flattened. sample_shape = dataset.shape[1:] if isinstance(dataset, np.ndarray) else () sample_shape = tuple(sample_shape) raw_factory = _make_factory(dataset) source_factory = raw_factory if shuffle: if buffer_size is None: buffer_size = 50 * int(batch_size) # shuffle_buffer concatenates yields along the leading axis, so single # samples must be promoted to one-row blocks before shuffling. source_factory = shuffle_buffer( _block_factory(raw_factory, sample_shape), buffer_size=buffer_size, batch_size=batch_size, seed=seed, ) self._source_factory = source_factory if isinstance(total_size, str): if total_size != "auto": raise ValueError(f"total_size string must be 'auto', got {total_size!r}") # Count the unshuffled source: the shuffle wrapper drops the trailing # partial batch, so counting through it would undercount N. total_size = _auto_total_size(raw_factory, dataset, sample_shape) elif total_size is None: warnings.warn( "DataLoader created with total_size=None: the minibatch " "log-likelihood will not be rescaled and the posterior will be " "biased. Pass total_size=N (the true dataset size) or total_size='auto'.", UserWarning, stacklevel=2, ) elif not _is_positive_int(total_size): # 0 is falsy (the rescaling would be silently skipped) and a negative # value flips the sign of the data log-likelihood; raise on both. raise ValueError( "total_size must be a positive integer (the true dataset size N) so " "the minibatch log-likelihood is rescaled by N / batch_size; got " f"{total_size!r}." ) # Plain Python ints: create_minibatch_rv rejects np.int64 for total_size. self._batch_size = int(batch_size) self._sample_shape = sample_shape self._dtype = dtype self._total_size = None if total_size is None else int(total_size) self._preprocess_fn = preprocess_fn self._batches_seen = 0 self._rows_streamed = 0 self._warned_size = False
@property def batch_size(self) -> int: return self._batch_size @property def total_size(self) -> int | None: """The dataset size ``N`` (pass to the distribution's ``total_size``).""" return self._total_size @property def batches_seen(self) -> int: return self._batches_seen @property def rows_streamed(self) -> int: """Total rows streamed into the model (grows past ``N`` across epochs).""" return self._rows_streamed def _rebatched(self) -> Iterator[np.ndarray]: """A fresh pass of exactly ``batch_size``-row batches from the source.""" return _rebatch(self._source_factory(), self._batch_size, self._sample_shape) def __iter__(self) -> Iterator[np.ndarray]: """Yield one epoch of validated ``(batch_size, *sample_shape)`` minibatches. Stream each into the model's ``pm.Data`` placeholder with ``model.set_data`` before a step. Plain iteration leaves :attr:`batches_seen` / :attr:`rows_streamed` untouched (it does not run the internal accounting path); re-iterate the loader for another epoch. """ for batch in self._rebatched(): yield self._prepare(batch) def __len__(self) -> int: """The dataset size ``N`` (row count); pass it to the distribution's ``total_size``. ``total_size=len(loader)`` is how the model gets the ``N / batch_size`` rescaling. Note this returns the row count ``N``, not the batch count that ``torch.utils.data.DataLoader.__len__`` returns; ``total_size`` needs ``N``. :attr:`total_size` is the same value. """ if self._total_size is None: raise TypeError( "len(DataLoader) is the dataset size N, but this loader was built with " "total_size=None; construct it with total_size=N or total_size='auto'." ) return self._total_size def _stream_batches(self) -> Iterator[np.ndarray]: """One epoch of prepared minibatches, with accounting (the consumer's path). Like :meth:`__iter__` but it updates :attr:`batches_seen` / :attr:`rows_streamed` and runs the one-shot ``total_size`` sanity check on the pass's final batch. The rebatcher is kept one batch ahead so the check still fires when a fit stops exactly at the pass boundary; without the lookahead the generator would be abandoned right before its epilogue. :meth:`__iter__` stays side-effect-free so plain iteration does not mutate counters. """ seen_this_pass = 0 it = self._rebatched() batch = next(it, None) while batch is not None: following = next(it, None) prepared = self._prepare(batch) self._batches_seen += 1 self._rows_streamed += int(prepared.shape[0]) seen_this_pass += int(prepared.shape[0]) if following is None: self._maybe_warn_total_size(seen_this_pass) yield prepared batch = following def _prepare(self, batch: np.ndarray) -> np.ndarray: """Preprocess, validate, and return an owned copy of one batch. A source may legitimately yield views into a reused array; the copy prevents the consumer from aliasing it. """ if self._preprocess_fn is not None: batch = self._preprocess_fn(batch) self._validate(batch) return np.array(batch, dtype=self._dtype) def _maybe_warn_total_size(self, seen: int) -> None: """Warn once if ``total_size`` is inconsistent with the rows of one full pass. ``seen`` is the row count of the pass that just completed (not the cumulative :attr:`rows_streamed`, which keeps growing across partial streams and earlier fits). A correct ``N`` satisfies ``seen <= N < seen + batch_size`` after a full pass (the trailing partial batch is dropped), so that window never warns; outside it a 10% slack absorbs sources that are only approximately sized. """ if self._warned_size or self._total_size is None: return self._warned_size = True if not seen or seen <= self._total_size < seen + self._batch_size: return if abs(self._total_size - seen) > 0.1 * seen: warnings.warn( f"total_size={self._total_size} disagrees with the {seen} rows streamed " f"in one full pass; the N/batch_size rescaling, and therefore the " f"posterior width, is likely wrong. Pass the true dataset size (or, if " f"'auto' resolved it from the source's n_rows, fix that attribute).", UserWarning, stacklevel=3, ) def _validate(self, batch: np.ndarray) -> None: if not isinstance(batch, np.ndarray): raise TypeError(f"expected np.ndarray batch, got {type(batch).__name__}") if batch.ndim < 1: raise ValueError( "batch needs a leading batch dimension; got a scalar array with " f"shape {batch.shape}." ) if batch.shape[0] != self._batch_size: raise ValueError( f"batch shape[0] = {batch.shape[0]} does not match batch_size = {self._batch_size}." ) if batch.shape[1:] != self._sample_shape: raise ValueError( f"batch sample-shape {batch.shape[1:]} does not match declared " f"sample_shape={self._sample_shape}" )
[docs] def shuffle_buffer( chunk_source: Callable[[], Iterator[np.ndarray]], *, buffer_size: int, batch_size: int, seed: int | None = None, ) -> Callable[[], Iterator[np.ndarray]]: """Wrap a chunk source into a shuffled, fixed-size batch source. Accumulates rows from ``chunk_source`` into a buffer of at least ``buffer_size`` rows, shuffles it, and yields ``batch_size`` slices; rows that do not fill a final batch are carried over into the next buffer (never dropped) until the source is exhausted, at which point a single trailing partial batch (< ``batch_size`` rows) is dropped. This approximates i.i.d. minibatches from an unordered or pre-shuffled stream. :class:`DataLoader` calls this for you when ``shuffle=True``; use it directly when you want explicit control over ``buffer_size`` independently of the loader. It does not by itself fix a strongly time/row-ordered stream (a bounded buffer only block-shuffles such data); pre-shuffle on disk, or interleave shards into ``chunk_source``, for that. ``buffer_size`` is a lower bound: each fill accumulates at least ``max(buffer_size, batch_size)`` rows before shuffling (so a ``buffer_size`` smaller than ``batch_size`` still yields full batches; the final fill stops at whatever the source has left), and the chunk that crosses the threshold is kept whole, so the buffer holds fewer than ``max(buffer_size, batch_size)`` plus one chunk's rows. Concatenating a fill into one shuffleable array transiently allocates a second copy of those rows, so peak allocation is about twice that bound. Each epoch (each call of the returned factory) draws a fresh permutation from a sub-stream of ``seed``, so the shuffle order differs across epochs while staying reproducible for a given ``seed``. """ if not _is_positive_int(batch_size): raise ValueError(f"batch_size must be a positive integer, got {batch_size!r}") if not _is_positive_int(buffer_size): raise ValueError(f"buffer_size must be a positive integer, got {buffer_size!r}") seed_seq = np.random.SeedSequence(seed) def factory() -> Iterator[np.ndarray]: # A fresh sub-stream per epoch: re-iterating reshuffles instead of # replaying one fixed permutation, yet stays reproducible per seed. rng = np.random.default_rng(seed_seq.spawn(1)[0]) # A factory may return a re-iterable (a list of chunks, ...); normalize so # each buffer fill continues one stream instead of restarting it forever. it = iter(chunk_source()) carry: np.ndarray | None = None exhausted = False # Accumulate at least one batch even when buffer_size < batch_size, # otherwise the guard below would silently discard the whole stream. target = max(buffer_size, batch_size) while not exhausted: bufs: list[np.ndarray] = [] have = 0 if carry is not None: bufs.append(carry) have += carry.shape[0] carry = None for arr in it: a = np.asarray(arr) bufs.append(a) have += a.shape[0] if have >= target: break else: exhausted = True if have < batch_size: # Only reachable once the source is exhausted: drop the final # partial batch. return buf = np.concatenate(bufs, axis=0) rng.shuffle(buf) n_full = buf.shape[0] // batch_size for i in range(n_full): yield buf[i * batch_size : (i + 1) * batch_size] rem = buf.shape[0] - n_full * batch_size carry = buf[n_full * batch_size :].copy() if rem else None # Forward a known row count so total_size="auto" stays metadata-cheap # through the shuffle wrapper. source_n_rows = getattr(chunk_source, "n_rows", None) if source_n_rows is not None: factory.n_rows = source_n_rows # type: ignore[attr-defined] return factory
def _promote_to_block(a: np.ndarray, sample_shape: tuple[int, ...]) -> np.ndarray: """Return ``a`` as a ``(rows, *sample_shape)`` block; a single sample becomes one row.""" if a.shape == sample_shape: return a[None, ...] if a.ndim != len(sample_shape) + 1 or a.shape[1:] != sample_shape: raise ValueError( f"source yielded shape {a.shape}; expected one sample of shape " f"{sample_shape} or a (rows, *sample_shape) block; if the source is " f"right, declare its trailing shape with DataLoader(sample_shape=...)" ) return a def _block_factory( factory: Callable[[], Iterator[np.ndarray]], sample_shape: tuple[int, ...], ) -> Callable[[], Iterator[np.ndarray]]: """Wrap ``factory`` so every yield is a block, promoting single samples. :func:`shuffle_buffer` counts and concatenates yields along the leading axis, so single-sample yields (e.g. the rows of a raw array) must be promoted to one-row blocks before shuffling. A known ``.n_rows`` is forwarded. """ def f() -> Iterator[np.ndarray]: for arr in factory(): yield _promote_to_block(np.asarray(arr), sample_shape) n_rows = getattr(factory, "n_rows", None) if n_rows is not None: f.n_rows = n_rows # type: ignore[attr-defined] return f def _rebatch( blocks: Iterable[np.ndarray], batch_size: int, sample_shape: tuple[int, ...], ) -> Iterator[np.ndarray]: """Slice a stream of samples/blocks into exact ``batch_size``-row batches, in order. Accepts single samples (shape ``sample_shape``, e.g. the rows of a raw array) and blocks of any size (shape ``(rows, *sample_shape)``), carrying remainders across blocks so no row is lost mid-stream. Trailing rows that do not fill a final batch are dropped when the stream ends (``drop_last=True`` behavior; the model observes a fixed-shape placeholder, so a partial batch cannot be fed). Sources that already yield exact ``batch_size`` blocks (e.g. :func:`shuffle_buffer`) pass through without copying. """ buf: list[np.ndarray] = [] have = 0 for arr in blocks: a = _promote_to_block(np.asarray(arr), sample_shape) buf.append(a) have += a.shape[0] if have < batch_size: continue merged = np.concatenate(buf, axis=0) if len(buf) > 1 else buf[0] n_full = merged.shape[0] // batch_size for i in range(n_full): yield merged[i * batch_size : (i + 1) * batch_size] rem = merged.shape[0] - n_full * batch_size buf = [merged[n_full * batch_size :].copy()] if rem else [] have = rem def _make_factory( source: Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]], ) -> Callable[[], Iterator[np.ndarray]]: """Coerce ``source`` into a zero-arg callable returning a fresh iterator. A callable that is not itself an iterator is treated as the factory; a bare iterator is wrapped (and refuses a second epoch); any other iterable (incl. an :class:`IterableDataset`) is re-``iter``-ed each epoch. A known ``.n_rows`` is forwarded onto the returned factory so ``total_size="auto"`` stays cheap. """ if callable(source) and not isinstance(source, Iterator): # A factory may return any iterable (a list of batches, a generator, ...); # normalize so the loader always pulls from a true iterator. def _factory() -> Iterator[np.ndarray]: return iter(source()) # type: ignore[operator] elif isinstance(source, Iterator): consumed = {"done": False} def _factory() -> Iterator[np.ndarray]: if consumed["done"]: raise RuntimeError( "source is a bare iterator and was already consumed; the loader " "restarts the stream each epoch, so pass a zero-arg factory or a " "re-iterable instead" ) consumed["done"] = True return source else: def _factory() -> Iterator[np.ndarray]: return iter(source) n_rows = getattr(source, "n_rows", None) if n_rows is not None: _factory.n_rows = n_rows # type: ignore[attr-defined] return _factory def _auto_total_size( factory: Callable[[], Iterator[np.ndarray]], source: object, sample_shape: tuple[int, ...] = (), ) -> int: """Resolve ``total_size="auto"``: a source ``.n_rows`` (cheap) else a counting pass. Fast path: if ``source`` advertises ``.n_rows`` (e.g. :func:`parquet_source`, which reads it from Parquet metadata without scanning the data) use it directly. Otherwise do a single counting pass over a finite, re-readable source. A bare one-shot iterator cannot be auto-counted (counting consumes it) and an infinite stream would make the pass hang; both must pass ``total_size`` explicitly. """ n = getattr(source, "n_rows", None) if n is None: n = getattr(factory, "n_rows", None) if n is not None: if not _is_positive_int(n): raise ValueError(f"source.n_rows must be a positive integer, got {n!r}") return int(n) if isinstance(source, Iterator): raise ValueError( "total_size='auto' needs a re-readable source (a zero-arg factory or an " "iterable), not a one-shot iterator; pass total_size=N explicitly instead." ) warnings.warn( "total_size='auto' is doing a full counting pass over the source; for a cheap " "path use a source exposing .n_rows (e.g. parquet_source, from Parquet metadata).", UserWarning, stacklevel=3, ) first_iter = factory() count = 0 for chunk in first_iter: a = np.asarray(chunk) # A yield of shape exactly `sample_shape` is one sample, not a block. count += 1 if a.shape == sample_shape else int(a.shape[0]) if count <= 0: raise ValueError("total_size='auto' counted 0 rows (empty or non-re-readable source).") # A genuine factory yields a fresh, non-empty stream each call; one that # returns the same exhausted iterator (or a new generator over consumed # state) would leave the loader with nothing to stream. The probe costs one # chunk, which the counting pass has already dwarfed. second_iter = factory() if second_iter is first_iter or next(second_iter, None) is None: raise ValueError( "total_size='auto' counted rows but the factory's next stream was empty " "(it returns the same one-shot iterator, or closes over an already-" "consumed one); pass a factory that creates a fresh iterator each call, " "or total_size=N explicitly." ) return count class _ParquetDataset(IterableDataset): """An :class:`IterableDataset` over a directory of Parquet shards. Yields one ``(rows, n_columns)`` array per row group (so peak read memory is one row group, not one file), in the fixed column order chosen at construction, and exposes :attr:`n_rows` read from Parquet metadata (no data scan). """ def __init__(self, paths: list[str], columns: list[str], n_rows: int): self._paths = paths self._columns = columns self.n_rows = n_rows def __iter__(self) -> Iterator[np.ndarray]: import pyarrow as pa import pyarrow.parquet as pq for path in self._paths: file = pq.ParquetFile(path) schema = file.schema_arrow missing = [c for c in self._columns if c not in schema.names] if missing: # read_row_group(columns=...) silently drops unknown names, so a # malformed shard must be named here, not surface as a bare # KeyError with no path. raise ValueError(f"columns {missing} not found in {path!r}") non_numeric = [ c for c in self._columns if not ( pa.types.is_integer(schema.field(c).type) or pa.types.is_floating(schema.field(c).type) or pa.types.is_boolean(schema.field(c).type) ) ] if non_numeric: # parquet_source validates types against the first shard only; a # later shard whose column turned non-numeric would otherwise # become an object array and fail at the batch cast with no path. raise ValueError( f"columns {non_numeric} in {path!r} are not numeric and cannot be " f"streamed into a float batch; select numeric columns with columns=." ) for i in range(file.metadata.num_row_groups): table = file.read_row_group(i, columns=self._columns) # Stack by the frozen column names, not the file's own order, so # a shard with a permuted schema cannot silently swap features. yield np.column_stack([table.column(c).to_numpy() for c in self._columns])
[docs] def parquet_source( directory: str, *, columns: list[str] | None = None, pattern: str = "*.parquet", ) -> _ParquetDataset: """An :class:`IterableDataset` over a directory of Parquet files. Yields one ``(rows, n_columns)`` array per row group (one or more per file), so peak read memory is one row group, not one file. The column order is frozen at construction — ``columns`` if given, else the first file's schema order — and every shard is read in that order, so a shard with a permuted schema cannot silently reorder features mid-stream. Carries an ``n_rows`` attribute read from Parquet metadata (no data scan) so that ``DataLoader(parquet_source(dir), ..., total_size="auto")`` resolves the dataset size for free. Pass ``shuffle=True`` to the :class:`DataLoader` (or wrap in :func:`shuffle_buffer`) to get shuffled batches. """ # pyarrow is an optional dependency, so it is imported on use. import pyarrow as pa import pyarrow.parquet as pq paths = sorted(glob.glob(os.path.join(directory, pattern))) if not paths: raise ValueError(f"no Parquet files match {os.path.join(directory, pattern)!r}") schema = pq.read_schema(paths[0]) if columns is None: columns = list(schema.names) else: missing = sorted(set(columns) - set(schema.names)) if missing: raise ValueError( f"columns {missing} not found in {paths[0]!r}; available: {sorted(schema.names)}" ) non_numeric = [ c for c in columns if not ( pa.types.is_integer(schema.field(c).type) or pa.types.is_floating(schema.field(c).type) or pa.types.is_boolean(schema.field(c).type) ) ] if non_numeric: # A string/dictionary column would turn whole chunks object-dtype and only # fail later at the batch cast, without naming the column. raise ValueError( f"columns {non_numeric} in {paths[0]!r} are not numeric and cannot be " f"streamed into a float batch; select numeric columns with columns=." ) n_rows = sum(pq.read_metadata(p).num_rows for p in paths) return _ParquetDataset(paths, columns, n_rows)