{ "cells": [ { "cell_type": "markdown", "id": "b314e777-7ffb-4e62-b4c5-ce8a785c5181", "metadata": {}, "source": [ "# End-to-End Tutorial: Training a Neural Network with Keras and Xbatcher\n", "\n", "## Import Required Libraries" ] }, { "cell_type": "code", "execution_count": null, "id": "5d912ff0-d808-4704-8dea-b9e1b5a53bf1", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import tensorflow as tf\n", "import xarray as xr\n", "from keras import layers, models, optimizers\n", "\n", "import xbatcher as xb\n", "import xbatcher.loaders.keras" ] }, { "cell_type": "code", "execution_count": null, "id": "7fb892c1-50fd-48c8-8567-b150946b53c9", "metadata": {}, "outputs": [], "source": [ "# Open the dataset stored in Zarr format\n", "ds = xr.open_dataset(\n", " 's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',\n", " engine='zarr',\n", " chunks={},\n", " backend_kwargs={'storage_options': {'anon': True}},\n", ")" ] }, { "cell_type": "markdown", "id": "c98134fe-581f-412a-93e3-6b07b7706078", "metadata": {}, "source": [ "## Define Batch Generators" ] }, { "cell_type": "code", "execution_count": null, "id": "c680ebd7-0310-4f40-91b5-e7cc1a59e853", "metadata": {}, "outputs": [], "source": [ "# Define batch generators for features (X) and labels (y)\n", "X_bgen = xb.BatchGenerator(\n", " ds['images'],\n", " input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},\n", " preload_batch=False, # Load each batch dynamically\n", ")\n", "y_bgen = xb.BatchGenerator(\n", " ds['labels'], input_dims={'sample': 2000}, preload_batch=False\n", ")" ] }, { "cell_type": "markdown", "id": "91d63180-e3a6-49f7-a8e7-67b8b698b08c", "metadata": {}, "source": [ "## Map Batches to a Keras-Compatible Dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "d1195057-269b-44ba-a3e7-aeedaa4ba8df", "metadata": {}, "outputs": [], "source": [ "# Use xbatcher's MapDataset to wrap the generators\n", "dataset = xbatcher.loaders.keras.CustomTFDataset(X_bgen, y_bgen)\n", "\n", "# Create a DataLoader using tf.data.Dataset\n", "train_dataloader = tf.data.Dataset.from_generator(\n", " lambda: iter(dataset),\n", " output_signature=(\n", " tf.TensorSpec(shape=(2000, 1, 28, 28), dtype=tf.float32), # Images\n", " tf.TensorSpec(shape=(2000,), dtype=tf.int64), # Labels\n", " ),\n", ").prefetch(3) # Prefetch 3 batches to improve performance" ] }, { "cell_type": "code", "execution_count": null, "id": "1892411c-ca17-4d7f-b76b-5b5decaa78c1", "metadata": {}, "outputs": [], "source": [ "## Visualize a Sample Batch" ] }, { "cell_type": "code", "execution_count": null, "id": "133b24bc-e7bc-4734-ad0a-22a848dd204c", "metadata": {}, "outputs": [], "source": [ "# Extract a batch from the DataLoader\n", "for train_features, train_labels in train_dataloader.take(1):\n", " print(f'Feature batch shape: {train_features.shape}')\n", " print(f'Labels batch shape: {train_labels.shape}')\n", "\n", " img = train_features[0].numpy().squeeze() # Extract the first image\n", " label = train_labels[0].numpy()\n", " plt.imshow(img, cmap='gray')\n", " plt.title(f'Label: {label}')\n", " plt.show()\n", " break" ] }, { "cell_type": "markdown", "id": "1e5d6a66-1943-47da-be67-9b54d51defed", "metadata": {}, "source": [ "## Build a Simple Neural Network with Keras" ] }, { "cell_type": "code", "execution_count": null, "id": "8b0490e5-7ccc-47fe-90ec-d41a81c4eb20", "metadata": {}, "outputs": [], "source": [ "# Define a simple feedforward neural network\n", "model = models.Sequential(\n", " [\n", " layers.Flatten(input_shape=(1, 28, 28)), # Flatten input images\n", " layers.Dense(128, activation='relu'), # Fully connected layer with 128 units\n", " layers.Dense(10, activation='softmax'), # Output layer for 10 classes\n", " ]\n", ")\n", "\n", "# Compile the model\n", "model.compile(\n", " optimizer=optimizers.Adam(learning_rate=0.001),\n", " loss='sparse_categorical_crossentropy',\n", " metrics=['accuracy'],\n", ")\n", "\n", "# Display model summary\n", "model.summary()" ] }, { "cell_type": "markdown", "id": "838df9c6-0753-4120-a0e0-dcc1480416b4", "metadata": {}, "source": [ "## Train the Model " ] }, { "cell_type": "code", "execution_count": null, "id": "25e86eba-4d4e-47cc-a6a7-9f0be244b009", "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "# Train the model for 5 epochs\n", "epochs = 5\n", "\n", "model.fit(\n", " train_dataloader, # Pass the DataLoader directly\n", " epochs=epochs,\n", " verbose=1, # Print progress during training\n", ")" ] }, { "cell_type": "markdown", "id": "a0f4246c-6461-4e2a-a49d-df6c1ce770fc", "metadata": {}, "source": [ "## Visualize a Sample Prediction" ] }, { "cell_type": "code", "execution_count": null, "id": "9361cb65-3c0d-40d6-be5c-18b309626817", "metadata": {}, "outputs": [], "source": [ "# Visualize a prediction on a sample image\n", "for train_features, train_labels in train_dataloader.take(1):\n", " img = train_features[0].numpy().squeeze()\n", " label = train_labels[0].numpy()\n", " predicted_label = tf.argmax(model.predict(train_features[:1]), axis=1).numpy()[0]\n", "\n", " plt.imshow(img, cmap='gray')\n", " plt.title(f'True Label: {label}, Predicted: {predicted_label}')\n", " plt.show()\n", " break" ] }, { "cell_type": "markdown", "id": "372d0e0a-1542-4aa0-b3b9-9fd4337459ba", "metadata": {}, "source": [ "## Key Highlights \n", "\n", "- **Dynamic Batching**: Xbatcher and the MapDataset class allow for dynamic loading of batches, which reduces memory usage and speeds up data processing.\n", "- **Prefetching**: The prefetch feature in `tf.data.Dataset` overlaps data loading with model training to minimize idle time.\n", "- **Compatibility**: The pipeline works seamlessly with `keras.Model.fit`, simplifying training workflows." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }