Pruning/Compressing heads in attention blocks

I’ve a conceptual question

BERT-base has a dimension of 768 for query, key and value and 12 heads (Hidden dimension=768, number of heads=12). The same is conveyed if we see the BERT-base architecture

(self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
)

Now, my question is:

Can I consider the first 64 neurons from the out_features as the first-head, the next 64 neurons from the out_features as the 2nd head and so on? (sec 3.2.2 from original paper; Link)

Basically, I am wondering if the Linear module representing query matrix; which is 768x768 can be thought as (768x64), (768x64)…12 times? The same for key and value modules

If so, is it possible to provide some starter code as I am unable to wrap around my head. Any help is appreciated (and I’ve some sample in the contribution section)

P.S: Here’s the issue I raised in Github (link)

I don’t think(!) you can make this consideration.

Yes, internally, all attention heads are represented by a single nn.Linear layer. However, their combined output is then still pushed through another nn.Linear layer – W^O in the cited paper, top of Page 5. This is the output you “see”, and here you no longer can say that the first 64 neurons refer to the first head.

I have a transformer implementation from scratch available here; see class “MultiHeadAttention”. Please note that this implementation treats each head as its own nn.Linear layer. I also have an optimized version using asingle nn.Layer for all the heads, but not yet on Github. In any case, both implementations do a

out = self.Wo(out)

at the very end, as this last layer is the same.

Thanks for the response Chris. I understand as we are combining at the very end, it’s difficult to achieve this feature.

There’s an active discussion going in Huggingface Github on this and it turns out HF has few functions which support this functionality. I hope this helps :slight_smile:

Link: https://github.com/huggingface/transformers/issues/27044

1 Like