For purely educational purposes, my goal is to implement basic Transformer architecture from scratch. So far I focused on the encoder for classification tasks and assumed that all samples in a batch have the same length. This means, I didn’t care about any masking.
However, now I want to support masking. I like to think that I understand the the purpose of, e.g., the target mask so the order cannot “peak into the future”. I generate this mask as follows:
source_batch = torch.LongTensor([
[1, 2, 3, 0, 0, 0],
[1, 2, 3, 4, 5, 6],
[1, 2, 3, 4, 5, 0]
])
batch_size, seq_len = source_batch.shape
def generate_tgt_mask(size):
return torch.triu(torch.ones(seq_len, seq_len) * float('-inf'), diagonal=1)
print(generate_tgt_mask(seq_len))
yielding:
tensor([[0., -inf, -inf, -inf, -inf, -inf],
[0., 0., -inf, -inf, -inf, -inf],
[0., 0., 0., -inf, -inf, -inf],
[0., 0., 0., 0., -inf, -inf],
[0., 0., 0., 0., 0., -inf],
[0., 0., 0., 0., 0., 0.]])
which should be the expected outcome when I check the PyTorch docs. This mask has a shape of (L,L)
where L
is the sequence length of the source or target sequence. Again, this matches the docs.
I use this mask in my implementation of the Scaled Dot Product Attention as follows:
class Attention(nn.Module):
### Implements Scaled Dot Product Attention
def __init__(self):
super().__init__()
def forward(self, Q, K, V, mask=None, dropout=None):
# All shapes: (batch_size, seq_len, hidden_size)
# Perform Q*K^T (* is the dot product here)
# We have to use torch.matmul since we work with batches!
out = torch.matmul(Q, K.transpose(1, 2)) # => shape: (B, L, L)
# Divide by scaling factor
out = out / (Q.shape[-1] ** 0.5)
# Optional: src_mask/tgt_mask (shape: (L, L); mask values are represented by -inf)
if mask is not None:
out += mask.unsqueeze(0) # Broadcast since it's the same mask for all samples in batch
# Push throught softmax layer
out = f.softmax(out, dim=-1)
# Optional: Dropout
if dropout is not None:
out = nn.Dropout(out, dropout)
# Multiply with values V
out = torch.matmul(out, V)
return out
So far so good…at least I like to think. However, my problem is not the mask to address the padding (e.g. src_key_padding_mask
). From different tutorials using the nn.Transformer
, this mask can be generated as followes:
pad_token_index = 0
src_key_padding_mask = (source_batch != pad_token_index)
print(src_key_padding_mask)
yielding:
tensor([[ True, True, True, False, False, False],
[ True, True, True, True, True, True],
[ True, True, True, True, True, False]])
having shape of (N,L)
which again matches the doc.
What I’m now missing is: How do I have to incorporate this matrix into my implementation of Attention
?