{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Section: Variational Autoencoders (VAE)\n",
    "### MCB128: AI in Molecular Biology — Section Notebook\n",
    "## ✅ ANSWER KEY\n",
    "\n",
    "**Objectives (45 min):**\n",
    "1. Understand what an autoencoder (AE) is and what the bottleneck does\n",
    "2. Understand what makes a VAE different from a plain AE\n",
    "3. Implement a VAE in PyTorch from scratch\n",
    "4. Use the trained VAE as a generative model — sample new digits!\n",
    "\n",
    "---"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Part 0: Setup"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision import datasets, transforms\n",
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# Use GPU if available\n",
    "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "print(f'Using device: {device}')\n",
    "\n",
    "# Reproducibility\n",
    "torch.manual_seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load MNIST — images are 28x28 = 784 pixels, pixel values in [0,1]\n",
    "transform = transforms.ToTensor()\n",
    "\n",
    "train_dataset = datasets.MNIST(root='./data', train=True,  download=True, transform=transform)\n",
    "test_dataset  = datasets.MNIST(root='./data', train=False, download=True, transform=transform)\n",
    "\n",
    "train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)\n",
    "test_loader  = DataLoader(test_dataset,  batch_size=128, shuffle=False)\n",
    "\n",
    "print(f'Training samples: {len(train_dataset)}')\n",
    "print(f'Test samples:     {len(test_dataset)}')\n",
    "print(f'Image shape:      {train_dataset[0][0].shape}  →  flattened to 784')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Peek at the data\n",
    "fig, axes = plt.subplots(1, 10, figsize=(12, 1.5))\n",
    "for i, ax in enumerate(axes):\n",
    "    img, label = train_dataset[i]\n",
    "    ax.imshow(img.squeeze(), cmap='gray')\n",
    "    ax.set_title(str(label), fontsize=10)\n",
    "    ax.axis('off')\n",
    "plt.suptitle('Sample MNIST digits', y=1.05)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Part 1: The Plain Autoencoder — a quick recap\n",
    "\n",
    "Before building a VAE, let's pin down the AE idea.\n",
    "\n",
    "An autoencoder has three parts:\n",
    "\n",
    "```\n",
    "x  ──[Encoder]──►  z  ──[Decoder]──►  x'\n",
    "784               latent             784\n",
    "```\n",
    "\n",
    "- **Encoder**: compresses the input `x` into a low-dimensional **latent code** `z`.\n",
    "- **Bottleneck**: the latent code `z` — this is the compact representation.\n",
    "- **Decoder**: reconstructs `x'` from `z`.\n",
    "- **Loss**: reconstruction error, e.g., MSE or Binary Cross-Entropy between `x` and `x'`.\n",
    "\n",
    "Training objective:  \n",
    "$$\\text{minimize}\\quad \\|x - x'\\|^2$$\n",
    "\n",
    "### ❓ Conceptual check\n",
    "\n",
    "> **Q1.** The bottleneck in a standard AE maps each input `x` to a **single point** `z` in latent space. Why does this cause a problem if we want to *generate* new data by sampling from latent space?\n",
    ">\n",
    "> *(Think: what does the region between two known z-points look like? Is it meaningful?)*\n",
    "\n",
    "### ✅ Answer\n",
    "\n",
    "A plain AE is only trained to encode and decode specific data points — it has no incentive to make the **space between** those points meaningful. The latent space can be highly discontinuous: if you pick a random point `z` that wasn't explicitly produced by encoding a real input, the decoder may output garbage (e.g., visual noise rather than a recognizable digit). There are \"holes\" or unstructured gaps between the encoded points. Generation requires sampling arbitrary `z` values, but without any structure or coverage guarantee over the latent space, most random samples will land in these meaningless regions."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Part 2: From AE to VAE — the key idea\n",
    "\n",
    "### The problem with a regular AE for generation\n",
    "\n",
    "A plain AE learns one specific point `z` per data point `x`. The latent space has no guarantee of being **continuous** or **structured** — holes and gaps between encoded points may decode to garbage.\n",
    "\n",
    "### The VAE solution: encode a *distribution*, not a point\n",
    "\n",
    "Instead of mapping `x → z`, the VAE encoder maps:\n",
    "$$x \\;\\longrightarrow\\; (\\mu(x),\\; \\sigma^2(x))$$\n",
    "\n",
    "These are the mean and variance of a **Gaussian distribution** over `z`. We then **sample** `z` from this distribution:\n",
    "$$z \\sim \\mathcal{N}(\\mu(x),\\; \\sigma^2(x))$$\n",
    "\n",
    "The decoder reconstructs `x'` from this sampled `z`, exactly as before.\n",
    "\n",
    "### The VAE loss: two terms\n",
    "\n",
    "The VAE is trained to maximize the **Evidence Lower BOund (ELBO)**:\n",
    "\n",
    "$$\\mathcal{L} = \\underbrace{\\mathbb{E}_{q_\\phi(z|x)}[\\log p_\\theta(x|z)]}_{\\text{reconstruction term}} - \\underbrace{D_{KL}(q_\\phi(z|x) \\| p(z))}_{\\text{KL regularization term}}$$\n",
    "\n",
    "In practice we **minimize** the negative ELBO:\n",
    "$$\\text{Loss} = \\underbrace{\\text{BCE}(x, x')}_{\\text{reconstruction}} + \\underbrace{D_{KL}(\\mathcal{N}(\\mu,\\sigma^2) \\| \\mathcal{N}(0,1))}_{\\text{KL term}}$$\n",
    "\n",
    "| Term | What it does |\n",
    "|---|---|\n",
    "| **Reconstruction** | Pushes decoded `x'` to be close to the input `x` |\n",
    "| **KL divergence** | Pushes the learned `q(z\\|x)` toward a standard Normal `N(0,1)` — this regularizes the latent space to be smooth and filled |\n",
    "\n",
    "For two Gaussians, the KL term has a **closed form**:\n",
    "$$D_{KL}(\\mathcal{N}(\\mu,\\sigma^2) \\| \\mathcal{N}(0,1)) = -\\frac{1}{2}\\sum_{j=1}^{d}\\left(1 + \\log\\sigma_j^2 - \\mu_j^2 - \\sigma_j^2\\right)$$\n",
    "\n",
    "### The reparameterization trick\n",
    "\n",
    "We need gradients to flow through the sampling step `z ~ N(μ, σ²)`. Sampling is not differentiable — so we reparameterize:\n",
    "\n",
    "$$z = \\mu + \\epsilon \\cdot \\sigma, \\qquad \\epsilon \\sim \\mathcal{N}(0, 1)$$\n",
    "\n",
    "Now `ε` is the only random variable, and `z` is a **deterministic function** of `μ`, `σ`, and `ε` — so gradients can flow back through `μ` and `σ`.\n",
    "\n",
    "### ❓ Conceptual check\n",
    "\n",
    "> **Q2.** In your own words, explain why the KL term is necessary. What would happen at training time if you removed it and only kept the reconstruction loss?\n",
    "\n",
    "### ✅ Answer\n",
    "\n",
    "Without the KL term, the encoder has no incentive to produce distributions that overlap with or resemble `N(0,1)`. It would collapse to behaving like a plain AE: each input `x` gets mapped to a tiny, near-zero-variance Gaussian (essentially a single point), which perfectly minimizes reconstruction loss but leaves the latent space fragmented and unstructured. There would be no coverage guarantee over the latent space — arbitrary samples from `N(0,1)` would mostly land in regions the decoder has never seen, decoding to noise. The KL term regularizes the encoder to spread its distributions across a well-defined prior, making the latent space continuous and allowing meaningful generation.\n",
    "\n",
    "> **Q3.** Why can't we just backpropagate through the sampling step `z ~ N(μ, σ²)` directly? What does the reparameterization trick change?\n",
    "\n",
    "### ✅ Answer\n",
    "\n",
    "Sampling is a **stochastic operation** — it doesn't have a well-defined gradient. If we write `z = sample(N(μ, σ²))`, there is no way for the gradient of the loss w.r.t. `z` to flow back through the sampling step to `μ` and `σ`, because the function `sample(·)` is not differentiable with respect to its parameters.\n",
    "\n",
    "The reparameterization trick sidesteps this by rewriting the sample as `z = μ + ε·σ` where `ε ~ N(0, 1)` is drawn **independently** of the parameters. Now `z` is a differentiable (deterministic) function of `μ` and `σ` — the randomness is \"pushed outside\" the computation graph into `ε`. Gradients can then flow through `μ` and `σ` as usual."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Part 3: Implement the VAE\n",
    "\n",
    "Architecture:\n",
    "```\n",
    "Encoder: 784 → 400 → (μ[latent_dim], log σ²[latent_dim])\n",
    "Decoder: latent_dim → 400 → 784  (sigmoid output, since pixels ∈ [0,1])\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VAE(nn.Module):\n",
    "    def __init__(self, input_dim=784, hidden_dim=400, latent_dim=20):\n",
    "        super(VAE, self).__init__()\n",
    "        self.latent_dim = latent_dim\n",
    "\n",
    "        # ── Encoder ──────────────────────────────────────────────────────────\n",
    "        # Shared encoder layer: input_dim → hidden_dim (use ReLU activation)\n",
    "        self.encoder_fc = nn.Linear(input_dim, hidden_dim)\n",
    "\n",
    "        # Two separate heads — one for μ, one for log(σ²)\n",
    "        # Both map: hidden_dim → latent_dim\n",
    "        # ✅ ANSWER\n",
    "        self.fc_mu     = nn.Linear(hidden_dim, latent_dim)\n",
    "        self.fc_logvar = nn.Linear(hidden_dim, latent_dim)\n",
    "\n",
    "        # ── Decoder ──────────────────────────────────────────────────────────\n",
    "        # Two FC layers: latent_dim → hidden_dim → input_dim\n",
    "        # Use ReLU on the hidden layer, sigmoid on the output\n",
    "        # ✅ ANSWER\n",
    "        self.decoder_fc1 = nn.Linear(latent_dim, hidden_dim)\n",
    "        self.decoder_fc2 = nn.Linear(hidden_dim, input_dim)\n",
    "\n",
    "    # ── Encoder forward pass ─────────────────────────────────────────────────\n",
    "    def encode(self, x):\n",
    "        \"\"\"\n",
    "        x: (batch, 784)\n",
    "        Returns: mu (batch, latent_dim), logvar (batch, latent_dim)\n",
    "        \"\"\"\n",
    "        # ✅ ANSWER\n",
    "        # 1. Pass x through self.encoder_fc with ReLU\n",
    "        # 2. Compute mu and logvar from their respective heads\n",
    "        h      = F.relu(self.encoder_fc(x))\n",
    "        mu     = self.fc_mu(h)\n",
    "        logvar = self.fc_logvar(h)\n",
    "        return mu, logvar\n",
    "\n",
    "    # ── Reparameterization trick ──────────────────────────────────────────────\n",
    "    def reparameterize(self, mu, logvar):\n",
    "        \"\"\"\n",
    "        Sample z = mu + eps * sigma  where  eps ~ N(0, I)\n",
    "        logvar = log(sigma^2), so sigma = exp(0.5 * logvar)\n",
    "        \"\"\"\n",
    "        # ✅ ANSWER\n",
    "        # 1. Compute std from logvar\n",
    "        # 2. Sample eps from standard normal (same shape as std)\n",
    "        # 3. Return z = mu + eps * std\n",
    "        std = torch.exp(0.5 * logvar)\n",
    "        eps = torch.randn_like(std)      # same shape as std, drawn from N(0,1)\n",
    "        z   = mu + eps * std\n",
    "        return z\n",
    "\n",
    "    # ── Decoder forward pass ─────────────────────────────────────────────────\n",
    "    def decode(self, z):\n",
    "        \"\"\"\n",
    "        z: (batch, latent_dim)\n",
    "        Returns: x_recon (batch, 784) — pixel probabilities in [0, 1]\n",
    "        \"\"\"\n",
    "        # ✅ ANSWER\n",
    "        # Pass z through decoder_fc1 (ReLU) then decoder_fc2 (sigmoid)\n",
    "        x_recon = torch.sigmoid(self.decoder_fc2(F.relu(self.decoder_fc1(z))))\n",
    "        return x_recon\n",
    "\n",
    "    # ── Full forward pass ─────────────────────────────────────────────────────\n",
    "    def forward(self, x):\n",
    "        \"\"\"\n",
    "        x: (batch, 1, 28, 28)  →  flatten  →  encode  →  reparameterize  →  decode\n",
    "        Returns: x_recon, mu, logvar\n",
    "        \"\"\"\n",
    "        x = x.view(-1, 784)          # flatten the image\n",
    "        mu, logvar = self.encode(x)\n",
    "        z = self.reparameterize(mu, logvar)\n",
    "        x_recon = self.decode(z)\n",
    "        return x_recon, mu, logvar"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ELBO Loss\n",
    "\n",
    "Now implement the loss. Recall:\n",
    "$$\\text{Loss} = \\text{BCE}(x, x') - \\frac{1}{2}\\sum_{j}\\left(1 + \\log\\sigma_j^2 - \\mu_j^2 - \\sigma_j^2\\right)$$\n",
    "\n",
    "Note: `logvar = log(σ²)`, so `σ² = exp(logvar)`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def elbo_loss(x_recon, x, mu, logvar):\n",
    "    \"\"\"\n",
    "    x_recon : (batch, 784)  — reconstructed pixel probabilities\n",
    "    x       : (batch, 784)  — original pixels, flattened\n",
    "    mu      : (batch, latent_dim)\n",
    "    logvar  : (batch, latent_dim)\n",
    "\n",
    "    Returns scalar loss (sum over batch).\n",
    "    \"\"\"\n",
    "    # ✅ ANSWER\n",
    "    # Reconstruction loss: binary cross-entropy, summed over pixels\n",
    "    recon_loss = F.binary_cross_entropy(x_recon, x, reduction='sum')\n",
    "\n",
    "    # ✅ ANSWER\n",
    "    # KL divergence: -0.5 * sum(1 + logvar - mu^2 - exp(logvar))\n",
    "    # Sum over latent dims, sum over batch\n",
    "    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())\n",
    "\n",
    "    return recon_loss + kl_loss"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Part 4: Train the VAE"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Hyperparameters\n",
    "LATENT_DIM = 20\n",
    "EPOCHS     = 10\n",
    "LR         = 1e-3\n",
    "\n",
    "model     = VAE(latent_dim=LATENT_DIM).to(device)\n",
    "optimizer = torch.optim.Adam(model.parameters(), lr=LR)\n",
    "\n",
    "print(model)\n",
    "total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
    "print(f'\\nTrainable parameters: {total_params:,}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def train_epoch(model, loader, optimizer):\n",
    "    model.train()\n",
    "    total_loss = 0\n",
    "    for x, _ in loader:                        # we don't need labels\n",
    "        x = x.to(device)\n",
    "        optimizer.zero_grad()\n",
    "        x_recon, mu, logvar = model(x)\n",
    "        loss = elbo_loss(x_recon, x.view(-1, 784), mu, logvar)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        total_loss += loss.item()\n",
    "    return total_loss / len(loader.dataset)\n",
    "\n",
    "def eval_epoch(model, loader):\n",
    "    model.eval()\n",
    "    total_loss = 0\n",
    "    with torch.no_grad():\n",
    "        for x, _ in loader:\n",
    "            x = x.to(device)\n",
    "            x_recon, mu, logvar = model(x)\n",
    "            loss = elbo_loss(x_recon, x.view(-1, 784), mu, logvar)\n",
    "            total_loss += loss.item()\n",
    "    return total_loss / len(loader.dataset)\n",
    "\n",
    "# Training loop\n",
    "train_losses, test_losses = [], []\n",
    "\n",
    "for epoch in range(1, EPOCHS + 1):\n",
    "    tr = train_epoch(model, train_loader, optimizer)\n",
    "    te = eval_epoch(model, test_loader)\n",
    "    train_losses.append(tr)\n",
    "    test_losses.append(te)\n",
    "    print(f'Epoch {epoch:2d}/{EPOCHS} | Train loss: {tr:.2f} | Test loss: {te:.2f}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.figure(figsize=(7, 3))\n",
    "plt.plot(train_losses, label='Train ELBO loss')\n",
    "plt.plot(test_losses,  label='Test ELBO loss', linestyle='--')\n",
    "plt.xlabel('Epoch')\n",
    "plt.ylabel('Loss per sample')\n",
    "plt.title('VAE Training Curve')\n",
    "plt.legend()\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Part 5: Evaluate Reconstructions\n",
    "\n",
    "Pass real test images through the VAE and compare input vs. reconstruction."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "x_batch, labels = next(iter(test_loader))\n",
    "x_batch = x_batch.to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    x_recon, mu, logvar = model(x_batch)\n",
    "\n",
    "n = 8\n",
    "fig, axes = plt.subplots(2, n, figsize=(12, 3))\n",
    "\n",
    "for i in range(n):\n",
    "    # Original\n",
    "    axes[0, i].imshow(x_batch[i].cpu().squeeze(), cmap='gray')\n",
    "    axes[0, i].axis('off')\n",
    "    if i == 0: axes[0, i].set_title('Original', loc='left', fontsize=9)\n",
    "\n",
    "    # Reconstruction\n",
    "    axes[1, i].imshow(x_recon[i].cpu().view(28, 28), cmap='gray')\n",
    "    axes[1, i].axis('off')\n",
    "    if i == 0: axes[1, i].set_title('Reconstructed', loc='left', fontsize=9)\n",
    "\n",
    "plt.suptitle('VAE Reconstructions', fontsize=12)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Part 6: Generation — Sample from the Prior\n",
    "\n",
    "This is the key payoff of the VAE over a plain AE.  \n",
    "Because the KL term regularized the latent space toward `N(0, I)`, we can **sample `z` directly from the prior** and decode it into a new, never-before-seen digit.\n",
    "\n",
    "$$z \\sim \\mathcal{N}(0, I) \\quad \\longrightarrow \\quad x_{\\text{generated}} = \\text{decode}(z)$$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "\n",
    "with torch.no_grad():\n",
    "    # Sample z from the prior N(0, I)\n",
    "    z = torch.randn(16, LATENT_DIM).to(device)\n",
    "    generated = model.decode(z).cpu().view(-1, 28, 28)\n",
    "\n",
    "fig, axes = plt.subplots(2, 8, figsize=(12, 3.5))\n",
    "for i, ax in enumerate(axes.flat):\n",
    "    ax.imshow(generated[i], cmap='gray')\n",
    "    ax.axis('off')\n",
    "\n",
    "plt.suptitle('Generated digits — sampled from prior N(0, I)', fontsize=12)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ❓ Conceptual check\n",
    "\n",
    "> **Q4.** The generated images look like digits but are slightly blurry. Where does this blurriness come from? \n",
    "> *(Hint: think about what the reconstruction loss is averaging over.)*\n",
    "\n",
    "### ✅ Answer\n",
    "\n",
    "The blurriness arises from two related causes:\n",
    "\n",
    "1. **The BCE/MSE reconstruction loss averages over all possible realizations of `z`**: because the encoder produces a *distribution* over `z` (not a single point), the decoder must reconstruct `x` from many possible `z` values. The loss averages over these samples, and the gradient signal pushes the decoder to produce outputs that are a kind of average of all likely reconstructions. Averaging over possibilities in pixel space produces blurry, \"superimposed\" images.\n",
    "\n",
    "2. **Pixel-wise loss does not capture perceptual sharpness**: BCE and MSE treat each pixel independently and reward \"safe\", middle-ground predictions rather than committing to sharp edges. Sharper generative models (e.g., GANs, diffusion models) use adversarial or score-based objectives that directly penalize blurry outputs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Part 7: Interpolation in Latent Space\n",
    "\n",
    "One of the most compelling demonstrations of a structured latent space: smoothly interpolate between two real digits.\n",
    "\n",
    "With a plain AE, the path between two `z` points may pass through regions that decode to nonsense. With a VAE, the KL regularization fills in the latent space — so interpolations should stay meaningful."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "\n",
    "# Pick two test images\n",
    "x1, label1 = test_dataset[0]\n",
    "x2, label2 = test_dataset[1]\n",
    "\n",
    "with torch.no_grad():\n",
    "    mu1, _ = model.encode(x1.view(1, 784).to(device))\n",
    "    mu2, _ = model.encode(x2.view(1, 784).to(device))\n",
    "\n",
    "steps = 10\n",
    "alphas = torch.linspace(0, 1, steps)\n",
    "\n",
    "fig, axes = plt.subplots(1, steps, figsize=(14, 2))\n",
    "for i, alpha in enumerate(alphas):\n",
    "    z_interp = (1 - alpha) * mu1 + alpha * mu2\n",
    "    with torch.no_grad():\n",
    "        img = model.decode(z_interp).cpu().view(28, 28)\n",
    "    axes[i].imshow(img, cmap='gray')\n",
    "    axes[i].axis('off')\n",
    "\n",
    "axes[0].set_title(f'digit {label1}', fontsize=8)\n",
    "axes[-1].set_title(f'digit {label2}', fontsize=8)\n",
    "plt.suptitle('Latent space interpolation between two digits', fontsize=11)\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### ❓ Conceptual check\n",
    "\n",
    "> **Q5.** Notice that the interpolation passes through intermediate images that look like plausible (if blurry) digits or transitions between digits. What property of the VAE's training makes this possible? Could you do the same with a plain AE? Why or why not?\n",
    "\n",
    "### ✅ Answer\n",
    "\n",
    "The smooth interpolation is possible because of the **KL regularization term**. By pushing all learned posteriors `q(z|x)` toward the standard normal prior `N(0,1)`, the KL term forces the encoder distributions to overlap — there are no large empty gaps between the encoded regions of different digits. As a result, linear paths through latent space stay in regions that the decoder has been trained on, producing semantically coherent intermediate images.\n",
    "\n",
    "A plain AE **cannot** reliably do this. Without KL regularization, the AE is free to place each digit's latent code anywhere in latent space with no requirement for overlap or coverage. The regions between two encoded points `z1` and `z2` were never seen during training, so the decoder has no meaningful signal about how to map them — the interpolation will likely pass through nonsensical or visually garbled outputs."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Part 8 (Bonus): Visualizing the Latent Space\n",
    "\n",
    "If you finish early: reduce `LATENT_DIM = 2`, retrain for a few epochs, and plot the latent space directly — colored by digit class. You should see that the VAE organizes the digits into a smooth, well-separated structure in 2D, similar to what the lecture notes show for gene expression data.\n",
    "\n",
    "*(This is exactly what the GENEEX autoencoder does for cell types — except with 1,000 gene dimensions instead of 784 pixel dimensions.)*"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Bonus: retrain with latent_dim=2 to plot directly\n",
    "# (skip if short on time — the interpolation above is sufficient)\n",
    "\n",
    "model_2d = VAE(latent_dim=2).to(device)\n",
    "optimizer_2d = torch.optim.Adam(model_2d.parameters(), lr=1e-3)\n",
    "\n",
    "# Train for 5 epochs\n",
    "for epoch in range(1, 6):\n",
    "    tr = train_epoch(model_2d, train_loader, optimizer_2d)\n",
    "    print(f'Epoch {epoch}/5 | Train loss: {tr:.2f}')\n",
    "\n",
    "# Encode all test images and collect mu values\n",
    "model_2d.eval()\n",
    "all_mu, all_labels = [], []\n",
    "\n",
    "with torch.no_grad():\n",
    "    for x, y in test_loader:\n",
    "        mu, _ = model_2d.encode(x.view(-1, 784).to(device))\n",
    "        all_mu.append(mu.cpu())\n",
    "        all_labels.append(y)\n",
    "\n",
    "all_mu     = torch.cat(all_mu).numpy()\n",
    "all_labels = torch.cat(all_labels).numpy()\n",
    "\n",
    "plt.figure(figsize=(8, 6))\n",
    "scatter = plt.scatter(all_mu[:, 0], all_mu[:, 1],\n",
    "                      c=all_labels, cmap='tab10', alpha=0.4, s=5)\n",
    "plt.colorbar(scatter, label='Digit class')\n",
    "plt.xlabel('z₁'); plt.ylabel('z₂')\n",
    "plt.title('VAE latent space (2D) — colored by digit class')\n",
    "plt.tight_layout()\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "---\n",
    "## Summary: AE vs VAE at a glance\n",
    "\n",
    "| | Autoencoder (AE) | Variational Autoencoder (VAE) |\n",
    "|---|---|---|\n",
    "| Encoder output | A single point `z` | Parameters `(μ, σ²)` of a distribution |\n",
    "| Latent variable | Deterministic | Stochastic (sampled via reparameterization) |\n",
    "| Latent space | Unstructured, may have holes | Regularized toward `N(0,1)` → continuous |\n",
    "| Loss | Reconstruction only | Reconstruction + KL divergence (ELBO) |\n",
    "| Can generate new samples? | ❌ Not reliably | ✅ Yes — sample `z ~ N(0,1)`, decode |\n",
    "| Framework | Discriminative | Generative (probabilistic) |\n",
    "| Biology example | GENEEX (gene expression compression) | scVI (single-cell RNA-seq) |\n",
    "\n",
    "---\n",
    "*MCB128 — Section Notebook Answer Key | Based on Prof. Rivas's Block 5 lecture*"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
