Hi @jack_ma
Sorry for the late answer. You could write a custom class using MultiHeadAttention module as its parent:
import torch
import torch.nn as nn
class myAttentionModule(nn.MultiheadAttention):
def __init__(self, embed_dim, num_heads):
super(myAttentionModule, self).__init__(embed_dim, num_heads)
def forward(self, query, key, value):
# your own forward function
query = torch.rand((1,10))
key = torch.rand((1,10))
value = torch.rand((1,10))
my_attention_module = myAttentionModule(embed_dim=10, num_heads=2)
out = my_attention_module(query, key, value)