# **b4 LLMs **
## Transformer based encoders, decoders, encoder-decoders






In [None]:
## MultiHeadAttention
## Transformer block

import torch
import torch.nn as nn

# multi‑head self-attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model   = d_model
        self.num_heads = num_heads
        self.d_head    = self.d_model // self.num_heads # H = D/n_heads

        # The weights
        # uses all dimensions equal to d_model
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

    def forward(self,
                x,         # [B, L, D]
                mask=None  # [L, L] with -inf on masked
    ):
        # x: (B, alen, d_model)
        B, L, D = x.shape

        # H x d_head = d_model
        #
        # q[B, H, L, d_head]
        # k[B, H, L, d_head]
        # v[B, H, L, d_head]
        q = self.Wq(x).reshape(B, L, self.num_heads, self.d_head).transpose(1, 2)
        k = self.Wk(x).reshape(B, L, self.num_heads, self.d_head).transpose(1, 2)
        v = self.Wv(x).reshape(B, L, self.num_heads, self.d_head).transpose(1, 2)

        # q[B, H, L, d_head]
        # k.transpose(-2,-1)[B, H, d_head, L]
        # attn[B, H, L, L]
        # v   [B, H, L, d_head]
        # out [B, H, L, d_head]
        scores = q @ k.transpose(-2, -1) / (self.d_head ** 0.5)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = torch.softmax(scores, dim=-1)

        # print
        print("Attention head dimensions [B, H, L, L]", attn.shape)

        out = attn @ v  # (B, heads, L, d_head)

        # concatenate all heads together
        # out.transpose(1, 2)[B, L, H, d_head]
        # out[B, L, H*d_head]
        out = out.transpose(1, 2).reshape(B, L, D)

        # apply one last FC layer
        out = self.Wo(out)

        return out



# simple Transformer block = attention -> FF
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()

        self.attn = MultiHeadAttention(d_model, num_heads)

        self.ff = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.norm_att = nn.LayerNorm(d_model)
        self.norm_ff  = nn.LayerNorm(d_model)

    def forward(self, x, mask=None):
        x = x + self.attn(self.norm_att(x), mask)
        x = x + self.ff(self.norm_ff(x))
        return x



In [None]:
# Example usage with
#
B = 2   # batch size
L = 10  # sequence length
D = 90  # embedding dimension
d_diff  = 512
n_heads = 2

# random sequences of length L and embedding dimension D
x = torch.randn(B, L, D)  # randomx

# This TransformerBlock produces as output another sequence with same embedding dimension,
# we usually refer to this model as an encoder
model = TransformerBlock(d_model=D, num_heads=n_heads, d_ff=d_diff)
x_out = model(x)

print("input  ", x.shape)      # (B, L, D)
print("output ", x_out.shape)  # (B, L, D)


In [None]:
# ENCODER
class TransformerEncoder(nn.Module):
    def __init__(self, abc_size, d_model, K, n_heads, d_ff, max_len):
        super().__init__()

        self.token_emb = nn.Embedding(abc_size, d_model)
        self.pos_emb   = nn.Embedding(max_len,  d_model)

        self.layers = nn.ModuleList([
            TransformerBlock(d_model=d_model, num_heads=n_heads, d_ff=d_diff)
            for _ in range(K)
        ])

        # the prediction layer used for Masked‑Language‑Modeling (MLM)
        # each embedded vector is projected back predict a token in the input vocabulary
        self.mlm_head = nn.Linear(d_model, abc_size)

    def forward(self, tokens, mask=None):
        B, L = tokens.shape
        pos = torch.arange(L, device=tokens.device)
        x = self.token_emb(tokens) + self.pos_emb(pos)

        for layer in self.layers:
            x = layer(x, mask)

        return self.mlm_head(x)

In [None]:
encoder = TransformerEncoder(
    abc_size=4,
    d_model=512,
    K=6,
    n_heads=8,
    d_ff=2048,
    max_len=512
)

# random sequence[B,L] with values in [0,1,2,3]
B=20
L=400
tokens = torch.randint(0, 4, size=(B,L))
output = encoder(tokens)
print(output.shape)  # (B, L, d_model)


