pymc.variational.DataLoader#
- class pymc.variational.DataLoader(dataset, *, batch_size, shuffle=False, buffer_size=None, seed=None, sample_shape=None, dtype='float64', total_size=None, preprocess_fn=None)[source]#
Turn an out-of-core dataset into fixed-size minibatches for variational inference.
Like
torch.utils.data.DataLoader, it batches (and optionally shuffles) anIterableDatasetinto a minibatch stream for variational inference. It is iterable and sized (len(loader)is the dataset sizeN). With bounded source chunks the full dataset is never resident at once.- Parameters:
- dataset
IterableDataset|Iterable[np.ndarray] |Callable[[],Iterator[np.ndarray]] The source of rows. An
IterableDataset, a re-iterable (including a plainnp.ndarray), or a zero-arg factory returning a fresh iterator (preferred, so the stream can be restarted each epoch). It may yield single samples (e.g. the rows of a raw array) or blocks of any size; the loader re-batches them, in order, to exactlybatch_sizerows. Trailing rows that do not fill a final batch are dropped at the end of a pass, likedrop_last=Truein PyTorch (required here because the model observes a fixed-shape placeholder). Withshuffle=Truethe dropped remainder differs per epoch; with a fixed replay order it is the same rows every pass.- batch_size
int Leading dimension of every yielded minibatch.
- shufflebool, default
False If
True, wrap the source in a boundedshuffle_buffer()ofbuffer_sizerows. This only approximates i.i.d. batches for an already unordered stream; a bounded buffer cannot fix strongly time/row-ordered data (pre-shuffle on disk for that; see the module docstring).- buffer_size
int, optional Shuffle-buffer size in rows when
shuffle=True. Defaults to50 * batch_size. Ignored whenshuffle=False. A buffer at least as large as the dataset holds all of it in memory (a full shuffle).- seed
int, optional Seed for the shuffle buffer (ignored when
shuffle=False).- sample_shape
tupleofint, optional Trailing shape of a single observation.
()for scalar observations,(k,)to streamkcolumns (e.g. features + the observed column). Defaults todataset.shape[1:]for a rawnp.ndarraysource (its rows are the samples, like torch’sTensorDataset), else().- dtype
str, default “float64” Dtype each prepared batch is cast to; match the dtype of the
pm.Dataplaceholder the batches are streamed into.- total_size
intor “auto”, optional The true dataset size
N(a positive integer), or"auto"to infer it (from the source’sn_rowsif available, else a single counting pass). Pass it on to the observed distribution astotal_size=len(loader)so the minibatch log-likelihood is rescaled byN / batch_size(the same mechanism aspm.Minibatch). Unlikepm.Minibatchit cannot be inferred from a resident array;Nonewarns at construction and a non-positive value raises (it would otherwise silently disable or invert the rescaling).- preprocess_fn
callable(), optional Pure transform applied to each batch before validation (e.g. normalization). It must preserve the row count and
sample_shape; to select columns, do it at the source instead (parquet_source(columns=...)).
- dataset
Methods
DataLoader.__init__(dataset, *, batch_size)Attributes
batch_sizebatches_seenrows_streamedTotal rows streamed into the model (grows past
Nacross epochs).total_sizeThe dataset size
N(pass to the distribution'stotal_size).