I am trying to understand a graph neural network code which has implemented a weighted attention layer as follows:

```
class WeightedAttention(nn.Module):
"""
Weighted softmax attention layer
"""
def __init__(self, gate_nn, message_nn, num_heads=1):
"""
Inputs
----------
gate_nn: Variable(nn.Module)
"""
super(WeightedAttention, self).__init__()
self.gate_nn = gate_nn
self.message_nn = message_nn
self.pow = torch.nn.Parameter(torch.randn((1)))
def forward(self, fea, index, weights):
""" forward pass """
gate = self.gate_nn(fea)
gate = gate - scatter_max(gate, index, dim=0)[0][index]
gate = (weights ** self.pow) * gate.exp()
# gate = weights * gate.exp()
# gate = gate.exp()
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-13)
fea = self.message_nn(fea)
out = scatter_add(gate * fea, index, dim=0)
return out
def __repr__(self):
return '{}(gate_nn={})'.format(self.__class__.__name__,
self.gate_nn)
```

I am unable to find anything that explains the working of the weighted attention network. Can someone explain what it is and how the code has implemented it using `scatter_max`

and `scatter_add`

? Also is there any mathematical formulae that can be related to this?