Embed_dim must be divisible by num_heads

I get an error: AssertionError: embed_dim must be divisible by num_heads
(assert self.head_dim * num_heads == self.embed_dim, “embed_dim must be divisible by num_heads”)
I can’t understand how to fix it.

import torch
import torch.nn as nn

class Trans(nn.Module):
    def __init__(self, d_model=10):
        super().__init__()
        self.tr = nn.Transformer(d_model)
        
    def forward(self, src, tgt):
        out = self.tr(src, tgt)
        
        return out #output: (T, N, E)
    
net = Trans() 

sr = torch.randn(5, 1, 10) #src: (S, N, E)
tg = torch.randn(1, 1, 10) #tgt: (T, N, E)
outputs = net(sr, tg) 

print(outputs)
1 Like

The nn.Transformer module by default uses 8 attention heads. Since the MultiHeadedAttention impl slices the model up into the number of head blocks (simply by a view operation) the model dimension must be a divisible by the number of heads. Please see also the documentation of nn.MultiheadAttention.

2 Likes

I don’t know why the implemented it like this, its a bit annoying. If instead of embed_dim being an input they asked you for head_dim and they calculated embed_dim as:

self.embed_dim = self.head_dim * num_heads

It would be much easier to understand because you can do a per-head reasoning when defining the shapes, and it would also guarantee you never get an error. The easiest way to deal with this is just to do the math on your own outside:

nhead = 5
head_dim = 32
dmodel = nhead * head_dim

net = nn.Transformer(d_model=d_model, nhead=nhead)
5 Likes

I am also confused with the PyTorch implementation.

In the multi_head_attention_forward method Q, K and V matrices are first calculated with respect to the shapes of the source and target tensors (S,N,E) and (T,N,E), E being the size of each embedding.
But then they are reshaped splitting the embedding dimension as @cgarciae just said
So why ? As the goal of MHA is to project these matrices on each head parameters and then concat them and project again on W_0 ?
The formula is even in the documentation but is not applied properly ?
This is very unclear, is there something I totally missed ?
EDIT: I found what I missed, I got confused by the formula itself which doesn’t distingue the Q/K/V slices for each head

2 Likes

Hi @Nat
From my understanding, Pytorch forces the embedding size to be consistent all over the computation. Hence, the embed_dim must be divisible by num_heads so later on when you “concatenate” all heads, the matrix size will be embed_dim.

The use of W0 in the documentation you showed above is not for reshaping the concatenate of heads back to embed_dim.
Here is the proof.
You can notice this code line
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len, bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)

attn_output is reshaped back to embed_dim already (no need W0, out_proj_weight, to help reshape)
In the 2nd line, out_proj_weight (W0) is initiated with dimension (embed_dim, embed_dim)
So you can see that the use of W0 is not for reshaping concatenate heads matrix.

To summary, with this implementation, Pytorch forces the matrix concatenate attention heads to have dimension is embed_dim. Hence, follow Attention is all you need paper, the purpose of using W0 to reshape concatenate heads to embed_dim is no need. Adding more evidence this statement, as shown above, out_proj_weight (W0) is initiated with size (embed_dim, embed_dim). So, W0 is not used for reshaping the “concatenate” attention heads matrix.

4 Likes

Yes, this is quite confusing, why it’s required to have embed_dim = head_dim * num_heads?

I know it’s pretty late but I also came across the same issue. After some search my guess is that PyTorch just didn’t implement the original version, which would have linearly transformed the input to multiple heads instead of reshape-and-split. A nice snippet showing what the original version should look like can be found here. I’m not sure if this would negatively impact the performance, and whether or not one need to manually append linear layers before and after though.

1 Like

I do not understand why they enforce embedding dimension of input = embedding dimension post linear transform.