In [None]:
# DECODER-only
#
import numpy as np

class PositionalEncoding(nn.Module):
    def __init__(self, d_model: int, max_len: int = 2048):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * (-np.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)  # [1, max_len, d_model]
        self.register_buffer("pe", pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, L, D]
        L = x.size(1)
        return x + self.pe[:, :L]

class Decoder(nn.Module):
    def __init__(self, abc_size, d_model, K, n_heads, d_ff, max_len):
        super().__init__()
        self.abc_size = abc_size
        self.d_model  = d_model
        self.max_len  = max_len

        self.token_emb = nn.Embedding(abc_size, d_model)
        self.pos_enc   = PositionalEncoding(d_model, max_len=max_len)

        self.layers = nn.ModuleList([
            TransformerBlock(d_model=d_model, num_heads=n_heads, d_ff=d_diff)
            for _ in range(K)
        ])

        self.ln = nn.LayerNorm(d_model)
        self.ff = nn.Linear(d_model, abc_size, bias=False)

    # if i <= j: mask[ij] = -infy,  attn_score + mask = -inf      , attn = 0
    # if i >  j: mask[ij] = 0    ,  attn_score + mask = attn_score, attn unchanged
    #
    def _causal_mask(self, seq_len: int, device):
        # upper triangular (1 above diag) then convert mask to value -inf
        mask = torch.triu(torch.ones(seq_len, seq_len, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float("-inf"))

        print("mask[0,1] = ", mask[0,1])
        print("mask[1,0] = ", mask[1,0])
        return mask  # [L, L]

    def forward(self, inputs: torch.Tensor):
        """
        inputs: [B, L]
        returns: logits [B, L, vocab_size]
        """
        B, L = inputs.shape
        print("inputs\nB L maxL ", B, L, self.max_len)

        if L > self.max_len:
            raise ValueError(f"seq_len {L} > max_len {self.max_len}")
        device = inputs.device

        # inputs embedding
        x = self.token_emb(inputs) * np.sqrt(self.d_model)  # [B, L, D]
        print("inputs embedded ", x.shape)
        x = self.pos_enc(x)
        print("inputs embedded + pos encoding", x.shape)

        attn_mask = self._causal_mask(L, device=device)

        for layer in self.layers:
            x = layer(x, mask=attn_mask)

        x = self.ln(x)
        logits = self.ff(x)  # [B, L, V=abc_size]
        return logits

    @torch.no_grad()
    def generate_next_token(self, inputs: torch.Tensor) -> torch.Tensor:
        """
        input_ids: [B, L] context
        returns: [B] next-token  (takes the val with max prob)
        """
        logits = self.forward(inputs)             # [B, L, V]
        last_logits = logits[:, -1, :]            # [B, V]
        next_token = last_logits.argmax(dim=-1)   # [B]
        return next_token


# ----- Example: generate next token -----

if __name__ == "__main__":
    torch.manual_seed(0)

    abc_size = 4
    decoder = Decoder(
        abc_size=abc_size,
        d_model=256,
        K=2,
        n_heads=8,
        d_ff=512,
        max_len=32,
    )

    # Dummy input: batch=1, seq_len=5
    # ATTCG
    input = torch.tensor([[0, 3, 3, 1, 2]])  # shape [1, 5]
    print("input:", input.tolist())

    next_token = decoder.generate_next_token(input)
    print("next token:", next_token.shape, next_token.item())

    # If you want to extend the sequence by 1:
    extended = torch.cat([input, next_token.unsqueeze(1)], dim=1)
    print("Extended:", extended.tolist())

In [None]:
#===========================
# ENCODER-DECODER
#===========================
#
# The encoder of the Vaswani encoder-decoder is similar to the encoder-only
# The decoder is a modification of the decoder only as it includes two MHA, one of them is standard (self) te other one is crossed.
#
# We are going to start by modifying the MHA rutine so it can work both for self-attention and cross-attention.
# It requires to having 3 x inputs insted of one, so you can use whichever you want to make the Q/K/V projections.
#
# self attention:  Q = WQ @ x
#                  K = WK @ x
#                  V = WV @ x
#
# cross attention: Q = WQ @ x_target
#                  K = WK @ x_source
#                  V = WV @ x_source
#
# MultiHeadAttentionGeneral
#    multi‑head s(elf/cross) attention
#
class MultiHeadAttentionGeneral(nn.Module):
    def __init__(self, d_model, num_heads, dropout: float = 0.1):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model   = d_model
        self.num_heads = num_heads
        self.d_head    = self.d_model // self.num_heads # H = D/n_heads

        # The weights
        # uses all dimensions equal to d_model
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Wo = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

    def forward(self,
                x_t,         # [B, Lt, D]  Lt = target sequence length
                x_s,         # [B, Ls, D]  Ls = source sequence length
                mask=None    # [Lt, Ls] with -inf on masked
    ):
        # x: (B, alen, d_model)
        B, Lt, D = x_t.shape
        B, Ls, D = x_s.shape

        # H x d_head = d_model
        #
        # q[B, H, Lt, d_head]
        # k[B, H, Ls, d_head]
        # v[B, H, Ls, d_head]
        q = self.Wq(x_t).reshape(B, Lt, self.num_heads, self.d_head).transpose(1, 2)
        k = self.Wk(x_s).reshape(B, Ls, self.num_heads, self.d_head).transpose(1, 2)
        v = self.Wv(x_s).reshape(B, Ls, self.num_heads, self.d_head).transpose(1, 2)

        # q[B, H, Lt, d_head]
        # k.transpose(-2,-1)[B, H, d_head, Ls]
        # attn[B, H, Lt, Ls]
        # v   [B, H, Ls, d_head]
        # out [B, H, Lt, d_head]
        scores = q @ k.transpose(-2, -1) / (self.d_head ** 0.5)
        if mask is not None:
            # If mask is [Lt, Ls], broadcast to [1, 1, Lt, Ls]
            if mask.dim() == 2:
                scores = scores + mask.view(1, 1, Lt, Ls)
            else:
                # [B, Lt, Ts] -> [B, 1, Lt, Ls]
                scores = scores + mask.unsqueeze(1)

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)
        # print
        print("Attention head dimensions [B, H, Lt, Ls]", attn.shape)

        out = attn @ v  # (B, heads, Lt, d_head)

        # concatenate all heads together
        # out.transpose(1, 2)[B, Lt, H, d_head]
        # out[B, Lt, H*d_head]
        out = out.transpose(1, 2).reshape(B, Lt, D)

        # apply one last FC layer
        out = self.Wo(out)

        return out


