In [1]:
import torch
import os

# Simple transformer layer

In [2]:
import torch
import torch.nn as nn

# Simple self-attention
class SelfAttention(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.d_model = d_model

        # uses all dimensions equal to d_model
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (B, alen, d_model)
        B, L, D = x.shape

        # q[B, L, d_model]
        # k[B, L, d_model]
        # v[B, L, d_model]
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # q                 [B, L,      d_model]
        # k.transpose(-2,-1)[B, d_model, L]
        # attn[B, L, L]
        # v   [B, L, d_model]
        # out [B, L, d_model]
        # softmax wrt the L dimension comming from the keys
        attn = torch.softmax(q @ k.transpose(-2, -1) / (self.d_model ** 0.5), dim=-1)
        out = attn @ v  # (B, L, d_model)

        return out


# simple Transformer block = attention -> MLP
class SimpleTransformerBlock(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()

        self.attn = SelfAttention(d_model)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.norm_att = nn.LayerNorm(d_model)
        self.norm_ff  = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(x)
        x = x + self.mlp(self.norm_att(x))
        x = self.norm_ff(x)
        return x


# Example usage with
#
B = 2   # batch size
L = 10  # sequence lengrh
D = 90  # embedding dimension
d_ff = 512
n_heads = 2

# random sequences of length L and embedding dimension D
x = torch.randn(B, L, D)  # randomx

# This MSATransformerBlock produces as output another sequence with same embedding dimension,
# we usually refer to this model as an encoder
model = SimpleTransformerBlock(d_model=D, d_ff=d_ff)
x_out = model(x)

# Transformer layer with multi-head self-attention

In [3]:
import torch
import torch.nn as nn

# multi‑head sefl-attention
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_model   = d_model
        self.num_heads = num_heads
        self.d_head    = self.d_model // self.num_heads # H = D/n_heads


        # uses all dimensions equal to d_model
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (B, alen, d_model)
        B, L, D = x.shape

        # H x d_head = d_model
        #
        # q[B, H, L, d_head]
        # k[B, H, L, d_head]
        # v[B, H, L, d_head]
        q = self.q_proj(x).reshape(B, L, self.num_heads, self.d_head).transpose(1, 2)
        k = self.k_proj(x).reshape(B, L, self.num_heads, self.d_head).transpose(1, 2)
        v = self.v_proj(x).reshape(B, L, self.num_heads, self.d_head).transpose(1, 2)

        # q[B, H, L, d_head]
        # k.transpose(-2,-1)[B, H, d_head, L]
        # attn[B, H, L, L]
        # v   [B, H, L, d_head]
        # out [B, H, L, d_head]
        attn = torch.softmax(q @ k.transpose(-2, -1) / (self.d_head ** 0.5), dim=-1)

        # print
        print("Attention head dimensions [B, H, L, L]", attn.shape)

        out = attn @ v  # (B, heads, L, d_head)

        # concatenate all heads together
        # out.transpose(1, 2)[B, L, H, d_head]
        # out[B, L, H*d_head]
        out = out.transpose(1, 2).reshape(B, L, D)

        # apply one last FC layer
        out = self.o_proj(out)

        return out



# simple Transformer block = attention -> MLP
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()

        self.attn = MultiHeadAttention(d_model, num_heads)

        self.mlp = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.ReLU(),
            nn.Linear(d_ff, d_model),
        )

        self.norm_att = nn.LayerNorm(d_model)
        self.norm_ff  = nn.LayerNorm(d_model)

    def forward(self, x):
        x = x + self.attn(self.norm_att(x))
        x = x + self.mlp(self.norm_ff(x))
        return x



In [4]:
# Example usage with
#
B = 2   # batch size
L = 10  # sequence length
D = 90  # embedding dimension
d_diff  = 512
n_heads = 2

# random sequences of length L and embedding dimension D
x = torch.randn(B, L, D)  # randomx

# This TransformerBlock produces as output another sequence with same embedding dimension,
# we usually refer to this model as an encoder
model = TransformerBlock(d_model = D, num_heads=n_heads, d_ff=d_diff)
x_out = model(x)

print("input  ", x.shape)      # (B, L, D)
print("output ", x_out.shape)  # (B, L, D)


Attention head dimensions [B, H, L, L] torch.Size([2, 2, 10, 10])
input   torch.Size([2, 10, 90])
output  torch.Size([2, 10, 90])


In [5]:
# Row attention =
# attention across residues for each sequence (MSA row)
# thus: row[S] is fixed,  L and D move
class RowAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)

    def forward(self, x):
        # x: (B, S, L, d_model)
        B, S, L, D = x.shape
        x = x.reshape(B * S, L, D)
        x = self.attn(x)
        return x.reshape(B, S, L, D)

