Why the Transformer model size is the same for different nhead?

I have just started learning Transformer architecture. I suppose the number of heads should increase the number of QKV matrices. As a result, the number of learnable parameters should increase.

Why does the following code produce the same output?

import torch
from torch import nn

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(count_parameters(nn.Transformer(nhead=8))) # 44140544
print(count_parameters(nn.Transformer(nhead=1))) # 44140544

The model size is actually the size of the QKV matrices, the latter sizes are scaled by the number of heads. In therms of source code, it looks something like that.

qkv_size = max(model_size // num_heads, 1)

That the number of parameters should not increase with the number of heads is by design. The also means that you will see an error of the model_size cannot be properly be divided by num_heads.

Do I understand correctly: the nhead is like a parameter to balance between single big QKV matrices (nhead=1) or multiple smaller QKV matrices?

Then how should I create a new Transformer with the same size QKV matrices as given but with 8 more heads? (it should be 8 times larger).

t1 = nn.Transformer(nhead=1)
t2 = nn.Transformer(qkv_size same as in t1 but 8 times more heads)
count_parameters(t1) * 8 == count_parameters(t2)

You can check the original Attention is all You Need paper:

In this work we employ h = 8 parallel attention layers, or heads. For each of these we use
dk = dv = dmodel/h = 64. Due to the reduced dimension of each head, the total computational cost
is similar to that of single-head attention with full dimensionality.

It doesn’t have to be done this way, but that’s how nn.Transformer is implemented anyway.

For learning purposes, I’ve implemented my own Transformer architecture from scratch, and there I could easily keep model_size without dividing by the number of heads.

EDIT: Yes, with nn.Transformer you would need to multiple your intended model_size by the number of heads.

2 Likes