Skip to content

mlm.unbatch

iter_unbatch(batch, length, *, dim=0, strict=True, broadcast_non_sliceable=False)

Convert a "batched" dict into an iterator of per-item dicts.

Parameters

batch: Mapping from field name -> batched value (tensor/ndarray/list/etc.). length: Number of items to unbatch (user-provided). dim: Which dimension is the batch dimension for array-likes (torch/numpy). For Python lists/tuples, only dim=0 makes sense. strict: If True, validate (when possible) that each value has batch length == length. broadcast_non_sliceable: If True, values that cannot be indexed are copied as-is into every item.

Yields

dict[str, Any] One dict per item in the batch.

unbatch(batch, length, *, dim=0, strict=False, broadcast_non_sliceable=True)

List-returning wrapper around iter_unbatch.