How to implement custom attention functions (especially attention masks)


I’m making my first foray into transformers, following this tutorial. I wanted to try experimenting with different attention functions, and found this previous discussion with recommendations on how to implement your own attention.

I’ve followed those recommendations, experimenting with just a single-head attention module, and have working code. However, I wasn’t sure how to hand masks, and sort of just did away with them for my TransformerEncoderLayers. The result has mildly better validation loss than the vanilla TransformerEncoderLayer, presumably because my code is cheating and looking at more information than it should. I’d like help with correcting the implementation I’ve made to properly use the attention masks (and probably other kinds of masks, they were always the part I sort of glazed over).

I followed the linked Tutorial exactly, with the following changes:

  1. I made a custom single head attention module like this:
class myAttentionModule(nn.MultiheadAttention):
    def __init__(self, embed_dim, num_heads):
        super(myAttentionModule, self).__init__(embed_dim, num_heads)
        self.kdim = embed_dim
        self.vdim = embed_dim
        self.embed_dim = embed_dim
        self.query = nn.Linear(self.embed_dim, self.kdim)
        self.key = nn.Linear(self.embed_dim, self.kdim)
        self.value = nn.Linear(self.embed_dim, self.vdim)
    def forward(self, x):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)
        y = nn.Softmax(dim=1)(torch.sum(q.unsqueeze(0)*k.unsqueeze(1), dim=2)).unsqueeze(2)
        y = torch.sum(y*v.unsqueeze(1), dim=1)
        return y

This is the “standard” dot product attention from Attention is All You Need, except that I didn’t divide by the square root of the embedding dimension (because I forgot to). I may have erred somehow in writing it, but I think it’s at least an approximation of correct.

  1. I created my own instance of a TransformerEncoderLayer which would use this attention formula, rather than the attention from nn.MultiheadAttention. I took the built-in implementation of TransformerEncoderLayer (I’m not allowed to give a third link as a new user; the source code for TransformerEncoderLayers is easy to find via google though), and made the following changes:
  • In the init:
# My code:
self.self_attn = myAttentionModule(d_model, nhead)

# Original code:
# self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first, **factory_kwargs)
  • I modified the self attention block, since myAttentionModule doesn’t take all the parameters MultiheadAttention does.
# My code
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
        x = self.self_attn(x)
                           #need_weights=False, is_causal=is_causal)[0]
        return self.dropout1(x)

# Original code:
    def _sa_block(self, x: Tensor,
                  attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor], is_causal: bool = False) -> Tensor:
        x = self.self_attn(x, x, x,
                           need_weights=False, is_causal=is_causal)[0]
        return self.dropout1(x)
  • I used this TransformerEncoderLayer instead of the standard, built-in one in Pytorch.

That’s all! Ideally, what I want is an easy way to carry the masking information from MultiheadAttention into myAttentionModule. Any advice would be greatly appreciated, though!