{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "GdaFPpZ-ylNM"
   },
   "source": [
    "# **b5 autoencoders, variational autoencoders / gene expression**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### In this homework we are going to do single-cell RNA seq analysis using both an autoencoder (AE) and a variational autoencoder (VAE).\n",
    "\n",
    "#### The data includes cells from several different brain cells, and our goal is to investigate:\n",
    "\n",
    "  *  Do the different brain regions express different cell types?\n",
    "  \n",
    "  * For any two regions for which the previous answer is yes, which genes are resposible for the different cell types?\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "# Here we download the data\n",
    "# \n",
    "#scRNAseq = np.loadtxt(\"http://rivaslab.org/teaching/MCB128_AIMB/downloads/b5_homework_data_scRNAseq.csv\", delimiter=\",\")\n",
    "scRNAseq = np.loadtxt(\"../web/_site/downloads/b5_homework_data_scRNAseq.csv\", delimiter=\",\")\n",
    "#scRNAseq = np.loadtxt(\"../web/downloads/b5_homework_data_scRNAseq.csv\", delimiter=\",\")\n",
    "\n",
    "# the structure of the data is:\n",
    "#   scRNAseq[N,G+1] # N = number of cells\n",
    "#                   # G = number of genes\n",
    "# \n",
    "# such that \n",
    "#   expression counts for each gene in each cell: scRNAseq(N, G) # all columns except last → shape (N, G)\n",
    "#   cell label as to the region it belongs to:    scRNAseq(N,-1) # last column  shape (N,)\n",
    "#\n",
    "#  \n",
    "#\n",
    "# (1) Inspect the data\n",
    "#\n",
    "#   (1.a) extract counts and labels\n",
    "#   (1.b) calculate N and G\n",
    "#   (1.c) A cel label determines the brain region from which the cell comes from. \n",
    "#         Determine n_regions, that is, the number of labels/brain regions do we have data from?\n",
    "#\n",
    "# (1.a)\n",
    "counts =    # all columns except last → shape (N, G)\n",
    "labels =    # last column             → shape (N,)\n",
    "\n",
    "counts = torch.from_numpy(counts).float() # add this so that the DataLoader works \n",
    "labels = torch.from_numpy(labels).long()  # add this to the DataLoader works \n",
    "\n",
    "# (1.b)\n",
    "N = \n",
    "G = \n",
    "print(\"# of cells\", N)\n",
    "print(\"# genes per cell\", G)\n",
    "\n",
    "# (1.c)\n",
    "n_regions = \n",
    "print('n_regions', n_regions)\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (1) Inspect the data further\n",
    "#\n",
    "#   (1.d) How many cells per region?\n",
    "#   (1.e) What is the range of expression values?\n",
    "#         Calculate the min, max, avg number of counts per region\n",
    "#\n",
    "\n",
    "\n",
    "   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Vhuw8TMNzLaF"
   },
   "outputs": [],
   "source": [
    "# (2) Train an AE to explore the information learned by a latent variable z of dimention 2.\n",
    "#\n",
    "# (2.a) Build the AE (you use the code from the class lecture)\n",
    "#       Hidden dimension can be h1=512, h2=256, latent_dim = 2\n",
    "#\n",
    "# (2.b) Load the data into the DataLoader\n",
    "#\n",
    "# (2.c) Write the training loop, using a MSE loss and Adam optimization. \n",
    "#       Justify metaparameters selection: learing rate, number of epochs, batch_size\n",
    "#\n",
    "# (2.d) plot the latent variable Z[2] and draw conclusions. \n",
    "#       What can you say about the cells in the different brain reagions?\n",
    "#\n",
    "# (2.e) Compare your results for the latent variable z to those obtained by PCA\n",
    "\n",
    "\n",
    "\n",
    "# (2.a) The autoencoder\n",
    "#\n",
    "class GeneExprAE(nn.Module):\n",
    "    def __init__(self, input_dim, h1=512, h2=256, latent_dim=32):\n",
    "        super().__init__()\n",
    "        # ----- Encoder: input -> h1 -> h2 -> latent(z) -----\n",
    "        self.enc_fc1 = nn.Linear(input_dim, h1)\n",
    "        self.enc_fc2 = nn.Linear(h1, h2)\n",
    "        self.enc_latent = nn.Linear(h2, latent_dim) \n",
    "\n",
    "        # ----- Decoder: latent -> h2 -> h1 -> input -----\n",
    "        self.dec_fc1 = nn.Linear(latent_dim, h2)\n",
    "        self.dec_fc2 = nn.Linear(h2, h1)\n",
    "        self.dec_out = nn.Linear(h1, input_dim)\n",
    "\n",
    "    def encode(self, x):\n",
    "        h = F.relu(self.enc_fc1(x))   # 1st hidden\n",
    "        h = F.relu(self.enc_fc2(h))   # 2nd hidden\n",
    "        z = self.enc_latent(h)        # bottleneck\n",
    "        return z\n",
    "\n",
    "    def decode(self, z):\n",
    "        h = F.relu(self.dec_fc1(z))   # 1st decoder hidden\n",
    "        h = F.relu(self.dec_fc2(h))   # 2nd decoder hidden\n",
    "        x_ae = self.dec_out(h)        # linear output for real-valued expression\n",
    "        return x_ae\n",
    "\n",
    "    def forward(self, x):\n",
    "        z = self.encode(x)\n",
    "        x_ae = self.decode(z)\n",
    "        return x_ae, z"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "6I3a3vhehyW_",
    "outputId": "797a726a-a3c1-4dea-a9b9-d59cb04185d8"
   },
   "outputs": [],
   "source": [
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "# (2.b) Loading of the data into a DataLoader, and create the model\n",
    "\n",
    "# (2.c) The training loop\n",
    "# MSELoss = Mean square error = squared L2 norm\n",
    "#\n",
    "#  loss(n) = 1/G \\sum_g ( x[n,g] - x_out[n,g] )^2\n",
    "#\n",
    "#\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "R2WHmrhDi1sC",
    "outputId": "c679cfe6-58e1-435e-8eba-5e4ae15c27c2",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# (2.c) The latent variable for all samples calculated by runing the model in inference mode\n",
    "model_ae.eval()\n",
    "with torch.inference_mode():\n",
    "    \n",
    "# You can use this ploting function (or a different of your liking)\n",
    "def plot_z(z, labels):\n",
    "    plt.figure(figsize=(6, 5))\n",
    "    \n",
    "    n_regions = labels.max() + 1\n",
    "    for k in range(n_regions):\n",
    "        mask = (labels.detach().cpu().numpy() == k)\n",
    "        plt.scatter(z[mask, 0], z[mask, 1], label=f\"class {k}\", alpha=0.7, s=20)\n",
    "\n",
    "        \n",
    "# Plot the latent variable z[N,2]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 525
    },
    "id": "iF-CQC1ulN87",
    "outputId": "f30692b1-d3b6-475f-f319-afe954ad234f",
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "# (2.d)\n",
    "def PCA(X,y):\n",
    "    torch.manual_seed(0)\n",
    "\n",
    "    N, D = X.shape\n",
    "    num_classes = int(y.max().item()) + 1\n",
    "    print(\"N D classes\", N, G, num_classes)\n",
    "\n",
    "    # 1. Center the data\n",
    "    X_mean = X.mean(dim=0, keepdim=True)\n",
    "    X_centered = X - X_mean\n",
    "\n",
    "    # 2. Compute covariance matrix (D x D)\n",
    "    cov = X_centered.t().mm(X_centered) / (N - 1)\n",
    "\n",
    "    # 3. Eigen-decomposition\n",
    "    eigenvalues, eigenvectors = torch.linalg.eigh(cov)   # eigh since cov is symmetric\n",
    "\n",
    "    # 4. Take top 2 principal components (largest eigenvalues)\n",
    "    # eigenvalues are in ascending order -> take last 2 columns\n",
    "    pc_vectors = eigenvectors[:, -2:]    # shape [D, 2]\n",
    "\n",
    "    # 5. Project data onto PCs -> X_pca [N, 2]\n",
    "    X_pca = X_centered.mm(pc_vectors)\n",
    "\n",
    "    # 6. Plot with colors for classes\n",
    "    X_pca = X_pca.detach().cpu().numpy()\n",
    "    y_np = y.detach().cpu().numpy()\n",
    "\n",
    "    plt.figure(figsize=(6, 5))\n",
    "    for k in range(num_classes):\n",
    "        mask = (y_np == k)\n",
    "        plt.scatter(X_pca[mask, 0], X_pca[mask, 1], label=f\"class {k}\", alpha=0.7, s=20)\n",
    "\n",
    "    plt.xlabel(\"PC1\")\n",
    "    plt.ylabel(\"PC2\")\n",
    "    plt.legend()\n",
    "    plt.title(\"PCA (2D) of X, colored by class\")\n",
    "    plt.tight_layout()\n",
    "    plt.show()\n",
    "    \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Y2h2wzVn2dqO"
   },
   "outputs": [],
   "source": [
    "# (3) Analysis of the same expression data using a VAE.\n",
    "#     While both AE and VAE can do classification by analyzing the latent variable, \n",
    "#     one thing that the VAE allows us to do is to test differential gene expression.\n",
    "#\n",
    "#     If we find, two types of cells that appear to be differntially expressed, \n",
    "#     which actual genes are resposible ofr that difference? \n",
    "#     We are exploring that aspect in this section\n",
    "# \n",
    "#      (3.a) Build the forward loop.\n",
    "#            you can use the code from b5_lecture_code_scVI.ipynb, \n",
    "#            but to make it similar to the AE above, please add one more layer to the decoder, \n",
    "#            and set the encoder/decoder parameters similar to those of the AE. \n",
    "#            In particular, the latent variable z with dimension 2.\n",
    "#\n",
    "#       (3.b) Build the training loop and train.\n",
    "#             Caution: you may need to add a BachNorm layer  to the encoder, as\n",
    "#                    nn.BatchNorm1d(n_hidden1),\n",
    "#             Justify if you do.\n",
    "#\n",
    "#       (3.c) Inspect the result for the latent variable and compare to those of the AE.\n",
    "#\n",
    "#       (3.d) Use the PCA code to compara the output of the decoder to the inputs.\n",
    "#\n",
    "#       (3.e) Differential gene expression. Select two of the region that are classified as differnt,\n",
    "#             and estimate the genes responsible for that overall cell type difference.\n",
    "#    \n",
    "    \n",
    "# (3.a) The VAE model\n",
    "#\n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (3.b) Training loop\n",
    "from torch.utils.data import DataLoader, TensorDataset\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (3.c) Latent variables\n",
    "model.eval()\n",
    "with torch.inference_mode():\n",
    "    \n",
    "# Ploting both z and mu_z\n",
    "\n",
    "# (3.d) Compare the actual counts to the mu_x values\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# (3,e) In this section you could follow the paper and assess differential gene expression by \n",
    "# using the zinb with parameters mu_x and theta to generate a lot samples for a given gene in both cell types.\n",
    "#\n",
    "# More simply, you can use the mu_x[n,g] which are the expression mean values.\n",
    "#\n",
    "# Let's assume you select cell_type a and cell type b\n",
    "#\n",
    "# For each gene, we are going to consider two different hypotheses:\n",
    "# \n",
    "#  H1: g is differentially expressed in type_a and type_b\n",
    "#  H0: there in no different in the expression of g between type_a and type_b\n",
    "#\n",
    "#\n",
    "#  We quantify H1 such that for any two cells of type_a and type_b, \n",
    "#\n",
    "#        h1 = # of cases such that |mu_x(a) - mu_x(b)| > score_thresh\n",
    "#\n",
    "#  We quantify H0 such that for any two cells of type_a and type_b, \n",
    "#\n",
    "#        h2 = # of cases such that |mu_x(a) - mu_x(b)| <= score_thresh\n",
    "#        \n",
    "#\n",
    "# then we select as differentially expressed genes such that h1/h0 > alpha\n",
    "#\n",
    "# (3.e) How would you select reasonable values of score_thresh and alpha?\n",
    "#\n",
    "#\n",
    "mask0 = (labels == 0)\n",
    "mask1 = (labels == 1)\n",
    "mask2 = (labels == 2)\n",
    "mask3 = (labels == 3)\n",
    "\n",
    "mu_x_0 = mu_x[mask0]  # cells with label 0, shape: (N1, 150)\n",
    "mu_x_1 = mu_x[mask1]  # cells with label 1, shape: (N2, 150)\n",
    "mu_x_2 = mu_x[mask2]  # cells with label 2, shape: (N3, 150)\n",
    "mu_x_3 = mu_x[mask3]  # cells with label 3, shape: (N4, 150)\n",
    "print(\"mu_x_0\", mu_x_0.shape)\n",
    "print(\"mu_x_1\", mu_x_1.shape)\n",
    "print(\"mu_x_2\", mu_x_2.shape)\n",
    "print(\"mu_x_3\", mu_x_3.shape)\n",
    "\n",
    "\n",
    "def differential(mu_x_a, mu_x_b, alpha, score_thresh, label):\n",
    "    Na = mu_x_a.shape[0]\n",
    "    Nb = mu_x_b.shape[0]\n",
    "    G = mu_x_a.shape[1]\n",
    "    if (mu_x_b.shape[1] != G):\n",
    "        raise ValueError(\"G must be\", G, \"but I found\", mu_x_b.shape[1])\n",
    "\n",
    "    pH1_H0 = torch.full((G,), 0.0)\n",
    "    \n",
    "    # For each gene g, do a comparison between all Na to all Nb cells,\n",
    "    # for each decide if h1 wins (h1+1) or h0 wins (h0+1), save the counts, and\n",
    "    # calculate \n",
    "    pH1_H0 = h1/h0\n",
    "    mask_diff  = pH1_H0 > alpha\n",
    "  \n",
    "    \n",
    "    count_diff = mask_diff.sum().item()     # number of genes where H1 > H0\n",
    "    idx_diff   = torch.nonzero(mask_diff, as_tuple=True)[0]  # tensor of indices\n",
    "    print(count_diff, \"genes different expression for\", label)\n",
    "    print(idx_diff)\n",
    "\n",
    "    \n",
    "alpha = \n",
    "diff_thresh =\n",
    "differential(mu_x_0, mu_x_1, alpha, diff_thresh, \"0-1\")\n",
    "differential(mu_x_0, mu_x_2, alpha, diff_thresh, \"0-2\")\n",
    "differential(mu_x_0, mu_x_3, alpha, diff_thresh, \"0-3\")\n",
    "differential(mu_x_1, mu_x_2, alpha, diff_thresh, \"1-2\")\n",
    "differential(mu_x_1, mu_x_3, alpha, diff_thresh, \"1-3\")\n",
    "differential(mu_x_2, mu_x_3, alpha, diff_thresh, \"2-3\")\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
