xbatcher: Batch Generation from Xarray Datasets

Xbatcher is a small library for iterating xarray DataArrays in batches. The goal is to make it easy to feed xarray datasets to machine learning libraries such as Keras.

Installation

Xbatcher can be installed from PyPI as:

pip install xbatcher

Or via Conda as:

conda install -c conda-forge xbatcher

Or from source as:

pip install git+https://github.com/pangeo-data/xbatcher.git

Basic Usage

Let’s say we have an xarray dataset

In [1]: import xarray as xr

In [2]: import numpy as np

In [3]: da = xr.DataArray(np.random.rand(1000, 100, 100), name='foo',
   ...:                   dims=['time', 'y', 'x']).chunk({'time': 1})
   ...: 

In [4]: da
Out[4]: 
<xarray.DataArray 'foo' (time: 1000, y: 100, x: 100)>
dask.array<xarray-<this-array>, shape=(1000, 100, 100), dtype=float64, chunksize=(1, 100, 100), chunktype=numpy.ndarray>
Dimensions without coordinates: time, y, x

and we want to create batches along the time dimension. We can do it like this

In [5]: import xbatcher

In [6]: bgen = xbatcher.BatchGenerator(da, {'time': 10})

In [7]: for batch in bgen:
   ...:     pass
   ...: batch
   ...: 
Out[7]: 
<xarray.Dataset>
Dimensions:  (time: 10, sample: 10000)
Coordinates:
  * sample   (sample) object MultiIndex
  * y        (sample) int64 0 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
  * x        (sample) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Dimensions without coordinates: time
Data variables:
    foo      (sample, time) float64 0.07604 0.1066 0.7487 ... 0.928 0.538 0.1897

or via a built-in Xarray accessor:

In [8]: import xbatcher

In [9]: for batch in da.batch.generator({'time': 10}):
   ...:     pass
   ...: batch
   ...: 
Out[9]: 
<xarray.Dataset>
Dimensions:  (time: 10, sample: 10000)
Coordinates:
  * sample   (sample) object MultiIndex
  * y        (sample) int64 0 0 0 0 0 0 0 0 0 0 ... 99 99 99 99 99 99 99 99 99
  * x        (sample) int64 0 1 2 3 4 5 6 7 8 9 ... 91 92 93 94 95 96 97 98 99
Dimensions without coordinates: time
Data variables:
    foo      (sample, time) float64 0.07604 0.1066 0.7487 ... 0.928 0.538 0.1897