{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JLsXiA6l5T91"
   },
   "source": [
    "# **b3 Transformers**\n",
    "## **homework — ANSWER KEY**\n",
    "\n",
    "### In this b3_homework, we are going to build a self-attention mechanism from scratch.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bdr6LQIg3CRu"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import torch\n",
    "import matplotlib.pyplot as plt\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "mo5mpQc-4Tcv"
   },
   "source": [
    "The self-attention mechanism described in Figure 3, maps an input X[L,D] into an output O[L,dv]. In this homework we are assuming dk = dv = D.\n",
    "\n",
    "Our inputs are going first sequences of length L with embedding dimension D, and later multiple sequence alignments (MSAs). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WLQDcF4U5ZPq"
   },
   "outputs": [],
   "source": [
    "# Set seed, this way you will get the same random values every time you run it\n",
    "np.random.seed(3)\n",
    "\n",
    "# We start with the attention mechanism for sequences (1 dimension)\n",
    "#\n",
    "# Input dimensions\n",
    "L = 25 # sequence dimension\n",
    "D = 4 # embedding dimension per sequence position\n",
    "\n",
    "# Create a random input sequence\n",
    "# X[L,D]\n",
    "# the actual values are sampled from a N(0,1) normal distribution.\n",
    "x = []\n",
    "for l in range(L):\n",
    "  x.append(np.random.normal(size=(D)))\n",
    "\n",
    "x = np.array(x)\n",
    "print(x.shape)\n",
    "# Print out the inputs\n",
    "print(\"input lenght\", len(x))\n",
    "print(\"input embedding of position 0\", x[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dp5IDPt27s2e"
   },
   "outputs": [],
   "source": [
    "# We also need to create an initialize the self-attention weights\n",
    "#\n",
    "# setting dk=dv=D, we need\n",
    "#\n",
    "# W^Q[D,D]\n",
    "# W^K[D,D]\n",
    "# W^V[D,D]\n",
    "#\n",
    "# and the biases\n",
    "#\n",
    "#b^Q[D]\n",
    "#b^K[D]\n",
    "#b^V[D]\n",
    "#\n",
    "WQ = np.random.normal(size=(D,D))\n",
    "WK = np.random.normal(size=(D,D))\n",
    "WV = np.random.normal(size=(D,D))\n",
    "bQ = np.random.normal(size=(D))\n",
    "bK = np.random.normal(size=(D))\n",
    "bV = np.random.normal(size=(D))\n",
    "\n",
    "print(\"Weigths shape\", WQ.shape)\n",
    "print(\"biases shape\", bQ.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "dYGEP-H789H4"
   },
   "outputs": [],
   "source": [
    "# ANSWER KEY — Section 1: Q/K/V computation\n",
    "#\n",
    "# q[L,D] = x[L,D] @ WQ[D,D] + bQ[D]\n",
    "# k[L,D] = x[L,D] @ WK[D,D] + bK[D]\n",
    "# v[L,D] = x[L,D] @ WV[D,D] + bV[D]\n",
    "#\n",
    "# NOTE: The original starter code had a typo — the third line\n",
    "#       assigned q instead of v. Students should catch this.\n",
    "\n",
    "q = x @ WQ + bQ\n",
    "k = x @ WK + bK\n",
    "v = x @ WV + bV  # <-- fixed from starter code typo\n",
    "\n",
    "print(\"queries shape\", q.shape)   # (25, 4)\n",
    "print(\"keys    shape\", k.shape)   # (25, 4)\n",
    "print(\"values  shape\", v.shape)   # (25, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "yf5TIUVOA_m2"
   },
   "outputs": [],
   "source": [
    "# ANSWER KEY — Section 2: Single-head self-attention\n",
    "\n",
    "# (1a) Transpose K\n",
    "k_T = k.T\n",
    "print(\"k_T shape\", k_T.shape)  # (4, 25)\n",
    "\n",
    "# (1b) Scaled dot-product score: Q @ K^T / sqrt(D)\n",
    "score = (q @ k_T) / np.sqrt(D)\n",
    "print(\"score shape\", score.shape)  # (25, 25)\n",
    "\n",
    "# (2) Softmax over the key dimension (last axis)\n",
    "attn = torch.softmax(torch.tensor(score), dim=-1).numpy()\n",
    "print(\"attn shape\", attn.shape)   # (25, 25)\n",
    "print(\"attn row 0 sums to\", attn[0].sum())  # should be ~1.0\n",
    "\n",
    "# (3) Output: attention-weighted values\n",
    "out = attn @ v\n",
    "print(\"out shape\", out.shape)  # (25, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "vfnrtTiiW0KP"
   },
   "outputs": [],
   "source": [
    "## ANSWER KEY — Section 3: Multi-head attention (2 independent heads)\n",
    "\n",
    "# Input dimensions\n",
    "L = 25 # sequence dimension\n",
    "D = 4  # embedding dimension per sequence position\n",
    "\n",
    "# Create a random input sequence\n",
    "x = []\n",
    "for l in range(L):\n",
    "  x.append(np.random.normal(size=(D)))\n",
    "x = np.array(x)\n",
    "print(x.shape)\n",
    "print(\"input lenght\", len(x))\n",
    "print(\"input embedding of position 0\", x[0])\n",
    "\n",
    "# Added dimensions to deal with 2 heads\n",
    "H = 2             # number of heads\n",
    "D_H = int(D / 2)  # dimension per head\n",
    "print(\"D_H\", D_H)  # 2\n",
    "\n",
    "# ----- Parameter initialization -----\n",
    "\n",
    "# Head 1: W's are [D, D_H], biases are [D_H]\n",
    "WQ_1 = np.random.normal(size=(D, D_H))\n",
    "WK_1 = np.random.normal(size=(D, D_H))\n",
    "WV_1 = np.random.normal(size=(D, D_H))\n",
    "bQ_1 = np.random.normal(size=(D_H,))\n",
    "bK_1 = np.random.normal(size=(D_H,))\n",
    "bV_1 = np.random.normal(size=(D_H,))\n",
    "\n",
    "# Head 2\n",
    "WQ_2 = np.random.normal(size=(D, D_H))\n",
    "WK_2 = np.random.normal(size=(D, D_H))\n",
    "WV_2 = np.random.normal(size=(D, D_H))\n",
    "bQ_2 = np.random.normal(size=(D_H,))\n",
    "bK_2 = np.random.normal(size=(D_H,))\n",
    "bV_2 = np.random.normal(size=(D_H,))\n",
    "\n",
    "# WC: concatenation weight\n",
    "# After concatenating H heads of dim D_H each, we get [L, H*D_H] = [L, D]\n",
    "# WC maps [D] -> [D], so shape is [D, D]\n",
    "WC = np.random.normal(size=(D, D))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZZS0gL0AaC7G"
   },
   "outputs": [],
   "source": [
    "# ANSWER KEY — Section 3 (cont): Multi-head forward pass\n",
    "\n",
    "# (1) Build Q/K/V per head — shape [L, D_H]\n",
    "q_1 = x @ WQ_1 + bQ_1\n",
    "k_1 = x @ WK_1 + bK_1\n",
    "v_1 = x @ WV_1 + bV_1\n",
    "\n",
    "q_2 = x @ WQ_2 + bQ_2\n",
    "k_2 = x @ WK_2 + bK_2\n",
    "v_2 = x @ WV_2 + bV_2\n",
    "\n",
    "# (2) Attention scores — scale by sqrt(D_H), NOT sqrt(D)\n",
    "score_1 = (q_1 @ k_1.T) / np.sqrt(D_H)  # [L, L]\n",
    "score_2 = (q_2 @ k_2.T) / np.sqrt(D_H)  # [L, L]\n",
    "\n",
    "# (3) Softmax\n",
    "attn_1 = torch.softmax(torch.tensor(score_1), dim=-1).numpy()\n",
    "attn_2 = torch.softmax(torch.tensor(score_2), dim=-1).numpy()\n",
    "\n",
    "# (4) Weighted values per head — [L, D_H]\n",
    "out_1 = attn_1 @ v_1\n",
    "out_2 = attn_2 @ v_2\n",
    "\n",
    "# (5) Concatenate heads along embedding dimension — [L, D]\n",
    "out = np.concatenate([out_1, out_2], axis=-1)\n",
    "print(\"out1 shape\", out_1.shape)  # (25, 2)\n",
    "print(\"out shape\", out.shape)     # (25, 4)\n",
    "\n",
    "# (6) Final linear projection\n",
    "print(\"WC shape\", WC.shape)       # (4, 4)\n",
    "out = out @ WC\n",
    "print(\"out shape\", out.shape)     # (25, 4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ANSWER KEY — Section 4: Multi-head attention, all heads via reshape\n",
    "\n",
    "# Input\n",
    "L = 60  # sequence dimension\n",
    "D = 20  # embedding dimension per sequence position\n",
    "x = []\n",
    "for l in range(L):\n",
    "  x.append(np.random.normal(size=(D)))\n",
    "x = np.array(x)\n",
    "print(\"X\", x.shape)  # (60, 20)\n",
    "\n",
    "# (1) Head dimension\n",
    "H   = 4\n",
    "D_H = int(D / H)  # = 5\n",
    "\n",
    "# (2) Combined parameters for all heads\n",
    "#     W's are [D, D] because they jointly produce all H heads\n",
    "#     biases are [D]\n",
    "WQ = np.random.normal(size=(D, D))\n",
    "WK = np.random.normal(size=(D, D))\n",
    "WV = np.random.normal(size=(D, D))\n",
    "bQ = np.random.normal(size=(D,))\n",
    "bK = np.random.normal(size=(D,))\n",
    "bV = np.random.normal(size=(D,))\n",
    "\n",
    "# (3) Q/K/V as [L, D]\n",
    "q = x @ WQ + bQ\n",
    "k = x @ WK + bK\n",
    "v = x @ WV + bV\n",
    "\n",
    "# (4) Reshape to [L, H, D_H]\n",
    "q = q.reshape(L, H, D_H)\n",
    "k = k.reshape(L, H, D_H)\n",
    "v = v.reshape(L, H, D_H)\n",
    "print(\"q\", q.shape)  # (60, 4, 5)\n",
    "\n",
    "# (5) Transpose to [H, L, D_H] — move head dim to front for batched matmul\n",
    "q = np.transpose(q, (1, 0, 2))\n",
    "k = np.transpose(k, (1, 0, 2))\n",
    "v = np.transpose(v, (1, 0, 2))\n",
    "print(\"q\", q.shape)  # (4, 60, 5)\n",
    "\n",
    "# (6) Attention scores [H, L, L]\n",
    "#     k transposed on last two dims: [H, D_H, L]\n",
    "score = (q @ np.transpose(k, (0, 2, 1))) / np.sqrt(D_H)\n",
    "\n",
    "# (7) Softmax [H, L, L]\n",
    "attn = torch.softmax(torch.tensor(score), dim=-1).numpy()\n",
    "\n",
    "# (8) Output [H, L, D_H]\n",
    "out = attn @ v\n",
    "print(\"out\", out.shape)  # (4, 60, 5)\n",
    "\n",
    "# (9) Reshape back to [L, D]\n",
    "#     First transpose to [L, H, D_H], then flatten last two dims\n",
    "out = np.transpose(out, (1, 0, 2))   # [L, H, D_H]\n",
    "out = out.reshape(L, D)               # [L, D]\n",
    "print(\"out\", out.shape)  # (60, 20)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# ANSWER KEY — Section 5: MSA attention (row and column)\n",
    "\n",
    "B = 10   # batch size\n",
    "S = 50   # number of sequences\n",
    "L = 100  # alignment length\n",
    "D = 80   # embedding dimension\n",
    "msa = torch.rand(B, S, L, D)\n",
    "\n",
    "########################################\n",
    "# ATTENTION BY ROWS\n",
    "########################################\n",
    "\n",
    "# (1) Random weights (torch tensors)\n",
    "WQ = torch.randn(D, D)\n",
    "WK = torch.randn(D, D)\n",
    "WV = torch.randn(D, D)\n",
    "bQ = torch.randn(D)\n",
    "bK = torch.randn(D)\n",
    "bV = torch.randn(D)\n",
    "\n",
    "# (2) Reshape MSA to [B*S, L, D] — each row is an independent sequence\n",
    "msa_row = msa.reshape(B * S, L, D)\n",
    "print(\"msa\", msa_row.shape)  # (500, 100, 80)\n",
    "\n",
    "# Q/K/V via batched matmul: [B*S, L, D] @ [D, D] + [D]\n",
    "q = msa_row @ WQ + bQ\n",
    "k = msa_row @ WK + bK\n",
    "v = msa_row @ WV + bV\n",
    "print(\"q by row\", q.shape)  # (500, 100, 80)\n",
    "\n",
    "# (3) Score [B*S, L, L]\n",
    "score = (q @ k.transpose(-2, -1)) / math.sqrt(D)\n",
    "\n",
    "# (4) Attention [B*S, L, L]\n",
    "attn = torch.softmax(score, dim=-1)\n",
    "\n",
    "# (5) Output [B*S, L, D]\n",
    "out = attn @ v\n",
    "print('out by row', out.shape)  # (500, 100, 80)\n",
    "\n",
    "# (6) Reshape back to [B, S, L, D]\n",
    "out = out.reshape(B, S, L, D)\n",
    "print('out by row', out.shape)  # (10, 50, 100, 80)\n",
    "\n",
    "########################################\n",
    "# ATTENTION BY COLUMNS\n",
    "########################################\n",
    "\n",
    "# (1) New random weights\n",
    "WQ = torch.randn(D, D)\n",
    "WK = torch.randn(D, D)\n",
    "WV = torch.randn(D, D)\n",
    "bQ = torch.randn(D)\n",
    "bK = torch.randn(D)\n",
    "bV = torch.randn(D)\n",
    "\n",
    "# (2) Reshape MSA to [B*L, S, D]\n",
    "#     We need to swap the S and L axes so that for each position l,\n",
    "#     we attend across all S sequences at that position.\n",
    "#\n",
    "#     IMPORTANT: must use permute (not reshape) to swap S <-> L,\n",
    "#     otherwise data is silently scrambled.\n",
    "msa_col = msa.permute(0, 2, 1, 3).reshape(B * L, S, D)\n",
    "print(\"msa\", msa_col.shape)  # (1000, 50, 80)\n",
    "\n",
    "# Q/K/V\n",
    "q = msa_col @ WQ + bQ\n",
    "k = msa_col @ WK + bK\n",
    "v = msa_col @ WV + bV\n",
    "print(\"q by col\", q.shape)  # (1000, 50, 80)\n",
    "\n",
    "# (3) Score [B*L, S, S]\n",
    "score = (q @ k.transpose(-2, -1)) / math.sqrt(D)\n",
    "\n",
    "# (4) Attention [B*L, S, S]\n",
    "attn = torch.softmax(score, dim=-1)\n",
    "\n",
    "# (5) Output [B*L, S, D]\n",
    "out = attn @ v\n",
    "print(\"out by col\", out.shape)  # (1000, 50, 80)\n",
    "\n",
    "# (6) Reshape back to [B, S, L, D]\n",
    "#     Currently [B*L, S, D] -> first reshape to [B, L, S, D]\n",
    "#     then permute back to [B, S, L, D]\n",
    "out = out.reshape(B, L, S, D)\n",
    "out = out.permute(0, 2, 1, 3)\n",
    "print(\"out by col\", out.shape)  # (10, 50, 100, 80)"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "name": "python",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