In [None]:
# EncoderBlock
#
# similar to the EncoderOnly but re-writen using the general MHA
#
# also including dropout
#
class EncoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttentionGeneral(d_model, n_heads, dropout)
        self.ln1 = nn.LayerNorm(d_model)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, s_mask: torch.Tensor | None = None):
        # Self-attention
        print("encoder self-attention")
        attn_out = self.self_attn(x, x, mask=s_mask)

        x = x + self.dropout(attn_out)
        x = self.ln1(x)

        # FFN
        ff_out = self.ffn(x)
        x = x + self.dropout(ff_out)
        x = self.ln2(x)
        return x

# DecoderBlock
#
# the decoder block is different from the decoder only model because it includes
# a masked self-attention transformer (for the emerging target sequence)
# and a cross-attention decoder (between the target (queries) and source (keys, values)
#
# also including dropout
#
class DecoderBlock(nn.Module):
    def __init__(self, d_model: int, n_heads: int, d_ff: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttentionGeneral(d_model, n_heads, dropout)
        self.ln1 = nn.LayerNorm(d_model)

        self.cross_attn = MultiHeadAttentionGeneral(d_model, n_heads, dropout)
        self.ln2 = nn.LayerNorm(d_model)

        self.ffn = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )
        self.ln3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x_t: torch.Tensor,                 # [B, L_t, D] target sequence
        x_s: torch.Tensor,                 # [B, L_s, D] source sequence
        t_mask: torch.Tensor | None,       # [L_t, L_t] (causal)
        s_mask: torch.Tensor | None,       # [L_t, L_s]
    ):
        # 1) Masked self-attention over target
        print("\ndecoder self-attention")
        attn_out = self.self_attn(x_t, x_t, mask=t_mask)
        print("^^attn out", attn_out.shape, attn_out[0])

        x_t = x_t + self.dropout(attn_out)
        x_t = self.ln1(x_t)

        # 2) Cross-attention: Q=decoder x_t, K/V=encoder x_s
        print("\ndecoder cross-attention")
        cross_out = self.cross_attn(x_t, x_s, mask=s_mask)

        x_t = x_t + self.dropout(cross_out)
        x_t = self.ln2(x_t)

        # 3) FFN
        ff_out = self.ffn(x_t)
        x_t = x_t + self.dropout(ff_out)
        x_t = self.ln3(x_t)


        return x_t


