Source code for xbatcher.generators

"""Classes for iterating through xarray datarrays / datasets in batches."""

import itertools
import warnings
from operator import itemgetter
from typing import (

import numpy as np
import xarray as xr

PatchGenerator = Iterator[Dict[Hashable, slice]]
BatchSelector = List[Dict[Hashable, slice]]
BatchSelectorSet = Dict[int, BatchSelector]

[docs]class BatchSchema: """ A representation of the indices and stacking/transposing parameters needed to generator batches from Xarray DataArrays and Datasets using xbatcher.BatchGenerator. Parameters ---------- ds : ``xarray.Dataset`` or ``xarray.DataArray`` The data to iterate over. Unlike for the BatchGenerator, the data is not retained as a class attribute for the BatchSchema. input_dims : dict A dictionary specifying the size of the inputs in each dimension, e.g. ``{'lat': 30, 'lon': 30}`` These are the dimensions the ML library will see. All other dimensions will be stacked into one dimension called ``sample``. input_overlap : dict, optional A dictionary specifying the overlap along each dimension e.g. ``{'lat': 3, 'lon': 3}`` batch_dims : dict, optional A dictionary specifying the size of the batch along each dimension e.g. ``{'time': 10}``. These will always be iterated over. concat_input_dims : bool, optional If ``True``, the dimension chunks specified in ``input_dims`` will be concatenated and stacked into the ``sample`` dimension. The batch index will be included as a new level ``input_batch`` in the ``sample`` coordinate. If ``False``, the dimension chunks specified in ``input_dims`` will be iterated over. preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. Notes ----- The BatchSchema is experimental and subject to change without notice. """
[docs] def __init__( self, ds: Union[xr.Dataset, xr.DataArray], input_dims: Dict[Hashable, int], input_overlap: Optional[Dict[Hashable, int]] = None, batch_dims: Optional[Dict[Hashable, int]] = None, concat_input_bins: bool = True, preload_batch: bool = True, ): if input_overlap is None: input_overlap = {} if batch_dims is None: batch_dims = {} self.input_dims = dict(input_dims) self.input_overlap = input_overlap self.batch_dims = dict(batch_dims) self.concat_input_dims = concat_input_bins self.preload_batch = preload_batch # Store helpful information based on arguments self._duplicate_batch_dims: Dict[Hashable, int] = { dim: length for dim, length in self.batch_dims.items() if self.input_dims.get(dim) is not None } self._unique_batch_dims: Dict[Hashable, int] = { dim: length for dim, length in self.batch_dims.items() if self.input_dims.get(dim) is None } self._input_stride: Dict[Hashable, int] = { dim: length - self.input_overlap.get(dim, 0) for dim, length in self.input_dims.items() } self._all_sliced_dims: Dict[Hashable, int] = dict( **self._unique_batch_dims, **self.input_dims ) self.selectors: BatchSelectorSet = self._gen_batch_selectors(ds)
def _gen_batch_selectors( self, ds: Union[xr.DataArray, xr.Dataset] ) -> BatchSelectorSet: """ Create batch selectors dict, which can be used to create a batch from an Xarray data object. """ # Create an iterator that returns an object usable for .isel in xarray patch_selectors = self._gen_patch_selectors(ds) # Create the Dict containing batch selectors if self.concat_input_dims: # Combine the patches into batches return self._combine_patches_into_batch(ds, patch_selectors) else: # Each patch gets its own batch return {ind: [value] for ind, value in enumerate(patch_selectors)} def _gen_patch_selectors( self, ds: Union[xr.DataArray, xr.Dataset] ) -> PatchGenerator: """ Create an iterator that can be used to index an Xarray Dataset/DataArray. """ if self._duplicate_batch_dims and not self.concat_input_dims: warnings.warn( "The following dimensions were included in both ``input_dims`` " "and ``batch_dims``. Since ``concat_input_dims`` is ``False``, " f"these dimensions will not impact batch generation: {self._duplicate_batch_dims}" ) # Generate the slices by iterating over batch_dims and input_dims all_slices = _iterate_through_dimensions( ds, dims=self._all_sliced_dims, overlap=self.input_overlap, ) return all_slices def _combine_patches_into_batch( self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator ) -> BatchSelectorSet: """ Combine the patch selectors to form a batch """ # Check that patches are only combined with concat_input_dims if not self.concat_input_dims: raise AssertionError( "Patches should only be combined into batches when ``concat_input_dims`` is ``True``" ) if not self.batch_dims: return self._combine_patches_into_one_batch(patch_selectors) elif self._duplicate_batch_dims: return self._combine_patches_grouped_by_input_and_batch_dims( ds=ds, patch_selectors=patch_selectors ) else: return self._combine_patches_grouped_by_batch_dims(patch_selectors) def _combine_patches_into_one_batch( self, patch_selectors: PatchGenerator ) -> BatchSelectorSet: """ Group all patches into a single batch """ return dict(enumerate([list(patch_selectors)])) def _combine_patches_grouped_by_batch_dims( self, patch_selectors: PatchGenerator ) -> BatchSelectorSet: """ Group patches based on the unique slices for dimensions in ``batch_dims`` """ batch_selectors = [ list(value) for _, value in itertools.groupby( patch_selectors, key=itemgetter(*self.batch_dims) ) ] return dict(enumerate(batch_selectors)) def _combine_patches_grouped_by_input_and_batch_dims( self, ds: Union[xr.DataArray, xr.Dataset], patch_selectors: PatchGenerator ) -> BatchSelectorSet: """ Combine patches with multiple slices along ``batch_dims`` grouped into each patch. Required when a dimension is duplicated between ``batch_dims`` and ``input_dims``. """ self._gen_patch_numbers(ds) self._gen_batch_numbers(ds) batch_id_per_patch = self._get_batch_multi_index_per_patch() patch_in_range = self._get_batch_in_range_per_batch( batch_multi_index=batch_id_per_patch ) batch_id_per_patch = self._ravel_batch_multi_index(batch_id_per_patch) batch_selectors = self._gen_empty_batch_selectors() for i, patch in enumerate(patch_selectors): if patch_in_range[i]: batch_selectors[batch_id_per_patch[i]].append(patch) return batch_selectors def _gen_empty_batch_selectors(self) -> BatchSelectorSet: """ Create an empty batch selector set that can be populated by appending patches to each batch. """ n_batches = return {k: [] for k in range(n_batches)} def _gen_patch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): """ Calculate the number of patches per dimension and the number of patches in each batch per dimension. """ self._n_patches_per_batch: Dict[Hashable, int] = { dim: int(np.ceil(length / self._input_stride.get(dim, length))) for dim, length in self.batch_dims.items() } self._n_patches_per_dim: Dict[Hashable, int] = { dim: int( (ds.sizes[dim] - self.input_overlap.get(dim, 0)) // (length - self.input_overlap.get(dim, 0)) ) for dim, length in self._all_sliced_dims.items() } def _gen_batch_numbers(self, ds: Union[xr.DataArray, xr.Dataset]): """ Calculate the number of batches per dimension """ self._n_batches_per_dim: Dict[Hashable, int] = { dim: int(ds.sizes[dim] // self.batch_dims.get(dim, ds.sizes[dim])) for dim in self._all_sliced_dims.keys() } def _get_batch_multi_index_per_patch(self): """ Calculate the batch multi-index for each patch """ batch_id_per_dim: Dict[Hashable, Any] = { dim: np.floor( np.arange(0, n_patches) / self._n_patches_per_batch.get(dim, n_patches + 1) ).astype(np.int64) for dim, n_patches in self._n_patches_per_dim.items() } batch_id_per_patch = np.array( list(itertools.product(*batch_id_per_dim.values())) ).transpose() return batch_id_per_patch def _ravel_batch_multi_index(self, batch_multi_index): """ Convert the batch multi-index to a flat index for each patch """ return np.ravel_multi_index( multi_index=batch_multi_index, dims=tuple(self._n_batches_per_dim.values()), mode="clip", ) def _get_batch_in_range_per_batch(self, batch_multi_index): """ Determine whether each patch is contained within any of the batches. """ batch_id_maximum = np.fromiter(self._n_batches_per_dim.values(), dtype=int) batch_id_maximum = np.pad( batch_id_maximum, (0, (len(self._n_patches_per_dim) - len(self._n_batches_per_dim))), constant_values=(1), ) batch_id_maximum = batch_id_maximum[:, np.newaxis] batch_in_range_per_patch = np.all(batch_multi_index < batch_id_maximum, axis=0) return batch_in_range_per_patch
def _gen_slices(*, dim_size: int, slice_size: int, overlap: int = 0) -> List[slice]: # return a list of slices to chop up a single dimension if overlap >= slice_size: raise ValueError( "input overlap must be less than the input sample length, but " f"the input sample length is {slice_size} and the overlap is {overlap}" ) slices = [] stride = slice_size - overlap for start in range(0, dim_size, stride): end = start + slice_size if end <= dim_size: slices.append(slice(start, end)) return slices def _iterate_through_dimensions( ds: Union[xr.Dataset, xr.DataArray], *, dims: Dict[Hashable, int], overlap: Dict[Hashable, int] = {}, ) -> Iterator[Dict[Hashable, slice]]: dim_slices = [] for dim in dims: dim_size = ds.sizes[dim] slice_size = dims[dim] slice_overlap = overlap.get(dim, 0) if slice_size > dim_size: raise ValueError( "input sample length must be less than or equal to the " f"dimension length, but the sample length of {slice_size} " f"is greater than the dimension length of {dim_size} " f"for {dim}" ) dim_slices.append( _gen_slices(dim_size=dim_size, slice_size=slice_size, overlap=slice_overlap) ) for slices in itertools.product(*dim_slices): selector = dict(zip(dims, slices)) yield selector def _drop_input_dims( ds: Union[xr.Dataset, xr.DataArray], input_dims: Dict[Hashable, int], suffix: str = "_input", ) -> Union[xr.Dataset, xr.DataArray]: # remove input_dims coordinates from datasets, rename the dimensions # then put intput_dims back in as coordinates out = ds.copy() for dim in input_dims.keys(): newdim = f"{dim}{suffix}" out = out.rename({dim: newdim}) # extra steps needed if there is a coordinate if newdim in out: out = out.drop_vars(newdim) out.coords[dim] = newdim, ds[dim].data, ds[dim].attrs return out def _maybe_stack_batch_dims( ds: Union[xr.Dataset, xr.DataArray], input_dims: Sequence[Hashable], ) -> Union[xr.Dataset, xr.DataArray]: batch_dims = [d for d in ds.sizes if d not in input_dims] if len(batch_dims) < 2: return ds ds_stack = ds.stack(sample=batch_dims) # ensure correct order dim_order = ("sample",) + tuple(input_dims) return ds_stack.transpose(*dim_order)
[docs]class BatchGenerator: """Create generator for iterating through Xarray DataArrays / Datasets in batches. Parameters ---------- ds : ``xarray.Dataset`` or ``xarray.DataArray`` The data to iterate over input_dims : dict A dictionary specifying the size of the inputs in each dimension, e.g. ``{'lat': 30, 'lon': 30}`` These are the dimensions the ML library will see. All other dimensions will be stacked into one dimension called ``sample``. input_overlap : dict, optional A dictionary specifying the overlap along each dimension e.g. ``{'lat': 3, 'lon': 3}`` batch_dims : dict, optional A dictionary specifying the size of the batch along each dimension e.g. ``{'time': 10}``. These will always be iterated over. concat_input_dims : bool, optional If ``True``, the dimension chunks specified in ``input_dims`` will be concatenated and stacked into the ``sample`` dimension. The batch index will be included as a new level ``input_batch`` in the ``sample`` coordinate. If ``False``, the dimension chunks specified in ``input_dims`` will be iterated over. preload_batch : bool, optional If ``True``, each batch will be loaded into memory before reshaping / processing, triggering any dask arrays to be computed. cache : dict, optional Dict-like object to cache batches in (e.g., Zarr DirectoryStore). Note: The caching API is experimental and subject to change. cache_preprocess: callable, optional A function to apply to batches prior to caching. Note: The caching API is experimental and subject to change. Yields ------ ds_slice : ``xarray.Dataset`` or ``xarray.DataArray`` Slices of the array matching the given batch size specification. """
[docs] def __init__( self, ds: Union[xr.Dataset, xr.DataArray], input_dims: Dict[Hashable, int], input_overlap: Dict[Hashable, int] = {}, batch_dims: Dict[Hashable, int] = {}, concat_input_dims: bool = False, preload_batch: bool = True, cache: Optional[Dict[str, Any]] = None, cache_preprocess: Optional[Callable] = None, ): self.ds = ds self.cache = cache self.cache_preprocess = cache_preprocess self._batch_selectors: BatchSchema = BatchSchema( ds, input_dims=input_dims, input_overlap=input_overlap, batch_dims=batch_dims, concat_input_bins=concat_input_dims, preload_batch=preload_batch, )
@property def input_dims(self): return self._batch_selectors.input_dims @property def input_overlap(self): return self._batch_selectors.input_overlap @property def batch_dims(self): return self._batch_selectors.batch_dims @property def concat_input_dims(self): return self._batch_selectors.concat_input_dims @property def preload_batch(self): return self._batch_selectors.preload_batch def __iter__(self) -> Iterator[Union[xr.DataArray, xr.Dataset]]: for idx in self._batch_selectors.selectors: yield self[idx] def __len__(self) -> int: return len(self._batch_selectors.selectors) def __getitem__(self, idx: int) -> Union[xr.Dataset, xr.DataArray]: if not isinstance(idx, int): raise NotImplementedError( f"{type(self).__name__}.__getitem__ currently requires a single integer key" ) if idx < 0: idx = list(self._batch_selectors.selectors)[idx] if self.cache and self._batch_in_cache(idx): return self._get_cached_batch(idx) if idx in self._batch_selectors.selectors: if self.concat_input_dims: new_dim_suffix = "_input" all_dsets: List = [] batch_selector = {} for dim in self._batch_selectors.batch_dims.keys(): starts = [ x[dim].start for x in self._batch_selectors.selectors[idx] ] stops = [x[dim].stop for x in self._batch_selectors.selectors[idx]] batch_selector[dim] = slice(min(starts), max(stops)) batch_ds = self.ds.isel(batch_selector) if self.preload_batch: batch_ds.load() for selector in self._batch_selectors.selectors[idx]: patch_ds = self.ds.isel(selector) all_dsets.append( _drop_input_dims( patch_ds, self.input_dims, suffix=new_dim_suffix, ) ) dsc = xr.concat(all_dsets, dim="input_batch") new_input_dims = [str(dim) + new_dim_suffix for dim in self.input_dims] batch = _maybe_stack_batch_dims(dsc, new_input_dims) else: batch_ds = self.ds.isel(self._batch_selectors.selectors[idx][0]) if self.preload_batch: batch_ds.load() batch = _maybe_stack_batch_dims( batch_ds, list(self.input_dims), ) else: raise IndexError("list index out of range") if self.cache is not None and self.cache_preprocess is not None: batch = self.cache_preprocess(batch) if self.cache is not None: self._cache_batch(idx, batch) return batch def _batch_in_cache(self, idx: int) -> bool: return self.cache is not None and f"{idx}/.zgroup" in self.cache def _cache_batch(self, idx: int, batch: Union[xr.Dataset, xr.DataArray]) -> None: batch.to_zarr(self.cache, group=str(idx), mode="a") def _get_cached_batch(self, idx: int) -> xr.Dataset: ds = xr.open_zarr(self.cache, group=str(idx)) if self.preload_batch: ds = ds.load() return ds