API reference#

This page provides an auto-generated summary of Xbatcher’s API.

Dataset.batch and DataArray.batch#

Dataset.batch.generator(*args, **kwargs)

Return a BatchGenerator via the batch accessor

DataArray.batch.generator(*args, **kwargs)

Return a BatchGenerator via the batch accessor

Core#

class xbatcher.BatchGenerator(ds: xarray.core.dataset.Dataset, input_dims: Dict[Hashable, int], input_overlap: Dict[Hashable, int] = {}, batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True)[source]#

Create generator for iterating through xarray datarrays / datasets in batches.

Parameters
dsxarray.Dataset or xarray.DataArray

The data to iterate over

input_dimsdict

A dictionary specifying the size of the inputs in each dimension, e.g. {'lat': 30, 'lon': 30} These are the dimensions the ML library will see. All other dimensions will be stacked into one dimension called sample.

input_overlapdict, optional

A dictionary specifying the overlap along each dimension e.g. {'lat': 3, 'lon': 3}

batch_dimsdict, optional

A dictionary specifying the size of the batch along each dimension e.g. {'time': 10}. These will always be iterated over.

concat_input_dimsbool, optional

If True, the dimension chunks specified in input_dims will be concatenated and stacked into the sample dimension. The batch index will be included as a new level input_batch in the sample coordinate. If False, the dimension chunks specified in input_dims will be iterated over.

preload_batchbool, optional

If True, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed.

Yields
ds_slicexarray.Dataset or xarray.DataArray

Slices of the array matching the given batch size specification.

__init__(ds: xarray.core.dataset.Dataset, input_dims: Dict[Hashable, int], input_overlap: Dict[Hashable, int] = {}, batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True)[source]#

Dataloaders#

class xbatcher.loaders.torch.MapDataset(*args: Any, **kwargs: Any)[source]#
__init__(X_generator, y_generator, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) None[source]#

PyTorch Dataset adapter for Xbatcher

Parameters
X_generatorxbatcher.BatchGenerator
y_generatorxbatcher.BatchGenerator
transformcallable, optional

A function/transform that takes in an array and returns a transformed version.

target_transformcallable, optional

A function/transform that takes in the target and transforms it.

class xbatcher.loaders.torch.IterableDataset(*args: Any, **kwargs: Any)[source]#
__init__(X_generator, y_generator) None[source]#

PyTorch Dataset adapter for Xbatcher

Parameters
X_generatorxbatcher.BatchGenerator
y_generatorxbatcher.BatchGenerator
class xbatcher.loaders.keras.CustomTFDataset(*args: Any, **kwargs: Any)[source]#
__init__(X_generator, y_generator, *, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None) None[source]#

Keras Dataset adapter for Xbatcher

Parameters
X_generatorxbatcher.BatchGenerator
y_generatorxbatcher.BatchGenerator
transformcallable, optional

A function/transform that takes in an array and returns a transformed version.

target_transformcallable, optional

A function/transform that takes in the target and transforms it.