How to custom my attention function in torch.nn.MultiheadAttention

Hi @jack_ma

Sorry for the late answer. You could write a custom class using MultiHeadAttention module as its parent:

import torch
import torch.nn as nn

class myAttentionModule(nn.MultiheadAttention):
	def __init__(self, embed_dim, num_heads):
		super(myAttentionModule, self).__init__(embed_dim, num_heads)

	def forward(self, query, key, value):
		# your own forward function


query = torch.rand((1,10))
key = torch.rand((1,10))
value = torch.rand((1,10))


my_attention_module = myAttentionModule(embed_dim=10, num_heads=2)
out = my_attention_module(query, key, value)