xbatcher: Batch Generation from Xarray Datasets#

Xbatcher is a small library for iterating xarray DataArrays in batches. The goal is to make it easy to feed xarray datasets to machine learning libraries such as Keras.

Installation#

Xbatcher can be installed from PyPI as:

pip install xbatcher

Or via Conda as:

conda install -c conda-forge xbatcher

Or from source as:

pip install git+https://github.com/xarray-contrib/xbatcher.git

Basic Usage#

Let’s say we have an xarray dataset

In [1]: import xarray as xr

In [2]: import numpy as np

In [3]: da = xr.DataArray(np.random.rand(1000, 100, 100), name='foo',
   ...:                   dims=['time', 'y', 'x']).chunk({'time': 1})
   ...: 

In [4]: da
Out[4]: 
<xarray.DataArray 'foo' (time: 1000, y: 100, x: 100)>
dask.array<xarray-<this-array>, shape=(1000, 100, 100), dtype=float64, chunksize=(1, 100, 100), chunktype=numpy.ndarray>
Dimensions without coordinates: time, y, x

and we want to create batches along the time dimension. We can do it like this

In [5]: import xbatcher

In [6]: bgen = xbatcher.BatchGenerator(da, {'time': 10})

In [7]: for batch in bgen:
   ...:     pass
   ...: batch
   ...: 
Out[7]: 
<xarray.DataArray 'foo' (sample: 10000, time: 10)>
array([[0.80642756, 0.26770488, 0.99932225, ..., 0.72890861, 0.89804776,
        0.94005062],
       [0.42923058, 0.67952858, 0.82109481, ..., 0.21543373, 0.40831671,
        0.82564444],
       [0.67910516, 0.27079572, 0.73764915, ..., 0.09239936, 0.44451992,
        0.19537579],
       ...,
       [0.2468853 , 0.78345663, 0.19955691, ..., 0.08913012, 0.01200241,
        0.20069452],
       [0.0839067 , 0.20304608, 0.92391009, ..., 0.62428469, 0.87942952,
        0.81704658],
       [0.06223387, 0.55298096, 0.5590063 , ..., 0.51424031, 0.84413156,
        0.09685334]])
Coordinates:
  * sample   (sample) object MultiIndex
  * y        (sample) int64 0 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
  * x        (sample) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Dimensions without coordinates: time

or via a built-in Xarray accessor:

In [8]: import xbatcher

In [9]: for batch in da.batch.generator({'time': 10}):
   ...:     pass
   ...: batch
   ...: 
Out[9]: 
<xarray.DataArray 'foo' (sample: 10000, time: 10)>
array([[0.80642756, 0.26770488, 0.99932225, ..., 0.72890861, 0.89804776,
        0.94005062],
       [0.42923058, 0.67952858, 0.82109481, ..., 0.21543373, 0.40831671,
        0.82564444],
       [0.67910516, 0.27079572, 0.73764915, ..., 0.09239936, 0.44451992,
        0.19537579],
       ...,
       [0.2468853 , 0.78345663, 0.19955691, ..., 0.08913012, 0.01200241,
        0.20069452],
       [0.0839067 , 0.20304608, 0.92391009, ..., 0.62428469, 0.87942952,
        0.81704658],
       [0.06223387, 0.55298096, 0.5590063 , ..., 0.51424031, 0.84413156,
        0.09685334]])
Coordinates:
  * sample   (sample) object MultiIndex
  * y        (sample) int64 0 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
  * x        (sample) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Dimensions without coordinates: time