Xbatcher Caching Feature#
This notebook demonstrates the new caching feature added to xbatcher’s BatchGenerator. This feature allows you to cache batches, potentially improving performance for repeated access to the same batches.
Introduction#
The caching feature in xbatcher’s BatchGenerator allows you to store generated batches in a cache, which can significantly speed up subsequent accesses to the same batches. This is particularly useful in scenarios where you need to iterate over the same dataset multiple times.
The cache is pluggable, meaning you can use any dict-like object to store the cache. This flexibility allows for various storage backends, including local storage, distributed storage systems, or cloud storage solutions.
Installation#
To use the caching feature, you’ll need to have xbatcher installed, along with zarr for serialization. If you haven’t already, you can install these using pip:
python -m pip install xbatcher zarr
or
using conda:
conda install -c conda-forge xbatcher zarr
Basic Usage#
Let’s start with a basic example of how to use the caching feature:
[1]:
import tempfile
import xarray as xr
import zarr
import xbatcher
[2]:
# create a cache using Zarr's DirectoryStore
directory = f'{tempfile.mkdtemp()}/xbatcher-cache'
print(directory)
cache = zarr.storage.DirectoryStore(directory)
/tmp/tmphgywdjbm/xbatcher-cache
In this example, we’re using a local directory to store the cache, but you could use any zarr-compatible store, such as S3, Redis, etc.
[3]:
# load a sample dataset
ds = xr.tutorial.open_dataset('air_temperature', chunks={})
ds
[3]:
<xarray.Dataset> Size: 31MB
Dimensions: (lat: 25, time: 2920, lon: 53)
Coordinates:
* lat (lat) float32 100B 75.0 72.5 70.0 67.5 65.0 ... 22.5 20.0 17.5 15.0
* lon (lon) float32 212B 200.0 202.5 205.0 207.5 ... 325.0 327.5 330.0
* time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
air (time, lat, lon) float64 31MB dask.array<chunksize=(2920, 25, 53), meta=np.ndarray>
Attributes:
Conventions: COARDS
title: 4x daily NMC reanalysis (1948)
description: Data is from NMC initialized reanalysis\n(4x/day). These a...
platform: Model
references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...[4]:
# create a BatchGenerator with caching enabled
gen = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10}, cache=cache)
Performance Comparison#
Let’s compare the performance with and without caching:
[5]:
import time
def time_iteration(gen):
start = time.time()
for batch in gen:
pass
end = time.time()
return end - start
[6]:
directory = f'{tempfile.mkdtemp()}/xbatcher-cache'
cache = zarr.storage.DirectoryStore(directory)
# Without cache
gen_no_cache = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10})
time_no_cache = time_iteration(gen_no_cache)
print(f'Time without cache: {time_no_cache:.2f} seconds')
Time without cache: 0.14 seconds
[7]:
# With cache
gen_with_cache = xbatcher.BatchGenerator(
ds, input_dims={'lat': 10, 'lon': 10}, cache=cache
)
time_first_run = time_iteration(gen_with_cache)
print(f'Time with cache (first run): {time_first_run:.2f} seconds')
time_second_run = time_iteration(gen_with_cache)
print(f'Time with cache (second run): {time_second_run:.2f} seconds')
Time with cache (first run): 0.18 seconds
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
Time with cache (second run): 0.07 seconds
You should see that the second run with cache is significantly faster than both the first run and the run without cache.
Advanced Usage#
Custom Cache Preprocessing#
You can also specify a custom preprocessing function to be applied to batches before they are cached:
[8]:
# create a cache using Zarr's DirectoryStore
directory = f'{tempfile.mkdtemp()}/xbatcher-cache'
cache = zarr.storage.DirectoryStore(directory)
def preprocess_batch(batch):
# example: add a new variable to each batch
batch['new_var'] = batch['air'] * 2
return batch
gen_with_preprocess = xbatcher.BatchGenerator(
ds,
input_dims={'lat': 10, 'lon': 10},
cache=cache,
cache_preprocess=preprocess_batch,
)
# Now, each cached batch will include the 'new_var' variable
for batch in gen_with_preprocess:
print(batch)
break
<xarray.Dataset> Size: 5MB
Dimensions: (lat: 10, time: 2920, lon: 10)
Coordinates:
* lat (lat) float32 40B 75.0 72.5 70.0 67.5 65.0 62.5 60.0 57.5 55.0 52.5
* lon (lon) float32 40B 200.0 202.5 205.0 207.5 ... 217.5 220.0 222.5
* time (time) datetime64[ns] 23kB 2013-01-01 ... 2014-12-31T18:00:00
Data variables:
air (time, lat, lon) float64 2MB 241.2 242.5 243.5 ... 281.8 282.3
new_var (time, lat, lon) float64 2MB 482.4 485.0 487.0 ... 563.6 564.6
Attributes:
Conventions: COARDS
title: 4x daily NMC reanalysis (1948)
description: Data is from NMC initialized reanalysis\n(4x/day). These a...
platform: Model
references: http://www.esrl.noaa.gov/psd/data/gridded/data.ncep.reanaly...
/home/docs/checkouts/readthedocs.org/user_builds/xbatcher/conda/latest/lib/python3.10/site-packages/xarray/core/dataset.py:2292: SerializationWarning: saving variable None with floating point data as an integer dtype without any _FillValue to use for NaNs
return to_zarr( # type: ignore[call-overload,misc]
Using Different Storage Backends#
While we’ve been using a local directory for caching, you can use any dict-like that is compatible with zarr. For example, you could use an S3 bucket as the cache storage backend:
import s3fs
import zarr
# Set up S3 filesystem (you'll need appropriate credentials)
s3 = s3fs.S3FileSystem(anon=False)
cache = s3.get_mapper('s3://my-bucket/my-cache.zarr')
# Use this cache with BatchGenerator
gen_s3 = xbatcher.BatchGenerator(ds, input_dims={'lat': 10, 'lon': 10}, cache=cache)
Considerations and Best Practices#
Storage Space: Be mindful of the storage space required for your cache, especially when working with large datasets.
Cache Invalidation: The current implementation doesn’t handle cache invalidation. If your source data changes, you’ll need to manually clear or update the cache.
Performance Tradeoffs: While caching can significantly speed up repeated access to the same data, the initial caching process may be slower than processing without a cache. Consider your use case to determine if caching is beneficial.
Storage Backend: Choose a storage backend that’s appropriate for your use case. Local storage might be fastest for single-machine applications, while distributed or cloud storage might be necessary for cluster computing or cloud-based workflows.