# Column attention =
# attention across MSA sequences at each residue (MSA column)
# L is fixed, we move on S and D.
# Thus we need to put S, D as the last two dimentsion
class ColumnAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, num_heads)

    def forward(self, x):
        # x: (B, S, alen, d_model)
        B, S, L, D = x.shape
        x = x.transpose(1, 2).reshape(B * L, S, D)
        x = self.attn(x)
        return x.reshape(B, L, S, D).transpose(1, 2)


In [6]:
# Example usage with
#
B = 2   # batch size (number of alignments)
S = 10  # sequences per aligment
L = 50  # sequence length
D = 90  # embedding dimension
n_heads = 2

# random sequences of length L and embedding dimension D
x = torch.randn(B, S, L, D)  # randomx

# row-attention
RA = RowAttention(d_model = D, num_heads=n_heads)
x_out = RA(x)

print("input  ", x.shape)        # (B, S, L, D)
print("RA output", x_out.shape)  # (B, S, L, D)

# col-attention
CA = ColumnAttention(d_model = D, num_heads=n_heads)
x_out = CA(x)

print("input  ", x.shape)        # (B, S, L, D)
print("CA output", x_out.shape)  # (B, S, L, D)



Attention head dimensions [B, H, L, L] torch.Size([20, 2, 50, 50])
input   torch.Size([2, 10, 50, 90])
RA output torch.Size([2, 10, 50, 90])
Attention head dimensions [B, H, L, L] torch.Size([100, 2, 10, 10])
input   torch.Size([2, 10, 50, 90])
CA output torch.Size([2, 10, 50, 90])


In [22]:
## TIED ATTENTIONs
import torch
import torch.nn as nn