# =========================
# Full Encoder / Decoder
# =========================

class TransformerEncoder(nn.Module):
    def __init__(
        self,
        abc_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        K: int = 6,
        d_ff: int = 2048,
        max_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.token_emb = nn.Embedding(abc_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len=max_len)

        self.layers = nn.ModuleList([
            EncoderBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(K)
        ])
        self.ln_f = nn.LayerNorm(d_model)

    def forward(self, src_ids: torch.Tensor, s_mask: torch.Tensor | None = None):
        # src_ids: [B, L_s]
        x_s = self.token_emb(src_ids) * np.sqrt(self.token_emb.embedding_dim)
        x_s = self.pos_enc(x_s)

        for layer in self.layers:
            x_s = layer(x_s, s_mask=s_mask)

        x_s = self.ln_f(x_s)
        return x_s  # [B, L_s, D]


class TransformerDecoder(nn.Module):
    def __init__(
        self,
        abc_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        K: int = 6,
        d_ff: int = 2048,
        max_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.token_emb = nn.Embedding(abc_size, d_model)
        self.pos_enc = PositionalEncoding(d_model, max_len=max_len)

        self.layers = nn.ModuleList([
            DecoderBlock(d_model, n_heads, d_ff, dropout)
            for _ in range(K)
        ])
        self.ln_f = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, abc_size, bias=False)

    def _causal_mask(self, L: int, device):
        # [L, L] with -inf above diagonal
        mask = torch.triu(torch.ones(L, L, device=device), diagonal=1)
        mask = mask.masked_fill(mask == 1, float("-inf"))
        return mask

    def forward(
        self,
        tgt_ids: torch.Tensor,               # [B, L_t]
        x_s: torch.Tensor,                   # [B, L_s, D]
        s_mask: torch.Tensor | None = None,  # often None
    ):
        B, L_t = tgt_ids.shape
        device = tgt_ids.device

        x_t = self.token_emb(tgt_ids) * np.sqrt(self.token_emb.embedding_dim)
        x_t = self.pos_enc(x_t)

        # Causal mask for decoder self-attn
        t_mask = self._causal_mask(L_t, device=device)  # [L_t, L_t]

        for layer in self.layers:
            x_t = layer(
                x_t,
                x_s=x_s,
                t_mask=t_mask,
                s_mask=s_mask,
            )

        x_t = self.ln_f(x_t)
        logits = self.head(x_t)  # [B, L_t, abc_size]
        return logits


# =========================
# Full Seq2Seq Transformer
# =========================

class TransformerSeq2Seq(nn.Module):
    def __init__(
        self,
        s_abc_size: int,
        t_abc_size: int,
        d_model: int = 512,
        n_heads: int = 8,
        K: int = 6,
        d_ff: int = 2048,
        max_len: int = 512,
        dropout: float = 0.1,
    ):
        super().__init__()
        self.encoder = TransformerEncoder(
            s_abc_size,
            d_model=d_model,
            n_heads=n_heads,
            K=K,
            d_ff=d_ff,
            max_len=max_len,
            dropout=dropout,
        )
        self.decoder = TransformerDecoder(
            t_abc_size,
            d_model=d_model,
            n_heads=n_heads,
            K=K,
            d_ff=d_ff,
            max_len=max_len,
            dropout=dropout,
        )

    def forward(
        self,
        src_ids: torch.Tensor,      # [B, L_s]
        tgt_ids: torch.Tensor,      # [B, L_t]
        s_mask: torch.Tensor | None = None,
        memory_mask: torch.Tensor | None = None,
    ):
        x_s    = self.encoder(src_ids, s_mask=s_mask)  # [B, L_s, D]
        logits = self.decoder(tgt_ids, x_s, s_mask=memory_mask)
        print("logits", logits)

        return logits  # [B, L_t, t_abc_size]




In [None]:
# usage
# extract an mRNA sequenc from a protein sequence
#
import torch.nn.functional as F

if __name__ == "__main__":
    torch.manual_seed(0)

    prot_abc = 20 # protein abc
    rna_abc  = 4  # RNA abc
    model = TransformerSeq2Seq(
        s_abc_size=prot_abc,
        t_abc_size=rna_abc,
        d_model=256,
        n_heads=4,
        K=2,
        d_ff=512,
        max_len=300,
    )

    B = 1
    L_s = 100
    L_t = 300

    protein = torch.randint(0, prot_abc, (B, L_s))
    mrna    = torch.randint(0, rna_abc,  (B, L_t))  # teacher-forcing input

    # tgt_out would be tgt_in shifted right in a real training loop
    mrna_in  = mrna[:, :-1]   # decoder input (all but last)
    mrna_out = mrna[:, 1:]    # prediction targets (all but first)
    logits  = model(protein, mrna_in)       # [B, L-1, 4]
    print("mrna_out shape:", mrna_out.shape)
    print("Logits shape:", logits.shape)

    loss = F.cross_entropy(
      logits.reshape(-1, rna_abc), # [B*(Lt-1), 4]
      mrna_out.reshape(-1),        # [B*(Lt-1)]
    )
    loss.backward()
    print("Loss:", loss.item())

#Suggested project.

Modify the previous TransformerSeq2Seq() so that the mRNA tokens are triplets (which make all the sense for mRNAs, similar to DNABERT). Then use this model to translate proteins into mRNAs, and then investigate if the cross-attention maps allows you to discover the genetic code, and maybe variations to the genetic code


In [4]:
# DNABERT-like tokenization
#
#
import torch
from itertools import product

# for k, the number of words in the alphabet is
# 4^k + 4
def build_kmer_abc(k=6):
    alphabet = ['A', 'C', 'G', 'T']
    kmers = [''.join(p) for p in product(alphabet, repeat=k)]
    abc = {kmer: i+4 for i, kmer in enumerate(kmers)}  # reserve 0-3 for special
    abc['[PAD]'] = 0
    abc['[UNK]'] = 1
    abc['[CLS]'] = 2
    abc['[SEP]'] = 3
    return abc

def dna_to_abc_idx(seq, abc, k=6):
    tokens = []
    tokens.append(abc['[CLS]'])
    for i in range(len(seq) - k + 1):
        kmer = seq[i:i+k]
        tokens.append(abc.get(kmer, abc['[UNK]']))
    tokens.append(abc['[SEP]'])
    return torch.tensor(tokens, dtype=torch.long)

abc6 = build_kmer_abc(k=6)
print("abc k=6", len(abc6))
seq = "ACGTACGT"
tokens_ids = dna_to_abc_idx(seq, abc6, k=6)
print("6-kmer tokenization of seq", seq, "\n", tokens_ids)

abc3 = build_kmer_abc(k=3)
print("abc k=3", len(abc3))
seq = "ACGTACGT"
tokens_ids = dna_to_abc_idx(seq, abc3, k=3)
print("3-kmer tokenization of seq", seq, "\n", tokens_ids)

abc k=6 4100
6-kmer tokenization of seq ACGTACGT 
 tensor([   2,  437, 1738, 2847,    3])
abc k=3 68
3-kmer tokenization of seq ACGTACGT 
 tensor([ 2, 10, 31, 48, 53, 10, 31,  3])
