pymc.variational.DataLoader#

class pymc.variational.DataLoader(dataset, *, batch_size, shuffle=False, buffer_size=None, seed=None, sample_shape=None, dtype='float64', total_size=None, preprocess_fn=None)[source]#

Turn an out-of-core dataset into fixed-size minibatches for variational inference.

Like torch.utils.data.DataLoader, it batches (and optionally shuffles) an IterableDataset into a minibatch stream for variational inference. It is iterable and sized (len(loader) is the dataset size N). With bounded source chunks the full dataset is never resident at once.

Parameters:
datasetIterableDataset | Iterable[np.ndarray] | Callable[[], Iterator[np.ndarray]]

The source of rows. An IterableDataset, a re-iterable (including a plain np.ndarray), or a zero-arg factory returning a fresh iterator (preferred, so the stream can be restarted each epoch). It may yield single samples (e.g. the rows of a raw array) or blocks of any size; the loader re-batches them, in order, to exactly batch_size rows. Trailing rows that do not fill a final batch are dropped at the end of a pass, like drop_last=True in PyTorch (required here because the model observes a fixed-shape placeholder). With shuffle=True the dropped remainder differs per epoch; with a fixed replay order it is the same rows every pass.

batch_sizeint

Leading dimension of every yielded minibatch.

shufflebool, default False

If True, wrap the source in a bounded shuffle_buffer() of buffer_size rows. This only approximates i.i.d. batches for an already unordered stream; a bounded buffer cannot fix strongly time/row-ordered data (pre-shuffle on disk for that; see the module docstring).

buffer_sizeint, optional

Shuffle-buffer size in rows when shuffle=True. Defaults to 50 * batch_size. Ignored when shuffle=False. A buffer at least as large as the dataset holds all of it in memory (a full shuffle).

seedint, optional

Seed for the shuffle buffer (ignored when shuffle=False).

sample_shapetuple of int, optional

Trailing shape of a single observation. () for scalar observations, (k,) to stream k columns (e.g. features + the observed column). Defaults to dataset.shape[1:] for a raw np.ndarray source (its rows are the samples, like torch’s TensorDataset), else ().

dtypestr, default “float64”

Dtype each prepared batch is cast to; match the dtype of the pm.Data placeholder the batches are streamed into.

total_sizeint or “auto”, optional

The true dataset size N (a positive integer), or "auto" to infer it (from the source’s n_rows if available, else a single counting pass). Pass it on to the observed distribution as total_size=len(loader) so the minibatch log-likelihood is rescaled by N / batch_size (the same mechanism as pm.Minibatch). Unlike pm.Minibatch it cannot be inferred from a resident array; None warns at construction and a non-positive value raises (it would otherwise silently disable or invert the rescaling).

preprocess_fncallable(), optional

Pure transform applied to each batch before validation (e.g. normalization). It must preserve the row count and sample_shape; to select columns, do it at the source instead (parquet_source(columns=...)).

Methods

DataLoader.__init__(dataset, *, batch_size)

Attributes

batch_size

batches_seen

rows_streamed

Total rows streamed into the model (grows past N across epochs).

total_size

The dataset size N (pass to the distribution's total_size).