# Tied Row Attention (TRA)
# S = rows
# L = cols
#
class TiedRowAttention(nn.Module):
    def __init__(self, d_model, num_heads, tie_mode="mean"):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim  = d_model // num_heads
        self.tie_mode = tie_mode   # "mean" or "sum"

        # shared projections across all rows
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (B, S, L, D)
        B, S, L, D = x.shape

        # flatten rows so projections are tied
        x_flat = x.reshape(B*S, L, D)

        # queries, keys, values [B*S, L, D]
        Q = self.q_proj(x_flat)
        K = self.k_proj(x_flat)
        V = self.v_proj(x_flat)

        # multi‑head split
        # queries, keys, values [B*S, L, H, DH] (reshape)
        # queries, keys, values [B*S, H, L, DH] (transpose 1,2)
        Q = Q.reshape(B*S, L, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.reshape(B*S, L, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.reshape(B*S, L, self.num_heads, self.head_dim).transpose(1, 2)

        # per‑row attention scores
        # R_scores[B*S, H, L, L] = Q[B*S, H, L, DH]*KT[B*S, H, DH, L]
        R_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        # R_scores[B. S, H, L, L]
        R_scores = R_scores.reshape(B, S, self.num_heads, L, L)

        # tie at the attention‑weight level
        # TR_scores[B, 1, H, L, L]
        TR_scores = R_scores.mean(dim=1, keepdim=True)  # (B, 1, H, L, L)

        #softmax AFTER tying
        # TR_attn[B, 1, H, L, L]
        TR_attn = torch.softmax(TR_scores, dim=-1)    # (B, 1, H, L, L)

        # expand tied attention to all rows
        # TR_attn[B, S, H, L, L]
        TR_attn = TR_attn.repeat(1, S, 1, 1, 1)       # (B, S, H, L, L)

        # print TR_attn shape
        print("\nTied-row attention (TR_attn) maps dimensions [B, S, H, L, L]", TR_attn.shape)

        # reshape V back to [B, S, H, L, DH] (from [B*S, H, L, DH])
        V = V.reshape(B, S, self.num_heads, L, self.head_dim)

        # out[B, S, H, L, DH] = TR_attn[B, S, H, L, L] * V[B, S, H, L, DH]
        out = torch.matmul(TR_attn, V)     # (B, S, H, L, DH)

        # merge heads
        # out[B, S, H, L, DH]
        # out[B, S, L, H, DH] (transposed 2,3)
        # out[B, S, L, D]     (merged all heads)
        out = out.transpose(2, 3).reshape(B, S, L, D)

        # project
        out = self.o_proj(out)

        return out

# Tied Column Attention (TCA)
# S = rows
# L = cols
#
class TiedColumnAttention(nn.Module):
    def __init__(self, d_model, num_heads, tie_mode="mean"):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim  = d_model // num_heads
        self.tie_mode = tie_mode   # "mean" or "sum"

        # shared projections across all rows
        self.q_proj = nn.Linear(d_model, d_model)
        self.k_proj = nn.Linear(d_model, d_model)
        self.v_proj = nn.Linear(d_model, d_model)
        self.o_proj = nn.Linear(d_model, d_model)

    def forward(self, x):
        # x: (B, S, L, D)
        B, S, L, D = x.shape

        # transpose
        # x[B, S, L, D] -> x[B, L, S, D]
        x = x.transpose(1,2)

        # flatten rows so projections are tied
        x_flat = x.reshape(B*L, S, D)

        # queries, keys, values [B*L, S, D]
        Q = self.q_proj(x_flat)
        K = self.k_proj(x_flat)
        V = self.v_proj(x_flat)

        # multi‑head split
        # queries, keys, values [B*L, S, H, DH] (reshape)
        # queries, keys, values [B*L, H, S, DH] (transpose 1,2)
        Q = Q.reshape(B*L, S, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.reshape(B*L, S, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.reshape(B*L, S, self.num_heads, self.head_dim).transpose(1, 2)

        # per‑column attention scores
        # C_scores[B*L, H, S, S] = Q[B*L, H, S, DH]*KT[B*L, H, DH, S]
        C_scores = torch.matmul(Q, K.transpose(-2, -1)) / (self.head_dim ** 0.5)
        # scores[B, L, H, S, S]
        C_scores = C_scores.reshape(B, L, self.num_heads, S, S)

        # tie at the attention‑weight level
        # TC_scores[B, 1, H, S, S]
        TC_scores = C_scores.mean(dim=1, keepdim=True)  # (B, 1, H, S, S)

        #softmax AFTER tying
        # TC_attn[B, 1, H, S, S]
        TC_attn = torch.softmax(TC_scores, dim=-1)    # (B, 1, H, S, S)

        # expand tied attention to all rows
        # TC_attn[B, L, H, S, S]
        TC_attn = TC_attn.repeat(1, L, 1, 1, 1)         # (B, L, H, S, S)

        # print TC_attn shape
        print("\ntied-column attention (TC_attn) maps dimensions [B, L, H, S, S]", TC_attn.shape)

        # reshape V to [B, L, H, S, DH] (from [B*L, H, S, DH])
        V = V.reshape(B, L, self.num_heads, S, self.head_dim)

        # out[B, L, H, S, DH] = TC_attn[B, L, H, S, S] * V[B, L, H, S, DH]
        out = torch.matmul(TC_attn, V)     # (B, L, H, S, DH)

        # merge heads
        # out[B, L, H, S, DH]
        # out[B, L, S, H, DH] (transposed 2,3)
        # out[B, L, S, D]     (merged all heads)
        out = out.transpose(2, 3).reshape(B, L, S, D)

        # out[B, S, L, D] <- out[B, L, S, D]
        out = out.transpose(1, 2)

        # project
        out = self.o_proj(out)

        return out

In [23]:
# Example usage with
#
B = 2   # batch size (number of alignments)
S = 10  # sequences per aligment
L = 50  # sequence length
D = 90  # embedding dimension
n_heads = 2

# random sequences of length L and embedding dimension D
x = torch.randn(B, S, L, D)  # randomx

# tied row-attention
TRA = TiedRowAttention(d_model = D, num_heads=n_heads)
x_out = TRA(x)
print("input     ", x.shape)      # (B, S, L, D)
print("TRA output", x_out.shape)  # (B, S, L, D)

# tied col-attention
TCA = TiedColumnAttention(d_model = D, num_heads=n_heads)
x_out = TCA(x)
print("input     ", x.shape)      # (B, S, L, D)
print("TCA output", x_out.shape)  # (B, S, L, D)



Tied-row attention (TR_attn) maps dimensions [B, S, H, L, L] torch.Size([2, 10, 2, 50, 50])
input      torch.Size([2, 10, 50, 90])
TRA output torch.Size([2, 10, 50, 90])

tied-column attention (TC_attn) maps dimensions [B, L, H, S, S] torch.Size([2, 50, 2, 10, 10])
input      torch.Size([2, 10, 50, 90])
TCA output torch.Size([2, 10, 50, 90])
