Recently I have been studying GNN and I came across the following piece of code which I am unable to understand:
self.pooling = nn.ModuleList([WeightedAttention(
gate_nn=SimpleNetwork(2*fea_len, 1, hidden_ele),
message_nn=SimpleNetwork(2*fea_len, fea_len, hidden_msg),
# message_nn=nn.Linear(2*fea_len, fea_len),
# message_nn=nn.Identity(),
) for _ in range(num_heads)])
Is the WeightedAttention
taking Simple Network
as Input which comprises of linear layer? Also, how does it work inside the WeightedAttention
class? The WeightedAttention
and Simple Network
class are as follows:
class WeightedAttention(nn.Module):
"""
Weighted softmax attention layer
"""
def __init__(self, gate_nn, message_nn, num_heads=1):
"""
Inputs
----------
gate_nn: Variable(nn.Module)
"""
super(WeightedAttention, self).__init__()
self.gate_nn = gate_nn
self.message_nn = message_nn
self.pow = torch.nn.Parameter(torch.randn((1)))
def forward(self, fea, index, weights):
""" forward pass """
gate = self.gate_nn(fea)
gate = gate - scatter_max(gate, index, dim=0)[0][index]
gate = (weights ** self.pow) * gate.exp()
# gate = weights * gate.exp()
# gate = gate.exp()
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-13)
fea = self.message_nn(fea)
out = scatter_add(gate * fea, index, dim=0)
return out
def __repr__(self):
return '{}(gate_nn={})'.format(self.__class__.__name__,
self.gate_nn)
class SimpleNetwork(nn.Module):
"""
Simple Feed Forward Neural Network
"""
def __init__(self, input_dim, output_dim, hidden_layer_dims):
"""
Inputs
----------
input_dim: int
output_dim: int
hidden_layer_dims: list(int)
"""
super(SimpleNetwork, self).__init__()
dims = [input_dim]+hidden_layer_dims
self.fcs = nn.ModuleList([nn.Linear(dims[i], dims[i+1])
for i in range(len(dims)-1)])
self.acts = nn.ModuleList([nn.LeakyReLU() for _ in range(len(dims)-1)])
self.fc_out = nn.Linear(dims[-1], output_dim)
def forward(self, fea):
for fc, act in zip(self.fcs, self.acts):
fea = act(fc(fea))
return self.fc_out(fea)
def __repr__(self):
return '{}'.format(self.__class__.__name__)