How to custom my attention function in torch.nn.MultiheadAttention

Hello! I look up the docs and can not find what attention function used in torch.nn.MultiheadAttention: MultiheadAttention — PyTorch 1.7.0 documentation
What’s more, I don’t know how to custom my attention function.

Hi jack_ma,

I would advise you to read carefully the paper Attention is all you Need: https://arxiv.org/abs/1706.03762 if you want to build your own attention module from scratch.

Basically, the multi head attention takes as inputs queries, keys and values that are mapped from a sequence of embeddings (of words for instance). The keys are conveying information about the nature of the embeddings, the values encode the content itself while the queries indicate the information the embeddings need (https://arxiv.org/abs/1911.07757). I hope this provides some intuition about the triplet query - key - value and on how the attention module can learn which embeddings in the sequence to focus on.

Thank you, it means there is no api provided in torch.nn.MultiheadAttention to do this, right? If I want to use my own attention function I need to implement from scratch without using torch.nn.MultiheadAttention module.

Hi @jack_ma

Sorry for the late answer. You could write a custom class using MultiHeadAttention module as its parent:

import torch
import torch.nn as nn

class myAttentionModule(nn.MultiheadAttention):
	def __init__(self, embed_dim, num_heads):
		super(myAttentionModule, self).__init__(embed_dim, num_heads)

	def forward(self, query, key, value):
		# your own forward function


query = torch.rand((1,10))
key = torch.rand((1,10))
value = torch.rand((1,10))


my_attention_module = myAttentionModule(embed_dim=10, num_heads=2)
out = my_attention_module(query, key, value)