Demo
Contents
Demo#
Author: Cindy Chiao Last Modified: Nov 16, 2021
What is xbatcher?#
Xbatcher is a small library for iterating through xarray objects (DataArrays and Datasets) in batches. The goal is to make it easy to feed xarray objects to machine learning libraries such as Keras and PyTorch.
What is included in this notebook?#
showcase current abilities with example data
brief discussion of current development track and ideas for future work
[1]:
import xarray as xr
import xbatcher
import fsspec
Example data#
Here we will load an example dataset from a global climate model. The data is from the historical experiment from CMIP6 and represents 60 days of daily max air temperature.
[2]:
store = fsspec.get_mapper(
"az://carbonplan-share/example_cmip6_data.zarr", account_name="carbonplan"
)
ds = xr.open_zarr(store, consolidated=True)
# the attributes contain a lot of useful information, but clutter the print out when we inspect the outputs
# throughout this demo, clearing it to avoid confusion
ds.attrs = {}
# inspect the dataset
display(ds)
<xarray.Dataset> Dimensions: (lat: 145, lon: 192, time: 60) Coordinates: * lat (lat) float64 -90.0 -88.75 -87.5 -86.25 ... 86.25 87.5 88.75 90.0 * lon (lon) float64 0.0 1.875 3.75 5.625 7.5 ... 352.5 354.4 356.2 358.1 * time (time) datetime64[ns] 1850-01-01T12:00:00 ... 1850-03-01T12:00:00 Data variables: tasmax (time, lat, lon) float32 dask.array<chunksize=(60, 145, 192), meta=np.ndarray>
- lat: 145
- lon: 192
- time: 60
- lat(lat)float64-90.0 -88.75 -87.5 ... 88.75 90.0
- axis :
- Y
- bounds :
- lat_bnds
- long_name :
- Latitude
- standard_name :
- latitude
- units :
- degrees_north
array([-90. , -88.75, -87.5 , -86.25, -85. , -83.75, -82.5 , -81.25, -80. , -78.75, -77.5 , -76.25, -75. , -73.75, -72.5 , -71.25, -70. , -68.75, -67.5 , -66.25, -65. , -63.75, -62.5 , -61.25, -60. , -58.75, -57.5 , -56.25, -55. , -53.75, -52.5 , -51.25, -50. , -48.75, -47.5 , -46.25, -45. , -43.75, -42.5 , -41.25, -40. , -38.75, -37.5 , -36.25, -35. , -33.75, -32.5 , -31.25, -30. , -28.75, -27.5 , -26.25, -25. , -23.75, -22.5 , -21.25, -20. , -18.75, -17.5 , -16.25, -15. , -13.75, -12.5 , -11.25, -10. , -8.75, -7.5 , -6.25, -5. , -3.75, -2.5 , -1.25, 0. , 1.25, 2.5 , 3.75, 5. , 6.25, 7.5 , 8.75, 10. , 11.25, 12.5 , 13.75, 15. , 16.25, 17.5 , 18.75, 20. , 21.25, 22.5 , 23.75, 25. , 26.25, 27.5 , 28.75, 30. , 31.25, 32.5 , 33.75, 35. , 36.25, 37.5 , 38.75, 40. , 41.25, 42.5 , 43.75, 45. , 46.25, 47.5 , 48.75, 50. , 51.25, 52.5 , 53.75, 55. , 56.25, 57.5 , 58.75, 60. , 61.25, 62.5 , 63.75, 65. , 66.25, 67.5 , 68.75, 70. , 71.25, 72.5 , 73.75, 75. , 76.25, 77.5 , 78.75, 80. , 81.25, 82.5 , 83.75, 85. , 86.25, 87.5 , 88.75, 90. ])
- lon(lon)float640.0 1.875 3.75 ... 356.2 358.1
- axis :
- X
- bounds :
- lon_bnds
- long_name :
- Longitude
- standard_name :
- longitude
- units :
- degrees_east
array([ 0. , 1.875, 3.75 , 5.625, 7.5 , 9.375, 11.25 , 13.125, 15. , 16.875, 18.75 , 20.625, 22.5 , 24.375, 26.25 , 28.125, 30. , 31.875, 33.75 , 35.625, 37.5 , 39.375, 41.25 , 43.125, 45. , 46.875, 48.75 , 50.625, 52.5 , 54.375, 56.25 , 58.125, 60. , 61.875, 63.75 , 65.625, 67.5 , 69.375, 71.25 , 73.125, 75. , 76.875, 78.75 , 80.625, 82.5 , 84.375, 86.25 , 88.125, 90. , 91.875, 93.75 , 95.625, 97.5 , 99.375, 101.25 , 103.125, 105. , 106.875, 108.75 , 110.625, 112.5 , 114.375, 116.25 , 118.125, 120. , 121.875, 123.75 , 125.625, 127.5 , 129.375, 131.25 , 133.125, 135. , 136.875, 138.75 , 140.625, 142.5 , 144.375, 146.25 , 148.125, 150. , 151.875, 153.75 , 155.625, 157.5 , 159.375, 161.25 , 163.125, 165. , 166.875, 168.75 , 170.625, 172.5 , 174.375, 176.25 , 178.125, 180. , 181.875, 183.75 , 185.625, 187.5 , 189.375, 191.25 , 193.125, 195. , 196.875, 198.75 , 200.625, 202.5 , 204.375, 206.25 , 208.125, 210. , 211.875, 213.75 , 215.625, 217.5 , 219.375, 221.25 , 223.125, 225. , 226.875, 228.75 , 230.625, 232.5 , 234.375, 236.25 , 238.125, 240. , 241.875, 243.75 , 245.625, 247.5 , 249.375, 251.25 , 253.125, 255. , 256.875, 258.75 , 260.625, 262.5 , 264.375, 266.25 , 268.125, 270. , 271.875, 273.75 , 275.625, 277.5 , 279.375, 281.25 , 283.125, 285. , 286.875, 288.75 , 290.625, 292.5 , 294.375, 296.25 , 298.125, 300. , 301.875, 303.75 , 305.625, 307.5 , 309.375, 311.25 , 313.125, 315. , 316.875, 318.75 , 320.625, 322.5 , 324.375, 326.25 , 328.125, 330. , 331.875, 333.75 , 335.625, 337.5 , 339.375, 341.25 , 343.125, 345. , 346.875, 348.75 , 350.625, 352.5 , 354.375, 356.25 , 358.125])
- time(time)datetime64[ns]1850-01-01T12:00:00 ... 1850-03-...
- axis :
- T
- bounds :
- time_bnds
- long_name :
- time
- standard_name :
- time
array(['1850-01-01T12:00:00.000000000', '1850-01-02T12:00:00.000000000', '1850-01-03T12:00:00.000000000', '1850-01-04T12:00:00.000000000', '1850-01-05T12:00:00.000000000', '1850-01-06T12:00:00.000000000', '1850-01-07T12:00:00.000000000', '1850-01-08T12:00:00.000000000', '1850-01-09T12:00:00.000000000', '1850-01-10T12:00:00.000000000', '1850-01-11T12:00:00.000000000', '1850-01-12T12:00:00.000000000', '1850-01-13T12:00:00.000000000', '1850-01-14T12:00:00.000000000', '1850-01-15T12:00:00.000000000', '1850-01-16T12:00:00.000000000', '1850-01-17T12:00:00.000000000', '1850-01-18T12:00:00.000000000', '1850-01-19T12:00:00.000000000', '1850-01-20T12:00:00.000000000', '1850-01-21T12:00:00.000000000', '1850-01-22T12:00:00.000000000', '1850-01-23T12:00:00.000000000', '1850-01-24T12:00:00.000000000', '1850-01-25T12:00:00.000000000', '1850-01-26T12:00:00.000000000', '1850-01-27T12:00:00.000000000', '1850-01-28T12:00:00.000000000', '1850-01-29T12:00:00.000000000', '1850-01-30T12:00:00.000000000', '1850-01-31T12:00:00.000000000', '1850-02-01T12:00:00.000000000', '1850-02-02T12:00:00.000000000', '1850-02-03T12:00:00.000000000', '1850-02-04T12:00:00.000000000', '1850-02-05T12:00:00.000000000', '1850-02-06T12:00:00.000000000', '1850-02-07T12:00:00.000000000', '1850-02-08T12:00:00.000000000', '1850-02-09T12:00:00.000000000', '1850-02-10T12:00:00.000000000', '1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', '1850-02-13T12:00:00.000000000', '1850-02-14T12:00:00.000000000', '1850-02-15T12:00:00.000000000', '1850-02-16T12:00:00.000000000', '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000', '1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', '1850-02-23T12:00:00.000000000', '1850-02-24T12:00:00.000000000', '1850-02-25T12:00:00.000000000', '1850-02-26T12:00:00.000000000', '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000'], dtype='datetime64[ns]')
- tasmax(time, lat, lon)float32dask.array<chunksize=(60, 145, 192), meta=np.ndarray>
- cell_measures :
- area: areacella
- cell_methods :
- area: mean time: maximum
- comment :
- maximum near-surface (usually, 2 meter) air temperature (add cell_method attribute 'time: max')
- coordinates :
- height
- history :
- 2019-11-15T17:28:16Z altered by CMOR: Treated scalar dimension: 'height'. 2019-11-15T17:28:16Z altered by CMOR: replaced missing value flag (-1.07374e+09) with standard missing value (1e+20).
- long_name :
- Daily Maximum Near-Surface Air Temperature
- standard_name :
- air_temperature
- units :
- K
Array Chunk Bytes 6.68 MB 6.68 MB Shape (60, 145, 192) (60, 145, 192) Count 2 Tasks 1 Chunks Type float32 numpy.ndarray
[3]:
# plot the first time dimension
ds.isel(time=0).tasmax.plot()
[3]:
<matplotlib.collections.QuadMesh at 0x7f95fc348550>
Batch generation#
Xbatcher’s BatchGenerator
can be used to generate batches with several arguments controlling the exact behavior.
The input_dims
argument takes a dictionary specifying the size of the inputs in each dimension. For example, {'time': 10}
means that each of the input sample will have 10 time points, while all other dimensions are flattened to a “sample” dimension
Note that even though ds
in this case only has one variable, the function can operate on multiple variables at the same time.
[4]:
n_timepoint_in_each_sample = 10
bgen = xbatcher.BatchGenerator(
ds=ds,
input_dims={"time": n_timepoint_in_each_sample},
)
n_batch = 0
for batch in bgen:
n_batch += 1
print(f"{n_batch} batches")
display(batch)
6 batches
<xarray.Dataset> Dimensions: (sample: 27840, time: 10) Coordinates: * time (time) datetime64[ns] 1850-02-20T12:00:00 ... 1850-03-01T12:00:00 * sample (sample) MultiIndex - lat (sample) float64 -90.0 -90.0 -90.0 -90.0 ... 90.0 90.0 90.0 90.0 - lon (sample) float64 0.0 1.875 3.75 5.625 ... 352.5 354.4 356.2 358.1 Data variables: tasmax (sample, time) float32 226.1 226.2 224.0 ... 251.5 245.5 242.9
- sample: 27840
- time: 10
- time(time)datetime64[ns]1850-02-20T12:00:00 ... 1850-03-...
- axis :
- T
- bounds :
- time_bnds
- long_name :
- time
- standard_name :
- time
array(['1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', '1850-02-23T12:00:00.000000000', '1850-02-24T12:00:00.000000000', '1850-02-25T12:00:00.000000000', '1850-02-26T12:00:00.000000000', '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000'], dtype='datetime64[ns]')
- sample(sample)MultiIndex(lat, lon)
array([(-90.0, 0.0), (-90.0, 1.875), (-90.0, 3.75), ..., (90.0, 354.375), (90.0, 356.25), (90.0, 358.125)], dtype=object)
- lat(sample)float64-90.0 -90.0 -90.0 ... 90.0 90.0
array([-90., -90., -90., ..., 90., 90., 90.])
- lon(sample)float640.0 1.875 3.75 ... 356.2 358.1
array([ 0. , 1.875, 3.75 , ..., 354.375, 356.25 , 358.125])
- tasmax(sample, time)float32226.1 226.2 224.0 ... 245.5 242.9
- cell_measures :
- area: areacella
- cell_methods :
- area: mean time: maximum
- comment :
- maximum near-surface (usually, 2 meter) air temperature (add cell_method attribute 'time: max')
- coordinates :
- height
- history :
- 2019-11-15T17:28:16Z altered by CMOR: Treated scalar dimension: 'height'. 2019-11-15T17:28:16Z altered by CMOR: replaced missing value flag (-1.07374e+09) with standard missing value (1e+20).
- long_name :
- Daily Maximum Near-Surface Air Temperature
- standard_name :
- air_temperature
- units :
- K
array([[226.11433, 226.19003, 224.03416, ..., 226.06456, 226.33026, 226.99207], [226.11433, 226.19003, 224.03416, ..., 226.06456, 226.33026, 226.99207], [226.11433, 226.19003, 224.03416, ..., 226.06456, 226.33026, 226.99207], ..., [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855], [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855], [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855]], dtype=float32)
We can verify that the outputs have the expected shapes.
For example, there are 60 time points in our input dataset, we’re asking 10 timepoints in each batch, thus expecting 6 batches
[5]:
expected_n_batch = len(ds.time) / n_timepoint_in_each_sample
print(f"Expecting {expected_n_batch} batches, getting {n_batch} batches")
Expecting 6.0 batches, getting 6 batches
There are 145 lat points and 192 lon points, thus we’re expecting 145 * 192 = 27840 samples in a batch.
[6]:
expected_batch_size = len(ds.lat) * len(ds.lon)
print(
f"Expecting {expected_batch_size} samples per batch, getting {len(batch.sample)} samples per batch"
)
Expecting 27840 samples per batch, getting 27840 samples per batch
Controlling the size/shape of batches#
We can use batch_dims
and concat_input_dims
options to control how many sample ends up in each batch. For example, we can specify 10 time points for each sample, but 20 time points in each batch this should yield half as many batches and twice as many samples in a batch as the example above note the difference in dimension name in this case
[7]:
n_timepoint_in_each_sample = 10
n_timepoint_in_each_batch = 20
bgen = xbatcher.BatchGenerator(
ds=ds,
input_dims={"time": n_timepoint_in_each_sample},
batch_dims={"time": n_timepoint_in_each_batch},
concat_input_dims=True,
)
n_batch = 0
for batch in bgen:
n_batch += 1
print(f"{n_batch} batches")
display(batch)
3 batches
<xarray.Dataset> Dimensions: (sample: 55680, time_input: 10) Coordinates: time (sample, time_input) datetime64[ns] 1850-02-10T12:00:00 ... ... * sample (sample) MultiIndex - input_batch (sample) int64 0 0 0 0 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 1 1 1 1 1 - lat (sample) float64 -90.0 -90.0 -90.0 -90.0 ... 90.0 90.0 90.0 - lon (sample) float64 0.0 1.875 3.75 5.625 ... 354.4 356.2 358.1 Dimensions without coordinates: time_input Data variables: tasmax (sample, time_input) float32 238.8 235.2 234.7 ... 245.5 242.9
- sample: 55680
- time_input: 10
- time(sample, time_input)datetime64[ns]1850-02-10T12:00:00 ... 1850-03-...
- axis :
- T
- bounds :
- time_bnds
- long_name :
- time
- standard_name :
- time
array([['1850-02-10T12:00:00.000000000', '1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', ..., '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000'], ['1850-02-10T12:00:00.000000000', '1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', ..., '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000'], ['1850-02-10T12:00:00.000000000', '1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', ..., '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000'], ..., ['1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', ..., '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000'], ['1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', ..., '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000'], ['1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', ..., '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000']], dtype='datetime64[ns]')
- sample(sample)MultiIndex(input_batch, lat, lon)
array([(0, -90.0, 0.0), (0, -90.0, 1.875), (0, -90.0, 3.75), ..., (1, 90.0, 354.375), (1, 90.0, 356.25), (1, 90.0, 358.125)], dtype=object)
- input_batch(sample)int640 0 0 0 0 0 0 0 ... 1 1 1 1 1 1 1 1
array([0, 0, 0, ..., 1, 1, 1])
- lat(sample)float64-90.0 -90.0 -90.0 ... 90.0 90.0
array([-90., -90., -90., ..., 90., 90., 90.])
- lon(sample)float640.0 1.875 3.75 ... 356.2 358.1
array([ 0. , 1.875, 3.75 , ..., 354.375, 356.25 , 358.125])
- tasmax(sample, time_input)float32238.8 235.2 234.7 ... 245.5 242.9
- cell_measures :
- area: areacella
- cell_methods :
- area: mean time: maximum
- comment :
- maximum near-surface (usually, 2 meter) air temperature (add cell_method attribute 'time: max')
- coordinates :
- height
- history :
- 2019-11-15T17:28:16Z altered by CMOR: Treated scalar dimension: 'height'. 2019-11-15T17:28:16Z altered by CMOR: replaced missing value flag (-1.07374e+09) with standard missing value (1e+20).
- long_name :
- Daily Maximum Near-Surface Air Temperature
- standard_name :
- air_temperature
- units :
- K
array([[238.76602, 235.20686, 234.71843, ..., 234.75949, 228.80817, 225.40395], [238.76602, 235.20686, 234.71843, ..., 234.75949, 228.80817, 225.40395], [238.76602, 235.20686, 234.71843, ..., 234.75949, 228.80817, 225.40395], ..., [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855], [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855], [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855]], dtype=float32)
Last batch behavior#
If the input ds is not divisible by the specified input_dims
, the remainder will be discarded instead of having a fractional batch. See https://github.com/xarray-contrib/xbatcher/issues/5 for more on this topic.
[8]:
n_timepoint_in_batch = 31
bgen = xbatcher.BatchGenerator(ds=ds, input_dims={"time": n_timepoint_in_batch})
for batch in bgen:
print(f"last time point in ds is {ds.time[-1].values}")
print(f"last time point in batch is {batch.time[-1].values}")
display(batch)
last time point in ds is 1850-03-01T12:00:00.000000000
last time point in batch is 1850-01-31T12:00:00.000000000
<xarray.Dataset> Dimensions: (sample: 27840, time: 31) Coordinates: * time (time) datetime64[ns] 1850-01-01T12:00:00 ... 1850-01-31T12:00:00 * sample (sample) MultiIndex - lat (sample) float64 -90.0 -90.0 -90.0 -90.0 ... 90.0 90.0 90.0 90.0 - lon (sample) float64 0.0 1.875 3.75 5.625 ... 352.5 354.4 356.2 358.1 Data variables: tasmax (sample, time) float32 252.6 250.9 250.4 ... 257.6 256.9 243.3
- sample: 27840
- time: 31
- time(time)datetime64[ns]1850-01-01T12:00:00 ... 1850-01-...
- axis :
- T
- bounds :
- time_bnds
- long_name :
- time
- standard_name :
- time
array(['1850-01-01T12:00:00.000000000', '1850-01-02T12:00:00.000000000', '1850-01-03T12:00:00.000000000', '1850-01-04T12:00:00.000000000', '1850-01-05T12:00:00.000000000', '1850-01-06T12:00:00.000000000', '1850-01-07T12:00:00.000000000', '1850-01-08T12:00:00.000000000', '1850-01-09T12:00:00.000000000', '1850-01-10T12:00:00.000000000', '1850-01-11T12:00:00.000000000', '1850-01-12T12:00:00.000000000', '1850-01-13T12:00:00.000000000', '1850-01-14T12:00:00.000000000', '1850-01-15T12:00:00.000000000', '1850-01-16T12:00:00.000000000', '1850-01-17T12:00:00.000000000', '1850-01-18T12:00:00.000000000', '1850-01-19T12:00:00.000000000', '1850-01-20T12:00:00.000000000', '1850-01-21T12:00:00.000000000', '1850-01-22T12:00:00.000000000', '1850-01-23T12:00:00.000000000', '1850-01-24T12:00:00.000000000', '1850-01-25T12:00:00.000000000', '1850-01-26T12:00:00.000000000', '1850-01-27T12:00:00.000000000', '1850-01-28T12:00:00.000000000', '1850-01-29T12:00:00.000000000', '1850-01-30T12:00:00.000000000', '1850-01-31T12:00:00.000000000'], dtype='datetime64[ns]')
- sample(sample)MultiIndex(lat, lon)
array([(-90.0, 0.0), (-90.0, 1.875), (-90.0, 3.75), ..., (90.0, 354.375), (90.0, 356.25), (90.0, 358.125)], dtype=object)
- lat(sample)float64-90.0 -90.0 -90.0 ... 90.0 90.0
array([-90., -90., -90., ..., 90., 90., 90.])
- lon(sample)float640.0 1.875 3.75 ... 356.2 358.1
array([ 0. , 1.875, 3.75 , ..., 354.375, 356.25 , 358.125])
- tasmax(sample, time)float32252.6 250.9 250.4 ... 256.9 243.3
- cell_measures :
- area: areacella
- cell_methods :
- area: mean time: maximum
- comment :
- maximum near-surface (usually, 2 meter) air temperature (add cell_method attribute 'time: max')
- coordinates :
- height
- history :
- 2019-11-15T17:28:16Z altered by CMOR: Treated scalar dimension: 'height'. 2019-11-15T17:28:16Z altered by CMOR: replaced missing value flag (-1.07374e+09) with standard missing value (1e+20).
- long_name :
- Daily Maximum Near-Surface Air Temperature
- standard_name :
- air_temperature
- units :
- K
array([[252.5753 , 250.88838, 250.43263, ..., 242.25215, 241.18654, 239.78615], [252.5753 , 250.88838, 250.43263, ..., 242.25215, 241.18654, 239.78615], [252.5753 , 250.88838, 250.43263, ..., 242.25215, 241.18654, 239.78615], ..., [247.44861, 248.71996, 241.68065, ..., 257.61612, 256.9255 , 243.33928], [247.44861, 248.71996, 241.68065, ..., 257.61612, 256.9255 , 243.33928], [247.44861, 248.71996, 241.68065, ..., 257.61612, 256.9255 , 243.33928]], dtype=float32)
Overlapping inputs#
In the example above, all samples have distinct time points. That is, for any lat/lon pixel, sample 1 has time points 1-10, sample 2 has time point 11-20, and they do not overlap however, in many machine learning applications, we will want overlapping samples (e.g. sample 1 has time points 1-10, sample 2 has time points 2-11, and so on). We can use the input_overlap
argument to get this behavior.
[9]:
n_timepoint_in_each_sample = 10
n_timepoint_in_each_batch = 20
input_overlap = 9
bgen = xbatcher.BatchGenerator(
ds=ds,
input_dims={"time": n_timepoint_in_each_sample},
batch_dims={"time": n_timepoint_in_each_batch},
concat_input_dims=True,
input_overlap={"time": input_overlap},
)
n_batch = 0
for batch in bgen:
n_batch += 1
print(f"{n_batch} batches")
batch
3 batches
[9]:
<xarray.Dataset> Dimensions: (sample: 306240, time_input: 10) Coordinates: time (sample, time_input) datetime64[ns] 1850-02-10T12:00:00 ... ... * sample (sample) MultiIndex - input_batch (sample) int64 0 0 0 0 0 0 0 0 0 ... 10 10 10 10 10 10 10 10 10 - lat (sample) float64 -90.0 -90.0 -90.0 -90.0 ... 90.0 90.0 90.0 - lon (sample) float64 0.0 1.875 3.75 5.625 ... 354.4 356.2 358.1 Dimensions without coordinates: time_input Data variables: tasmax (sample, time_input) float32 238.8 235.2 234.7 ... 245.5 242.9
- sample: 306240
- time_input: 10
- time(sample, time_input)datetime64[ns]1850-02-10T12:00:00 ... 1850-03-...
- axis :
- T
- bounds :
- time_bnds
- long_name :
- time
- standard_name :
- time
array([['1850-02-10T12:00:00.000000000', '1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', ..., '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000'], ['1850-02-10T12:00:00.000000000', '1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', ..., '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000'], ['1850-02-10T12:00:00.000000000', '1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', ..., '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000'], ..., ['1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', ..., '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000'], ['1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', ..., '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000'], ['1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', ..., '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000']], dtype='datetime64[ns]')
- sample(sample)MultiIndex(input_batch, lat, lon)
array([(0, -90.0, 0.0), (0, -90.0, 1.875), (0, -90.0, 3.75), ..., (10, 90.0, 354.375), (10, 90.0, 356.25), (10, 90.0, 358.125)], dtype=object)
- input_batch(sample)int640 0 0 0 0 0 0 ... 10 10 10 10 10 10
array([ 0, 0, 0, ..., 10, 10, 10])
- lat(sample)float64-90.0 -90.0 -90.0 ... 90.0 90.0
array([-90., -90., -90., ..., 90., 90., 90.])
- lon(sample)float640.0 1.875 3.75 ... 356.2 358.1
array([ 0. , 1.875, 3.75 , ..., 354.375, 356.25 , 358.125])
- tasmax(sample, time_input)float32238.8 235.2 234.7 ... 245.5 242.9
- cell_measures :
- area: areacella
- cell_methods :
- area: mean time: maximum
- comment :
- maximum near-surface (usually, 2 meter) air temperature (add cell_method attribute 'time: max')
- coordinates :
- height
- history :
- 2019-11-15T17:28:16Z altered by CMOR: Treated scalar dimension: 'height'. 2019-11-15T17:28:16Z altered by CMOR: replaced missing value flag (-1.07374e+09) with standard missing value (1e+20).
- long_name :
- Daily Maximum Near-Surface Air Temperature
- standard_name :
- air_temperature
- units :
- K
array([[238.76602, 235.20686, 234.71843, ..., 234.75949, 228.80817, 225.40395], [238.76602, 235.20686, 234.71843, ..., 234.75949, 228.80817, 225.40395], [238.76602, 235.20686, 234.71843, ..., 234.75949, 228.80817, 225.40395], ..., [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855], [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855], [238.86842, 240.81023, 242.40994, ..., 251.47948, 245.45798, 242.93855]], dtype=float32)
We can inspect the samples in a batch for a lat/lon pixel, noting that the overlap only applies within a batch and not across. Thus, within the 20 time points in a batch, we can get 11 samples each with 10 time points and 9 time points allowed to overlap.
[10]:
lat = -90
lon = 0
pixel = batch.sel(lat=lat, lon=lon)
display(pixel)
print(
f"sample 1 goes from {pixel.isel(input_batch=0).time[0].values} to {pixel.isel(input_batch=0).time[-1].values}"
)
print(
f"sample 2 goes from {pixel.isel(input_batch=1).time[0].values} to {pixel.isel(input_batch=1).time[-1].values}"
)
<xarray.Dataset> Dimensions: (input_batch: 11, time_input: 10) Coordinates: time (input_batch, time_input) datetime64[ns] 1850-02-10T12:00:00... * input_batch (input_batch) int64 0 1 2 3 4 5 6 7 8 9 10 Dimensions without coordinates: time_input Data variables: tasmax (input_batch, time_input) float32 238.8 235.2 ... 226.3 227.0
- input_batch: 11
- time_input: 10
- time(input_batch, time_input)datetime64[ns]1850-02-10T12:00:00 ... 1850-03-...
- axis :
- T
- bounds :
- time_bnds
- long_name :
- time
- standard_name :
- time
array([['1850-02-10T12:00:00.000000000', '1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', '1850-02-13T12:00:00.000000000', '1850-02-14T12:00:00.000000000', '1850-02-15T12:00:00.000000000', '1850-02-16T12:00:00.000000000', '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000'], ['1850-02-11T12:00:00.000000000', '1850-02-12T12:00:00.000000000', '1850-02-13T12:00:00.000000000', '1850-02-14T12:00:00.000000000', '1850-02-15T12:00:00.000000000', '1850-02-16T12:00:00.000000000', '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000', '1850-02-20T12:00:00.000000000'], ['1850-02-12T12:00:00.000000000', '1850-02-13T12:00:00.000000000', '1850-02-14T12:00:00.000000000', '1850-02-15T12:00:00.000000000', '1850-02-16T12:00:00.000000000', '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000', '1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000'], ['1850-02-13T12:00:00.000000000', '1850-02-14T12:00:00.000000000', '1850-02-15T12:00:00.000000000', '1850-02-16T12:00:00.000000000', '1850-02-17T12:00:00.000000000', '1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000', '1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000'], ... '1850-02-19T12:00:00.000000000', '1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', '1850-02-23T12:00:00.000000000', '1850-02-24T12:00:00.000000000', '1850-02-25T12:00:00.000000000', '1850-02-26T12:00:00.000000000'], ['1850-02-18T12:00:00.000000000', '1850-02-19T12:00:00.000000000', '1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', '1850-02-23T12:00:00.000000000', '1850-02-24T12:00:00.000000000', '1850-02-25T12:00:00.000000000', '1850-02-26T12:00:00.000000000', '1850-02-27T12:00:00.000000000'], ['1850-02-19T12:00:00.000000000', '1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', '1850-02-23T12:00:00.000000000', '1850-02-24T12:00:00.000000000', '1850-02-25T12:00:00.000000000', '1850-02-26T12:00:00.000000000', '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000'], ['1850-02-20T12:00:00.000000000', '1850-02-21T12:00:00.000000000', '1850-02-22T12:00:00.000000000', '1850-02-23T12:00:00.000000000', '1850-02-24T12:00:00.000000000', '1850-02-25T12:00:00.000000000', '1850-02-26T12:00:00.000000000', '1850-02-27T12:00:00.000000000', '1850-02-28T12:00:00.000000000', '1850-03-01T12:00:00.000000000']], dtype='datetime64[ns]')
- input_batch(input_batch)int640 1 2 3 4 5 6 7 8 9 10
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
- tasmax(input_batch, time_input)float32238.8 235.2 234.7 ... 226.3 227.0
- cell_measures :
- area: areacella
- cell_methods :
- area: mean time: maximum
- comment :
- maximum near-surface (usually, 2 meter) air temperature (add cell_method attribute 'time: max')
- coordinates :
- height
- history :
- 2019-11-15T17:28:16Z altered by CMOR: Treated scalar dimension: 'height'. 2019-11-15T17:28:16Z altered by CMOR: replaced missing value flag (-1.07374e+09) with standard missing value (1e+20).
- long_name :
- Daily Maximum Near-Surface Air Temperature
- standard_name :
- air_temperature
- units :
- K
array([[238.76602, 235.20686, 234.71843, 233.27286, 235.85324, 236.41687, 235.92152, 234.75949, 228.80817, 225.40395], [235.20686, 234.71843, 233.27286, 235.85324, 236.41687, 235.92152, 234.75949, 228.80817, 225.40395, 226.11433], [234.71843, 233.27286, 235.85324, 236.41687, 235.92152, 234.75949, 228.80817, 225.40395, 226.11433, 226.19003], [233.27286, 235.85324, 236.41687, 235.92152, 234.75949, 228.80817, 225.40395, 226.11433, 226.19003, 224.03416], [235.85324, 236.41687, 235.92152, 234.75949, 228.80817, 225.40395, 226.11433, 226.19003, 224.03416, 240.60652], [236.41687, 235.92152, 234.75949, 228.80817, 225.40395, 226.11433, 226.19003, 224.03416, 240.60652, 241.64258], [235.92152, 234.75949, 228.80817, 225.40395, 226.11433, 226.19003, 224.03416, 240.60652, 241.64258, 240.7453 ], [234.75949, 228.80817, 225.40395, 226.11433, 226.19003, 224.03416, 240.60652, 241.64258, 240.7453 , 233.43279], [228.80817, 225.40395, 226.11433, 226.19003, 224.03416, 240.60652, 241.64258, 240.7453 , 233.43279, 226.06456], [225.40395, 226.11433, 226.19003, 224.03416, 240.60652, 241.64258, 240.7453 , 233.43279, 226.06456, 226.33026], [226.11433, 226.19003, 224.03416, 240.60652, 241.64258, 240.7453 , 233.43279, 226.06456, 226.33026, 226.99207]], dtype=float32)
sample 1 goes from 1850-02-10T12:00:00.000000000 to 1850-02-19T12:00:00.000000000
sample 2 goes from 1850-02-11T12:00:00.000000000 to 1850-02-20T12:00:00.000000000
Example applications#
These batches can then be used to train a downstream machine learning model while preserving the indices of these sample.
As an example, let’s say we want to train a simple CNN model to predict the max air temprature for each day at each lat/lon pixel. To predict the temperature at lat/lon/time of (i, j, t), we’ll use features including the temperature of a 9 x 9 grid centered at (i, j), from times t-10 to t-1 (shape of input should be (n_samples_in_each_batch, 9, 9, 9)). Note that in this example, we subset the dataset to a smaller domain for efficiency.
[11]:
bgen = xbatcher.BatchGenerator(
ds=ds[["tasmax"]].isel(lat=slice(0, 18), lon=slice(0, 18), time=slice(0, 30)),
input_dims={"lat": 9, "lon": 9, "time": 10},
batch_dims={"lat": 18, "lon": 18, "time": 15},
concat_input_dims=True,
input_overlap={"lat": 8, "lon": 8, "time": 9},
)
for i, batch in enumerate(bgen):
print(f"batch {i}")
# make sure the ordering of dimension is consistent
batch = batch.transpose("input_batch", "lat_input", "lon_input", "time_input")
# only use the first 9 time points as features, since the last time point is the label to be predicted
features = batch.tasmax.isel(time_input=slice(0, 9))
# select the center pixel at the last time point to be the label to be predicted
# the actual lat/lon/time for each of the sample can be accessed in labels.coords
labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)
print("feature shape", features.shape)
print("label shape", labels.shape)
print("shape of lat of each sample", labels.coords["lat"].shape)
print("")
batch 0
feature shape (600, 9, 9, 9)
label shape (600,)
shape of lat of each sample (600,)
batch 1
feature shape (600, 9, 9, 9)
label shape (600,)
shape of lat of each sample (600,)
We can also use the Xarray’s “stack” method to transform these into 2D inputs (n_samples, n_features) suitable for other machine learning algorithms implemented in libraries such as sklearn and xgboost. In this case, we are expecting 9 x 9 x 9 = 729 features total.
[12]:
for i, batch in enumerate(bgen):
print(f"batch {i}")
# make sure the ordering of dimension is consistent
batch = batch.transpose("input_batch", "lat_input", "lon_input", "time_input")
# only use the first 9 time points as features, since the last time point is the label to be predicted
features = batch.tasmax.isel(time_input=slice(0, 9))
features = features.stack(features=["lat_input", "lon_input", "time_input"])
# select the center pixel at the last time point to be the label to be predicted
# the actual lat/lon/time for each of the sample can be accessed in labels.coords
labels = batch.tasmax.isel(lat_input=5, lon_input=5, time_input=9)
print("feature shape", features.shape)
print("label shape", labels.shape)
print("shape of lat of each sample", labels.coords["lat"].shape, "\n")
batch 0
feature shape (600, 729)
label shape (600,)
shape of lat of each sample (600,)
batch 1
feature shape (600, 729)
label shape (600,)
shape of lat of each sample (600,)
What’s next?#
There are many additional useful features that were yet to be implemented in the context of batch generation for downstream machine learning model training purposes. One of the current efforts is to improve the set of data loaders.
Additional features of interest can include:
Handling overlaps across batches. The common use case of batching in machine learning training involves generating all samples, then group them into batches. When overlap is enabled, this yields different results compared to first generating batches then creating possible samples within each batch.
Shuffling/randomization of samples across batches. It is often desirable for each batch to be grouped randomly instead of along a specific dimension.
Be efficient in terms of memory usage. In the case where overlap is enabled, each sample would comprised of mostly repetitive values compared to adjacent samples. It would be beneficial if each batch/sample is generated lazily to avoid storing these extra duplicative values.
Handling preprocessing steps. For example, data augmentation, scaling/normalization, outlier detection, etc.
More thoughts on 1. and 2. can be found in this issue. Interested users are welcomed to comment or submit other issues in GitHub.