xbatcher: Batch Generation from Xarray Datasets#

Xbatcher is a small library for iterating Xarray DataArrays and Datasets in batches. The goal is to make it easy to feed Xarray objects to machine learning libraries such as Keras.

Installation#

Xbatcher can be installed from PyPI as:

python -m pip install xbatcher

Or via Conda as:

conda install -c conda-forge xbatcher

Or from source as:

python -m pip install git+https://github.com/xarray-contrib/xbatcher.git

Optional Dependencies#

Note

The required dependencies installed with Xbatcher are Xarray, Dask, and NumPy. You will need to separately install TensorFlow or PyTorch to use those data loaders or Xarray accessors.

To install Xbatcher and PyTorch via Conda:

conda install -c conda-forge xbatcher pytorch

Or via PyPI:

python -m pip install xbatcher[torch]

To install Xbatcher and TensorFlow via Conda:

conda install -c conda-forge xbatcher tensorflow

Or via PyPI:

python -m pip install xbatcher[tensorflow]

Basic Usage#

Let’s say we have an Xarray Dataset

In [1]: import xarray as xr

In [2]: import numpy as np

In [3]: da = xr.DataArray(np.random.rand(1000, 100, 100), name='foo',
   ...:                   dims=['time', 'y', 'x']).chunk({'time': 1})
   ...: 

In [4]: da
Out[4]: 
<xarray.DataArray 'foo' (time: 1000, y: 100, x: 100)> Size: 80MB
dask.array<xarray-<this-array>, shape=(1000, 100, 100), dtype=float64, chunksize=(1, 100, 100), chunktype=numpy.ndarray>
Dimensions without coordinates: time, y, x

and we want to create batches along the time dimension. We can do it like this

In [5]: import xbatcher

In [6]: bgen = xbatcher.BatchGenerator(da, {'time': 10})

In [7]: for batch in bgen:
   ...:     pass
   ...: batch
   ...: 
Out[7]: 
<xarray.DataArray 'foo' (sample: 10000, time: 10)> Size: 800kB
array([[0.18331244, 0.95401645, 0.66911297, ..., 0.6002228 , 0.91207246,
        0.72293461],
       [0.77464729, 0.88832204, 0.24764826, ..., 0.21678624, 0.57430225,
        0.53489885],
       [0.67427496, 0.33217619, 0.29205809, ..., 0.73499709, 0.12988777,
        0.82830798],
       ...,
       [0.84169892, 0.6524452 , 0.89884177, ..., 0.28966143, 0.08839356,
        0.97866952],
       [0.23696842, 0.69764846, 0.59524178, ..., 0.46733703, 0.41171854,
        0.37998778],
       [0.54160734, 0.18213067, 0.78172645, ..., 0.70750948, 0.2426151 ,
        0.87082846]])
Coordinates:
  * sample   (sample) object 80kB MultiIndex
  * y        (sample) int64 80kB 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99
  * x        (sample) int64 80kB 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
Dimensions without coordinates: time

or via a built-in Xarray accessor:

In [8]: import xbatcher

In [9]: for batch in da.batch.generator({'time': 10}):
   ...:     pass
   ...: batch
   ...: 
Out[9]: 
<xarray.DataArray 'foo' (sample: 10000, time: 10)> Size: 800kB
array([[0.18331244, 0.95401645, 0.66911297, ..., 0.6002228 , 0.91207246,
        0.72293461],
       [0.77464729, 0.88832204, 0.24764826, ..., 0.21678624, 0.57430225,
        0.53489885],
       [0.67427496, 0.33217619, 0.29205809, ..., 0.73499709, 0.12988777,
        0.82830798],
       ...,
       [0.84169892, 0.6524452 , 0.89884177, ..., 0.28966143, 0.08839356,
        0.97866952],
       [0.23696842, 0.69764846, 0.59524178, ..., 0.46733703, 0.41171854,
        0.37998778],
       [0.54160734, 0.18213067, 0.78172645, ..., 0.70750948, 0.2426151 ,
        0.87082846]])
Coordinates:
  * sample   (sample) object 80kB MultiIndex
  * y        (sample) int64 80kB 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99
  * x        (sample) int64 80kB 0 1 2 3 4 5 6 7 8 ... 92 93 94 95 96 97 98 99
Dimensions without coordinates: time