Source code for xbatcher.loaders.torch

from typing import Any, Callable, Optional, Tuple

try:
    import torch
except ImportError as exc:
    raise ImportError(
        "The Xbatcher PyTorch Dataset API depends on PyTorch. Please "
        "install PyTorch to proceed."
    ) from exc

# Notes:
# This module includes two PyTorch datasets.
#  - The MapDataset provides an indexable interface
#  - The IterableDataset provides a simple iterable interface
# Both can be provided as arguments to the the Torch DataLoader
# Assumptions made:
#  - Each dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset)
# TODOs:
#  - need to test with additional dataset parameters (e.g. transforms)


[docs]class MapDataset(torch.utils.data.Dataset):
[docs] def __init__( self, X_generator, y_generator, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, ) -> None: """ PyTorch Dataset adapter for Xbatcher Parameters ---------- X_generator : xbatcher.BatchGenerator y_generator : xbatcher.BatchGenerator transform : callable, optional A function/transform that takes in an array and returns a transformed version. target_transform : callable, optional A function/transform that takes in the target and transforms it. """ self.X_generator = X_generator self.y_generator = y_generator self.transform = transform self.target_transform = target_transform
def __len__(self) -> int: return len(self.X_generator) def __getitem__(self, idx) -> Tuple[Any, Any]: if torch.is_tensor(idx): idx = idx.tolist() if len(idx) == 1: idx = idx[0] else: raise NotImplementedError( f"{type(self).__name__}.__getitem__ currently requires a single integer key" ) X_batch = self.X_generator[idx].torch.to_tensor() y_batch = self.y_generator[idx].torch.to_tensor() if self.transform: X_batch = self.transform(X_batch) if self.target_transform: y_batch = self.target_transform(y_batch) return X_batch, y_batch
[docs]class IterableDataset(torch.utils.data.IterableDataset):
[docs] def __init__( self, X_generator, y_generator, ) -> None: """ PyTorch Dataset adapter for Xbatcher Parameters ---------- X_generator : xbatcher.BatchGenerator y_generator : xbatcher.BatchGenerator """ self.X_generator = X_generator self.y_generator = y_generator
def __iter__(self): for xb, yb in zip(self.X_generator, self.y_generator): yield (xb.torch.to_tensor(), yb.torch.to_tensor())