Attention mask shape error - shape should be (1,1)

I’m implementing a Transformers architecture from the ground up on 1 dummy sentence.

Here is the code

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):

    def __init__(self, context_size, d_model):
        super().__init__()

        self.encoding = torch.zeros(context_size, d_model)

        pos = torch.arange(0, context_size).unsqueeze(dim=1)
        dim = torch.arange(
            0, d_model, 2)  # dim is i in the positional encoding formula
        self.encoding[:, 0::2] = torch.sin(pos / (10000**(2 * dim / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000**(2 * dim / d_model)))

    def forward(self, x):
        seq_len = x.size(1)
        return self.encoding[:seq_len, :]

class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        hidden_states, _ = self.self_attn(query=x, key=x, value=x)
        x = self.norm1(x + hidden_states)
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        return x

class Encoder(nn.Module):
    # input_size - # rows in token embedding
    # context size - # rows in positional embedding
    # d_ff - internal dimension of the FF network
    # num encoder blocks
    def __init__(self, input_size, context_size, d_model, d_ff, num_heads,
                 n_blocks):
        super().__init__()

        self.embedding = nn.Embedding(input_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)

        self.blocks = nn.ModuleList([
            EncoderBlock(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
            ) for _ in range(n_blocks)
        ])

    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding(x)
        for block in self.blocks:
            x = block(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads)
        self.cross_attn = nn.MultiheadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_output):
        lookahead_mask = torch.triu(torch.ones(x.shape[1], x.shape[1])).bool().transpose(0,1)  # lookahead mask shape should be context size (1st dim)
        hidden_states, _ = self.self_attn(x, x, x, attn_mask = lookahead_mask)
        x = self.norm1(x + hidden_states)
        hidden_states, _ = self.cross_attn(
                             query=x, key=enc_output, value=enc_output)
        x = self.norm2(x + hidden_states)
        ff_output = self.feed_forward(x)
        x = self.norm3(x + ff_output)
        return x

class Decoder(nn.Module):
    def __init__(self, output_size, context_size,
                 d_model, d_ff, num_heads, n_blocks):
        super().__init__()
        self.embedding = nn.Embedding(output_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)

        self.blocks = nn.ModuleList([
            DecoderBlock(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
            )
            for _ in range(n_blocks)
        ])

        self.out = nn.Linear(d_model, output_size)

    def forward(self, x, enc_output):
        x = self.embedding(x) + self.pos_embedding(x)

        for block in self.blocks:
            x = block(x, enc_output)

        output = self.out(x)
        return output

class Transformer(nn.Module):

    def __init__(self, vocab_size, context_size,
                 d_model, d_ff, num_heads, n_blocks):
        super().__init__()

        self.encoder = Encoder(
            vocab_size,
            context_size,
            d_model,
            d_ff,
            num_heads,
            n_blocks
        )

        self.decoder = Decoder(
            vocab_size,
            context_size,
            d_model,
            d_ff,
            num_heads,
            n_blocks
        )

    def forward(self, input_encoder, input_decoder):
        enc_output = self.encoder(input_encoder)   # (64, 100, 10)
        output = self.decoder(input_decoder, enc_output)  # input_decoder shape - (64, 99)
        return output

SOS_token = 0
EOS_token = 1
PAD_token = 2   # Need to have padding so that the input & output sentences
                # are the same length - required for the cross-attention computation

index2words = {
    SOS_token: 'SOS',
    EOS_token: 'EOS',
    PAD_token: 'PAD'
}

words = "How are you doing ? I am good and you ?"
words_list = set(words.lower().split(' '))
for word in words_list:
    index2words[len(index2words)] = word

words2index = {w: i for i, w in index2words.items()}

def convert2tensors(sentence, max_len):
    words_list = sentence.lower().split(' ')
    padding = ['PAD'] * (max_len - len(words_list))
    words_list.extend(padding)
    indexes = [words2index[word] for word in words_list]
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)

D_MODEL = 10
VOCAB_SIZE = len(words2index)
N_BLOCKS = 10
D_FF = 20
CONTEXT_SIZE = 100
NUM_HEADS = 2

transformer = Transformer(
    vocab_size=VOCAB_SIZE,
    context_size=CONTEXT_SIZE,
    d_model=D_MODEL,
    d_ff=D_FF,  # internal dimension of the feed forward network
    num_heads=NUM_HEADS,
    n_blocks=N_BLOCKS
)

input_sentence = "How are you doing ?"
output_sentence = "I am good and"

input_encoder = convert2tensors(input_sentence, CONTEXT_SIZE)
input_decoder = convert2tensors(output_sentence, CONTEXT_SIZE)

output_toy = transformer(input_encoder, input_decoder)

I’m adding a lookahead mask in the DecoderBlock and I get an error on this line of the forward method.
hidden_states, _ = self.self_attn(x, x, x, attn_mask = lookahead_mask)

Error:
RuntimeError: The shape of the 2D attn_mask is torch.Size([100, 100]), but should be (1, 1).

Why should the shape be (1,1)? x has shape (1,100,10) - batch, context size, d_model.
lookahead_mask has shape (100,100).

Also going through this, realized not sure whether it 's ok to apply the attention mask after summing the embeddings with the positional embeddings.

Hello. By any chance, have you found a solution to this ? I’m facing a similar problem, but with tgt_mask.