If it’s of any use, please find my implementation of a Vanilla Transformer below. Please note, however, that the focus was on ease of understanding and ease of implementation since I want to cover it in my classes. Particularly, note that each `AttentionHead`

has its own linear layers `Wq`

, `Wk`

, and `Wv`

. Most “from scratch” implementations of a Vanialla Transformer I came across have only one set of linear layers in `MultiHeadAttention`

and use splicing etc. to handle the different heads.

Also, no guarantee of correctness

```
import math
import torch
import torch.nn as nn
import torch.nn.functional as f
class PositionalEncoding(nn.Module):
"""
https://pytorch.org/tutorials/beginner/transformer_tutorial.html
"""
def __init__(self, d_model, vocab_size=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(vocab_size, d_model)
position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float()
* (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1), :]
return self.dropout(x)
class Attention(nn.Module):
### Implements Scaled Dot Product Attention
def __init__(self):
super().__init__()
def forward(self, Q, K, V, mask=None, dropout=None):
# All shapes: (batch_size, seq_len, hidden_size)
# Perform Q*K^T (* is the dot product here)
# We have to use torch.matmul since we work with batches!
out = torch.matmul(Q, K.transpose(1, 2)) # => shape: (B, L, L)
# Divide by scaling factor
out = out / (Q.shape[-1] ** 0.5)
# Optional: src_mask/tgt_mask (shape: (S, S); mask values are represented by -inf)
if mask is not None:
out += mask # Broadcast since it's the same mask for all samples in batch
# Push throught softmax layer
out = f.softmax(out, dim=-1)
# Optional: Dropout
if dropout is not None:
out = nn.Dropout(out, dropout)
# Multiply with values V
out = torch.matmul(out, V)
return out
class AttentionHead(nn.Module):
def __init__(self, model_size, qkv_size):
super().__init__()
self.Wq = nn.Linear(model_size, qkv_size)
self.Wk = nn.Linear(model_size, qkv_size)
self.Wv = nn.Linear(model_size, qkv_size)
self.attention = Attention()
self._init_parameters()
def _init_parameters(self):
nn.init.xavier_uniform_(self.Wq.weight)
nn.init.xavier_uniform_(self.Wk.weight)
nn.init.xavier_uniform_(self.Wv.weight)
def forward(self, query, key, value):
return self.attention(self.Wq(query), self.Wk(key), self.Wv(value))
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, model_size, qkv_size):
super().__init__()
self.heads = nn.ModuleList(
[AttentionHead(model_size, qkv_size) for _ in range(num_heads)]
)
# Linear layer to "unify" all heads into one
self.Wo = nn.Linear(num_heads * qkv_size, model_size)
self._init_parameters()
def _init_parameters(self):
nn.init.xavier_uniform_(self.Wo.weight)
def forward(self, query, key, value):
out_heads = tuple([ attention_head(query, key, value) for attention_head in self.heads ])
out = torch.cat(out_heads, dim=-1)
return self.Wo(out)
class FeedForward(nn.Module):
def __init__(self, model_size, hidden_size=2048):
super().__init__()
self.net = nn.Sequential(
nn.Linear(model_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, model_size),
)
def forward(self, X):
return self.net(X)
class TransformerEncoderLayer(nn.Module):
def __init__(self, model_size, num_heads, ff_hidden_size, dropout):
super().__init__()
qkv_size = max(model_size // num_heads, 1)
# MultiHeadAttention block
self.mha1 = MultiHeadAttention(num_heads, model_size, qkv_size)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(model_size)
# FeedForward block
self.ff = FeedForward(model_size, ff_hidden_size)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(model_size)
def forward(self, source):
# MultiHeadAttentionBlock
out1 = self.mha1(source, source, source)
out1 = self.dropout1(out1)
out1 = self.norm1(out1 + source)
# FeedForward block
out2 = self.ff(out1)
out2 = self.dropout2(out2)
out2 = self.norm2(out2 + out1)
# Return final output
return out2
class TransformerEncoder(nn.Module):
def __init__(self, num_layers=6, model_size=512, num_heads=8, ff_hidden_size=2048, dropout= 0.1):
super().__init__()
self.layers = nn.ModuleList(
[ TransformerEncoderLayer(model_size, num_heads, ff_hidden_size, dropout) for _ in range(num_layers) ]
)
def forward(self, source):
for l in self.layers:
source = l(source)
return source
##
## Decoder
##
class TransformerDecoderLayer(nn.Module):
def __init__(self, model_size, num_heads, ff_hidden_size, dropout):
super().__init__()
qkv_size = max(model_size // num_heads, 1)
# 1st MultiHeadAttention block (decoder input only)
self.mha1 = MultiHeadAttention(num_heads, model_size, qkv_size)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(model_size)
# 2nd MultiHeadAttention block (encoder & decoder)
self.mha2 = MultiHeadAttention(num_heads, model_size, qkv_size)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(model_size)
self.ff = FeedForward(model_size, ff_hidden_size)
self.dropout3 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(model_size)
def forward(self, target, memory):
# 1st MultiHeadAttentionBlock
out1 = self.mha1(target, target, target)
out1 = self.dropout1(out1)
out1 = self.norm1(out1 + target)
# 2nd MultiHeadAttentionBlock
out2 = self.mha2(out1, memory, memory)
out2 = self.dropout2(out2)
out2 = self.norm2(out2 + out1)
# FeedForward block
out3 = self.ff(out2)
out3 = self.dropout3(out3)
out3 = self.norm3(out3 + out2)
# Return final output
return out3
class TransformerDecoder(nn.Module):
def __init__(self, num_layers=6, model_size=512, num_heads=8, ff_hidden_size=2048, dropout= 0.1):
super().__init__()
self.layers = nn.ModuleList(
[ TransformerDecoderLayer(model_size, num_heads, ff_hidden_size, dropout) for _ in range(num_layers) ]
)
def forward(self, target, memory):
for l in self.layers:
target = l(target, memory)
return target
class Transformer(nn.Module):
def __init__(self, num_encoder_layers=6, num_decoder_layers=6, model_size=512, num_heads=8, ff_hidden_size=2048, dropout= 0.1):
super().__init__()
self.encoder = TransformerEncoder(
num_layers=num_encoder_layers,
model_size=model_size,
num_heads=num_heads,
ff_hidden_size=ff_hidden_size,
dropout=dropout
)
self.decoder = TransformerDecoder(
num_layers=num_decoder_layers,
model_size=model_size,
num_heads=num_heads,
ff_hidden_size=ff_hidden_size,
dropout=dropout
)
def forward(self, source, target):
memory = self.encoder(source)
return self.decoder(target, memory)
```