xbatcher.BatchGenerator#
- class xbatcher.BatchGenerator(ds: xarray.core.dataset.Dataset | xarray.core.dataarray.DataArray, input_dims: dict[collections.abc.Hashable, int], input_overlap: Optional[dict[collections.abc.Hashable, int]] = None, batch_dims: Optional[dict[collections.abc.Hashable, int]] = None, concat_input_dims: bool = False, preload_batch: bool = True, cache: Optional[dict[str, Any]] = None, cache_preprocess: Optional[Callable] = None)[source]#
Create generator for iterating through Xarray DataArrays / Datasets in batches.
- Parameters:
- ds
xarray.Datasetorxarray.DataArray The data to iterate over
- input_dimsdict
A dictionary specifying the size of the inputs in each dimension, e.g.
{'lat': 30, 'lon': 30}These are the dimensions the ML library will see. All other dimensions will be stacked into one dimension calledsample.- input_overlapdict, optional
A dictionary specifying the overlap along each dimension e.g.
{'lat': 3, 'lon': 3}- batch_dimsdict, optional
A dictionary specifying the size of the batch along each dimension e.g.
{'time': 10}. These will always be iterated over.- concat_input_dimsbool, optional
If
True, the dimension chunks specified ininput_dimswill be concatenated and stacked into thesampledimension. The batch index will be included as a new levelinput_batchin thesamplecoordinate. IfFalse, the dimension chunks specified ininput_dimswill be iterated over.- preload_batchbool, optional
If
True, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed.- cachedict, optional
Dict-like object to cache batches in (e.g., Zarr DirectoryStore). Note: The caching API is experimental and subject to change.
- cache_preprocess: callable, optional
A function to apply to batches prior to caching. Note: The caching API is experimental and subject to change.
- ds
- Yields:
- ds_slice
xarray.Datasetorxarray.DataArray Slices of the array matching the given batch size specification.
- ds_slice
- __init__(ds: xarray.core.dataset.Dataset | xarray.core.dataarray.DataArray, input_dims: dict[collections.abc.Hashable, int], input_overlap: Optional[dict[collections.abc.Hashable, int]] = None, batch_dims: Optional[dict[collections.abc.Hashable, int]] = None, concat_input_dims: bool = False, preload_batch: bool = True, cache: Optional[dict[str, Any]] = None, cache_preprocess: Optional[Callable] = None)[source]#
Methods
__init__(ds, input_dims[, input_overlap, ...])Attributes
batch_dimsconcat_input_dimsinput_dimsinput_overlappreload_batch