Attention mask shape error - shape should be (1,1)

I’m implementing a Transformers architecture from the ground up on 1 dummy sentence.

Here is the code

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class PositionalEncoding(nn.Module):

    def __init__(self, context_size, d_model):
        super().__init__()

        self.encoding = torch.zeros(context_size, d_model)

        pos = torch.arange(0, context_size).unsqueeze(dim=1)
        dim = torch.arange(
            0, d_model, 2)  # dim is i in the positional encoding formula
        self.encoding[:, 0::2] = torch.sin(pos / (10000**(2 * dim / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000**(2 * dim / d_model)))

    def forward(self, x):
        seq_len = x.size(1)
        return self.encoding[:seq_len, :]

class PositionwiseFeedForward(nn.Module):

    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

class EncoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

    def forward(self, x):
        hidden_states, _ = self.self_attn(query=x, key=x, value=x)
        x = self.norm1(x + hidden_states)
        ff_output = self.feed_forward(x)
        x = self.norm2(x + ff_output)
        return x

class Encoder(nn.Module):
    # input_size - # rows in token embedding
    # context size - # rows in positional embedding
    # d_ff - internal dimension of the FF network
    # num encoder blocks
    def __init__(self, input_size, context_size, d_model, d_ff, num_heads,
                 n_blocks):
        super().__init__()

        self.embedding = nn.Embedding(input_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)

        self.blocks = nn.ModuleList([
            EncoderBlock(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
            ) for _ in range(n_blocks)
        ])

    def forward(self, x):
        x = self.embedding(x) + self.pos_embedding(x)
        for block in self.blocks:
            x = block(x)
        return x

class DecoderBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, num_heads)
        self.cross_attn = nn.MultiheadAttention(d_model, num_heads)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

    def forward(self, x, enc_output):
        lookahead_mask = torch.triu(torch.ones(x.shape[1], x.shape[1])).bool().transpose(0,1)  # lookahead mask shape should be context size (1st dim)
        hidden_states, _ = self.self_attn(x, x, x, attn_mask = lookahead_mask)
        x = self.norm1(x + hidden_states)
        hidden_states, _ = self.cross_attn(
                             query=x, key=enc_output, value=enc_output)
        x = self.norm2(x + hidden_states)
        ff_output = self.feed_forward(x)
        x = self.norm3(x + ff_output)
        return x

class Decoder(nn.Module):
    def __init__(self, output_size, context_size,
                 d_model, d_ff, num_heads, n_blocks):
        super().__init__()
        self.embedding = nn.Embedding(output_size, d_model)
        self.pos_embedding = PositionalEncoding(context_size, d_model)

        self.blocks = nn.ModuleList([
            DecoderBlock(
                d_model=d_model,
                num_heads=num_heads,
                d_ff=d_ff,
            )
            for _ in range(n_blocks)
        ])

        self.out = nn.Linear(d_model, output_size)

    def forward(self, x, enc_output):
        x = self.embedding(x) + self.pos_embedding(x)

        for block in self.blocks:
            x = block(x, enc_output)

        output = self.out(x)
        return output

class Transformer(nn.Module):

    def __init__(self, vocab_size, context_size,
                 d_model, d_ff, num_heads, n_blocks):
        super().__init__()

        self.encoder = Encoder(
            vocab_size,
            context_size,
            d_model,
            d_ff,
            num_heads,
            n_blocks
        )

        self.decoder = Decoder(
            vocab_size,
            context_size,
            d_model,
            d_ff,
            num_heads,
            n_blocks
        )

    def forward(self, input_encoder, input_decoder):
        enc_output = self.encoder(input_encoder)   # (64, 100, 10)
        output = self.decoder(input_decoder, enc_output)  # input_decoder shape - (64, 99)
        return output

SOS_token = 0
EOS_token = 1
PAD_token = 2   # Need to have padding so that the input & output sentences
                # are the same length - required for the cross-attention computation

index2words = {
    SOS_token: 'SOS',
    EOS_token: 'EOS',
    PAD_token: 'PAD'
}

words = "How are you doing ? I am good and you ?"
words_list = set(words.lower().split(' '))
for word in words_list:
    index2words[len(index2words)] = word

words2index = {w: i for i, w in index2words.items()}

def convert2tensors(sentence, max_len):
    words_list = sentence.lower().split(' ')
    padding = ['PAD'] * (max_len - len(words_list))
    words_list.extend(padding)
    indexes = [words2index[word] for word in words_list]
    return torch.tensor(indexes, dtype=torch.long).view(1, -1)

D_MODEL = 10
VOCAB_SIZE = len(words2index)
N_BLOCKS = 10
D_FF = 20
CONTEXT_SIZE = 100
NUM_HEADS = 2

transformer = Transformer(
    vocab_size=VOCAB_SIZE,
    context_size=CONTEXT_SIZE,
    d_model=D_MODEL,
    d_ff=D_FF,  # internal dimension of the feed forward network
    num_heads=NUM_HEADS,
    n_blocks=N_BLOCKS
)

input_sentence = "How are you doing ?"
output_sentence = "I am good and"

input_encoder = convert2tensors(input_sentence, CONTEXT_SIZE)
input_decoder = convert2tensors(output_sentence, CONTEXT_SIZE)

output_toy = transformer(input_encoder, input_decoder)

I’m adding a lookahead mask in the DecoderBlock and I get an error on this line of the forward method.
hidden_states, _ = self.self_attn(x, x, x, attn_mask = lookahead_mask)

Error:
RuntimeError: The shape of the 2D attn_mask is torch.Size([100, 100]), but should be (1, 1).

Why should the shape be (1,1)? x has shape (1,100,10) - batch, context size, d_model.
lookahead_mask has shape (100,100).

Also going through this, realized not sure whether it 's ok to apply the attention mask after summing the embeddings with the positional embeddings.

1 Like

Hello. By any chance, have you found a solution to this ? I’m facing a similar problem, but with tgt_mask.