I saw the following implementation of attention in PyTorch:
class WeightedAttention(nn.Module):
"""
Weighted softmax attention layer
"""
def __init__(self, gate_nn, message_nn, num_heads=1):
"""
Inputs
----------
gate_nn: Variable(nn.Module)
"""
super(WeightedAttention, self).__init__()
self.gate_nn = gate_nn
self.message_nn = message_nn
self.pow = torch.nn.Parameter(torch.randn((1)))
def forward(self, fea, index, weights):
""" forward pass """
gate = self.gate_nn(fea)
gate = gate - scatter_max(gate, index, dim=0)[0][index]
gate = (weights ** self.pow) * gate.exp()
# gate = weights * gate.exp()
# gate = gate.exp()
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-13)
fea = self.message_nn(fea)
out = scatter_add(gate * fea, index, dim=0)
return out
def __repr__(self):
return '{}(gate_nn={})'.format(self.__class__.__name__,
self.gate_nn)
Is it the correct way of implementing attention in Pytorch?