Source code for xbatcher.loaders.keras
from typing import Any, Callable, Optional, Tuple
try:
import tensorflow as tf
except ImportError as exc:
raise ImportError(
"The Xbatcher TensorFlow Dataset API depends on TensorFlow. Please "
"install TensorFlow to proceed."
) from exc
# Notes:
# This module includes one Keras dataset, which can be provided to model.fit().
# - The CustomTFDataset provides an indexable interface
# Assumptions made:
# - The dataset takes pre-configured X/y xbatcher generators (may not always want two generators in a dataset)
[docs]class CustomTFDataset(tf.keras.utils.Sequence):
[docs] def __init__(
self,
X_generator,
y_generator,
*,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
) -> None:
"""
Keras Dataset adapter for Xbatcher
Parameters
----------
X_generator : xbatcher.BatchGenerator
y_generator : xbatcher.BatchGenerator
transform : callable, optional
A function/transform that takes in an array and returns a transformed version.
target_transform : callable, optional
A function/transform that takes in the target and transforms it.
"""
self.X_generator = X_generator
self.y_generator = y_generator
self.transform = transform
self.target_transform = target_transform
def __len__(self) -> int:
return len(self.X_generator)
def __getitem__(self, idx: int) -> Tuple[Any, Any]:
X_batch = tf.convert_to_tensor(self.X_generator[idx].data)
y_batch = tf.convert_to_tensor(self.y_generator[idx].data)
# TODO: Should the transformations be applied before tensor conversion?
if self.transform:
X_batch = self.transform(X_batch)
if self.target_transform:
y_batch = self.target_transform(y_batch)
return X_batch, y_batch