# How does Neural Network as Input work?

Recently I have been studying GNN and I came across the following piece of code which I am unable to understand:

``````self.pooling = nn.ModuleList([WeightedAttention(
gate_nn=SimpleNetwork(2*fea_len, 1, hidden_ele),
message_nn=SimpleNetwork(2*fea_len, fea_len, hidden_msg),
# message_nn=nn.Linear(2*fea_len, fea_len),
# message_nn=nn.Identity(),
``````

Is the `WeightedAttention` taking `Simple Network` as Input which comprises of linear layer? Also, how does it work inside the `WeightedAttention` class? The `WeightedAttention` and `Simple Network` class are as follows:

``````class WeightedAttention(nn.Module):
"""
Weighted softmax attention layer
"""
"""
Inputs
----------
gate_nn: Variable(nn.Module)
"""

super(WeightedAttention, self).__init__()
self.gate_nn = gate_nn
self.message_nn = message_nn
self.pow = torch.nn.Parameter(torch.randn((1)))

def forward(self, fea, index, weights):
""" forward pass """

gate = self.gate_nn(fea)

gate = gate - scatter_max(gate, index, dim=0)[0][index]
gate = (weights ** self.pow) * gate.exp()
# gate = weights * gate.exp()
# gate = gate.exp()
gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-13)

fea = self.message_nn(fea)
out = scatter_add(gate * fea, index, dim=0)

return out

def __repr__(self):
return '{}(gate_nn={})'.format(self.__class__.__name__,
self.gate_nn)

class SimpleNetwork(nn.Module):
"""
Simple Feed Forward Neural Network
"""
def __init__(self, input_dim, output_dim, hidden_layer_dims):
"""
Inputs
----------
input_dim: int
output_dim: int
hidden_layer_dims: list(int)

"""
super(SimpleNetwork, self).__init__()

dims = [input_dim]+hidden_layer_dims

self.fcs = nn.ModuleList([nn.Linear(dims[i], dims[i+1])
for i in range(len(dims)-1)])
self.acts = nn.ModuleList([nn.LeakyReLU() for _ in range(len(dims)-1)])

self.fc_out = nn.Linear(dims[-1], output_dim)

def forward(self, fea):
for fc, act in zip(self.fcs, self.acts):
fea = act(fc(fea))

return self.fc_out(fea)

def __repr__(self):
return '{}'.format(self.__class__.__name__)
``````

`WeightedAttention` uses a `SimpleNetwork` for `gate_nn` and another one for `message_nn`.
The usage of these modules is shown in `WeightedAttention.forward`.
I don’t know, how `scatter_max` and `scatter_add` are implemented, but the usage of the `SimpleNetwork` modules is:

``````gate = self.gate_nn(fea)
...
fea = self.message_nn(fea)
...
``````

`SimpleNetwork` uses an `nn.ModulesList` with several linear layers, another module list to store the activation functions, as well as a final linear layer.

But how can you perform `scatter_max` or `scatter_add` to a neural network? What will it affect, the weight of the neural network or something else?

`scatter_add` can be understood as an advanced indexing method.
Have a look at the docs with the example:

``````>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
[ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
[ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
[ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
[ 0.0000,  0.0000,  0.0000,  1.2300]])
``````

`scatter_max` isn’t a PyTorch method, if I’m not mistaken, but I assume it’ll use the `max` operation instead of adding the values.

I would recommend to print all tensors, which are used in these operations, and to check the output of the operations in question.

These operations won’t affect the weights or other parametes directly, but are most likely used for some indexing.