Source code for xbatcher.loaders.torch

from __future__ import annotations

from collections.abc import Callable
from types import ModuleType

import xarray as xr

from xbatcher import BatchGenerator

try:
    import torch
except ImportError as exc:
    raise ImportError(
        'The Xbatcher PyTorch Dataset API depends on PyTorch. Please '
        'install PyTorch to proceed.'
    ) from exc

try:
    import dask
except ImportError:
    dask: ModuleType | None = None  # type: ignore[no-redef]

T_DataArrayOrSet = xr.DataArray | xr.Dataset

# Notes:
# This module includes two PyTorch datasets.
#  - The MapDataset provides an indexable interface
#  - The IterableDataset provides a simple iterable interface
# Both can be provided as arguments to the the Torch DataLoader
# Assumptions made:
#  - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset)
# TODOs:
#  - need to test with additional dataset parameters (e.g. transforms)


def to_tensor(xr_obj: T_DataArrayOrSet) -> torch.Tensor:
    """Convert this DataArray or Dataset to a torch.Tensor"""
    if isinstance(xr_obj, xr.Dataset):
        xr_obj = xr_obj.to_array().squeeze(dim='variable')
    if isinstance(xr_obj, xr.DataArray):
        xr_obj = xr_obj.data
    return torch.tensor(xr_obj)


[docs]class MapDataset(torch.utils.data.Dataset):
[docs] def __init__( self, X_generator: BatchGenerator, y_generator: BatchGenerator | None = None, transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor, target_transform: Callable[[T_DataArrayOrSet], torch.Tensor] = to_tensor, ) -> None: """ PyTorch Dataset adapter for Xbatcher Parameters ---------- X_generator : xbatcher.BatchGenerator y_generator : xbatcher.BatchGenerator transform, target_transform : callable, optional A function/transform that takes in an Xarray object and returns a transformed version in the form of a torch.Tensor. """ self.X_generator = X_generator self.y_generator = y_generator self.transform = transform self.target_transform = target_transform
def __len__(self) -> int: return len(self.X_generator) def __getitem__(self, idx) -> tuple[torch.Tensor, torch.Tensor] | torch.Tensor: if torch.is_tensor(idx): idx = idx.tolist() if len(idx) == 1: idx = idx[0] else: raise NotImplementedError( f'{type(self).__name__}.__getitem__ currently requires a single integer key' ) # generate batch (or batches) if self.y_generator is not None: X_batch, y_batch = self.X_generator[idx], self.y_generator[idx] else: X_batch, y_batch = self.X_generator[idx], None # load batch (or batches) with dask if possible if dask is not None: X_batch, y_batch = dask.compute(X_batch, y_batch) # apply transformation(s) X_batch_tensor = self.transform(X_batch) if y_batch is not None: y_batch_tensor = self.target_transform(y_batch) assert isinstance(X_batch_tensor, torch.Tensor), self.transform if y_batch is None: return X_batch_tensor assert isinstance(y_batch_tensor, torch.Tensor) return X_batch_tensor, y_batch_tensor
[docs]class IterableDataset(torch.utils.data.IterableDataset):
[docs] def __init__( self, X_generator, y_generator, ) -> None: """ PyTorch Dataset adapter for Xbatcher Parameters ---------- X_generator : xbatcher.BatchGenerator y_generator : xbatcher.BatchGenerator """ self.X_generator = X_generator self.y_generator = y_generator
def __iter__(self): for xb, yb in zip(self.X_generator, self.y_generator): yield (xb.torch.to_tensor(), yb.torch.to_tensor())