{ "cells": [ { "cell_type": "markdown", "id": "e579c0e1-bb12-4c7b-8a97-bdb6fad01755", "metadata": {}, "source": [ "# End-to-End Tutorial: Training a Neural Network with PyTorch and Xbatcher\n", "\n", "This tutorial demonstrates how to use xarray, xbatcher, and PyTorch to train a simple neural network on the FashionMNIST dataset." ] }, { "cell_type": "markdown", "id": "5aa4bf55-588a-465d-affb-de5d16a54cdd", "metadata": {}, "source": [ "## Step 1: Setup \n", "\n", "Import the necessary libraries and load the dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "916bb4a8-d2df-49e8-9109-a92299960886", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "import torch.utils.data\n", "import xarray as xr\n", "\n", "import xbatcher as xb\n", "import xbatcher.loaders.torch" ] }, { "cell_type": "code", "execution_count": null, "id": "474b2cc1-9991-4060-92fe-559c15d96678", "metadata": { "scrolled": true }, "outputs": [], "source": [ "ds = xr.open_dataset(\n", " 's3://carbonplan-share/xbatcher/fashion-mnist-train.zarr',\n", " engine='zarr',\n", " chunks={},\n", " backend_kwargs={'storage_options': {'anon': True}},\n", ")\n", "ds" ] }, { "cell_type": "code", "execution_count": null, "id": "d647846e-381a-4901-ba1c-d4d47ff7b1fa", "metadata": {}, "outputs": [], "source": [ "ds.sel(sample=1).images.plot(cmap='gray');" ] }, { "cell_type": "markdown", "id": "533b9827-8c27-4229-8035-cf39f3e99e54", "metadata": {}, "source": [ "## Step 2: Create batch generator and data loader\n", "\n", "We use `xbatcher` to create batch generators for the images (`X_bgen`) and labels (`y_gen`)" ] }, { "cell_type": "code", "execution_count": null, "id": "303b2a1d-9126-44e7-b312-fa546eca8f2e", "metadata": {}, "outputs": [], "source": [ "# Define batch generators\n", "X_bgen = xb.BatchGenerator(\n", " ds['images'],\n", " input_dims={'sample': 2000, 'channel': 1, 'height': 28, 'width': 28},\n", " preload_batch=False,\n", ")\n", "y_bgen = xb.BatchGenerator(\n", " ds['labels'], input_dims={'sample': 2000}, preload_batch=False\n", ")\n", "X_bgen[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "1c1ab1a7-4bf2-4d73-a8e2-a88e7cdf6829", "metadata": {}, "outputs": [], "source": [ "# Map batches to a PyTorch-compatible dataset\n", "dataset = xbatcher.loaders.torch.MapDataset(X_bgen, y_bgen)" ] }, { "cell_type": "code", "execution_count": null, "id": "942ee9c8-5369-49d3-8038-0658b80ff851", "metadata": {}, "outputs": [], "source": [ "# Create a DataLoader\n", "train_dataloader = torch.utils.data.DataLoader(\n", " dataset,\n", " batch_size=None, # Using batches defined by the dataset itself (via xbatcher)\n", " prefetch_factor=3, # Prefetch up to 3 batches in advance to reduce data loading latency\n", " num_workers=4, # Use 4 parallel worker processes to load data concurrently\n", " persistent_workers=True, # Keep workers alive between epochs for faster subsequent epochs\n", " multiprocessing_context='forkserver', # Use \"forkserver\" to spawn subprocesses, ensuring stability in multiprocessing\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "16953fd7-53d8-4d37-80e8-57f5b4daccee", "metadata": {}, "outputs": [], "source": [ "train_features, train_labels = next(iter(train_dataloader))" ] }, { "cell_type": "code", "execution_count": null, "id": "119cbac9-a973-4b37-b42d-12e03f105826", "metadata": {}, "outputs": [], "source": [ "print(f'Feature batch shape: {train_features.size()}')\n", "print(f'Labels batch shape: {train_labels.size()}')\n", "img = train_features[0].squeeze()\n", "label = train_labels[0]\n", "plt.imshow(img, cmap='gray')\n", "plt.show()\n", "print(f'Label: {label}')" ] }, { "cell_type": "markdown", "id": "4a6c3219-b281-4782-a77d-58860f7f7c83", "metadata": {}, "source": [ "## Step 3: Define the Neural Network\n", "\n", "We define a simple feedforward neural network for classification." ] }, { "cell_type": "code", "execution_count": null, "id": "09620988-bbda-4508-b39e-e1d81f2374c9", "metadata": {}, "outputs": [], "source": [ "class SimpleNN(nn.Module):\n", " def __init__(self):\n", " super().__init__()\n", " self.flatten = nn.Flatten()\n", " self.fc1 = nn.Linear(28 * 28, 128)\n", " self.fc2 = nn.Linear(128, 10)\n", "\n", " def forward(self, x):\n", " x = self.flatten(x)\n", " x = torch.relu(self.fc1(x))\n", " x = self.fc2(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "id": "cff27830-a15e-4f4a-b529-7fed8ea7632d", "metadata": {}, "outputs": [], "source": [ "# Instantiate the model\n", "model = SimpleNN()\n", "model" ] }, { "cell_type": "markdown", "id": "c12073c2-a380-4a28-bc24-68c1811739b5", "metadata": {}, "source": [ "## Step 4: Define Loss Function and Optimizer\n", "We use Cross-Entropy Loss and the Adam optimizer." ] }, { "cell_type": "code", "execution_count": null, "id": "ab29bcbc-11e6-48ac-8618-8b6da4627b8f", "metadata": {}, "outputs": [], "source": [ "loss_fn = nn.CrossEntropyLoss()\n", "optimizer = optim.Adam(model.parameters(), lr=0.001)" ] }, { "cell_type": "markdown", "id": "ec0bd97d-c9fd-42ae-8962-ec9ed6434331", "metadata": {}, "source": [ "## Step 5: Train the Model\n", "We train the model using the data loader." ] }, { "cell_type": "code", "execution_count": null, "id": "76fe4f2d-d15d-43ba-a079-0c3c8c4d65a2", "metadata": {}, "outputs": [], "source": [ "%%time\n", "\n", "epochs = 5\n", "\n", "for epoch in range(epochs):\n", " print(f'Epoch {epoch+1}/{epochs}')\n", " for batch, (X, y) in enumerate(train_dataloader):\n", " # Forward pass\n", " predictions = model(X)\n", " loss = loss_fn(predictions, y)\n", "\n", " # Backward pass\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " if batch % 10 == 0:\n", " print(f'Batch {batch}: Loss = {loss.item():.4f}')\n", "\n", "print('Training completed!')" ] }, { "cell_type": "markdown", "id": "aa1db735-9670-41dc-a27e-74369e8c320d", "metadata": {}, "source": [ "## Step 6: Evaluate the Model\n", "You can evaluate the model on the test set or visualize some predictions." ] }, { "cell_type": "code", "execution_count": null, "id": "ea26801a-18a9-4ffe-9d04-92157c42bc8a", "metadata": {}, "outputs": [], "source": [ "# Visualize a sample prediction\n", "img = train_features[0].squeeze()\n", "label = train_labels[0]\n", "predicted_label = torch.argmax(model(train_features[0:1]), dim=1).item()\n", "\n", "plt.imshow(img, cmap='gray')\n", "plt.title(f'True Label: {label}, Predicted: {predicted_label}')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "8c52dcc1-d583-4cfe-af34-8030d4b451f0", "metadata": {}, "source": [ "## Key Highlights\n", "\n", "- **Data Handling**: We use Xbatcher to create efficient, chunked data pipelines from Xarray datasets.\n", "- **Integration**: The `xbatcher.loaders.torch.MapDatase`t enables direct compatibility with PyTorch's DataLoader.\n", "- **Training**: PyTorch simplifies the model training loop while leveraging the custom data pipeline.\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }