If it’s of any use, please find my implementation of a Vanilla Transformer below. Please note, however, that the focus was on ease of understanding and ease of implementation since I want to cover it in my classes. Particularly, note that each AttentionHead
has its own linear layers Wq
, Wk
, and Wv
. Most “from scratch” implementations of a Vanialla Transformer I came across have only one set of linear layers in MultiHeadAttention
and use splicing etc. to handle the different heads.
Also, no guarantee of correctness 
import math
import torch
import torch.nn as nn
import torch.nn.functional as f
class PositionalEncoding(nn.Module):
"""
https://pytorch.org/tutorials/beginner/transformer_tutorial.html
"""
def __init__(self, d_model, vocab_size=5000, dropout=0.1):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(vocab_size, d_model)
position = torch.arange(0, vocab_size, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2).float()
* (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1), :]
return self.dropout(x)
class Attention(nn.Module):
### Implements Scaled Dot Product Attention
def __init__(self):
super().__init__()
def forward(self, Q, K, V, mask=None, dropout=None):
# All shapes: (batch_size, seq_len, hidden_size)
# Perform Q*K^T (* is the dot product here)
# We have to use torch.matmul since we work with batches!
out = torch.matmul(Q, K.transpose(1, 2)) # => shape: (B, L, L)
# Divide by scaling factor
out = out / (Q.shape[-1] ** 0.5)
# Optional: src_mask/tgt_mask (shape: (S, S); mask values are represented by -inf)
if mask is not None:
out += mask # Broadcast since it's the same mask for all samples in batch
# Push throught softmax layer
out = f.softmax(out, dim=-1)
# Optional: Dropout
if dropout is not None:
out = nn.Dropout(out, dropout)
# Multiply with values V
out = torch.matmul(out, V)
return out
class AttentionHead(nn.Module):
def __init__(self, model_size, qkv_size):
super().__init__()
self.Wq = nn.Linear(model_size, qkv_size)
self.Wk = nn.Linear(model_size, qkv_size)
self.Wv = nn.Linear(model_size, qkv_size)
self.attention = Attention()
self._init_parameters()
def _init_parameters(self):
nn.init.xavier_uniform_(self.Wq.weight)
nn.init.xavier_uniform_(self.Wk.weight)
nn.init.xavier_uniform_(self.Wv.weight)
def forward(self, query, key, value):
return self.attention(self.Wq(query), self.Wk(key), self.Wv(value))
class MultiHeadAttention(nn.Module):
def __init__(self, num_heads, model_size, qkv_size):
super().__init__()
self.heads = nn.ModuleList(
[AttentionHead(model_size, qkv_size) for _ in range(num_heads)]
)
# Linear layer to "unify" all heads into one
self.Wo = nn.Linear(num_heads * qkv_size, model_size)
self._init_parameters()
def _init_parameters(self):
nn.init.xavier_uniform_(self.Wo.weight)
def forward(self, query, key, value):
out_heads = tuple([ attention_head(query, key, value) for attention_head in self.heads ])
out = torch.cat(out_heads, dim=-1)
return self.Wo(out)
class FeedForward(nn.Module):
def __init__(self, model_size, hidden_size=2048):
super().__init__()
self.net = nn.Sequential(
nn.Linear(model_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, model_size),
)
def forward(self, X):
return self.net(X)
class TransformerEncoderLayer(nn.Module):
def __init__(self, model_size, num_heads, ff_hidden_size, dropout):
super().__init__()
qkv_size = max(model_size // num_heads, 1)
# MultiHeadAttention block
self.mha1 = MultiHeadAttention(num_heads, model_size, qkv_size)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(model_size)
# FeedForward block
self.ff = FeedForward(model_size, ff_hidden_size)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(model_size)
def forward(self, source):
# MultiHeadAttentionBlock
out1 = self.mha1(source, source, source)
out1 = self.dropout1(out1)
out1 = self.norm1(out1 + source)
# FeedForward block
out2 = self.ff(out1)
out2 = self.dropout2(out2)
out2 = self.norm2(out2 + out1)
# Return final output
return out2
class TransformerEncoder(nn.Module):
def __init__(self, num_layers=6, model_size=512, num_heads=8, ff_hidden_size=2048, dropout= 0.1):
super().__init__()
self.layers = nn.ModuleList(
[ TransformerEncoderLayer(model_size, num_heads, ff_hidden_size, dropout) for _ in range(num_layers) ]
)
def forward(self, source):
for l in self.layers:
source = l(source)
return source
##
## Decoder
##
class TransformerDecoderLayer(nn.Module):
def __init__(self, model_size, num_heads, ff_hidden_size, dropout):
super().__init__()
qkv_size = max(model_size // num_heads, 1)
# 1st MultiHeadAttention block (decoder input only)
self.mha1 = MultiHeadAttention(num_heads, model_size, qkv_size)
self.dropout1 = nn.Dropout(dropout)
self.norm1 = nn.LayerNorm(model_size)
# 2nd MultiHeadAttention block (encoder & decoder)
self.mha2 = MultiHeadAttention(num_heads, model_size, qkv_size)
self.dropout2 = nn.Dropout(dropout)
self.norm2 = nn.LayerNorm(model_size)
self.ff = FeedForward(model_size, ff_hidden_size)
self.dropout3 = nn.Dropout(dropout)
self.norm3 = nn.LayerNorm(model_size)
def forward(self, target, memory):
# 1st MultiHeadAttentionBlock
out1 = self.mha1(target, target, target)
out1 = self.dropout1(out1)
out1 = self.norm1(out1 + target)
# 2nd MultiHeadAttentionBlock
out2 = self.mha2(out1, memory, memory)
out2 = self.dropout2(out2)
out2 = self.norm2(out2 + out1)
# FeedForward block
out3 = self.ff(out2)
out3 = self.dropout3(out3)
out3 = self.norm3(out3 + out2)
# Return final output
return out3
class TransformerDecoder(nn.Module):
def __init__(self, num_layers=6, model_size=512, num_heads=8, ff_hidden_size=2048, dropout= 0.1):
super().__init__()
self.layers = nn.ModuleList(
[ TransformerDecoderLayer(model_size, num_heads, ff_hidden_size, dropout) for _ in range(num_layers) ]
)
def forward(self, target, memory):
for l in self.layers:
target = l(target, memory)
return target
class Transformer(nn.Module):
def __init__(self, num_encoder_layers=6, num_decoder_layers=6, model_size=512, num_heads=8, ff_hidden_size=2048, dropout= 0.1):
super().__init__()
self.encoder = TransformerEncoder(
num_layers=num_encoder_layers,
model_size=model_size,
num_heads=num_heads,
ff_hidden_size=ff_hidden_size,
dropout=dropout
)
self.decoder = TransformerDecoder(
num_layers=num_decoder_layers,
model_size=model_size,
num_heads=num_heads,
ff_hidden_size=ff_hidden_size,
dropout=dropout
)
def forward(self, source, target):
memory = self.encoder(source)
return self.decoder(target, memory)