How does Neural Network as Input work?

Recently I have been studying GNN and I came across the following piece of code which I am unable to understand:

self.pooling = nn.ModuleList([WeightedAttention(
            gate_nn=SimpleNetwork(2*fea_len, 1, hidden_ele),
            message_nn=SimpleNetwork(2*fea_len, fea_len, hidden_msg),
            # message_nn=nn.Linear(2*fea_len, fea_len),
            # message_nn=nn.Identity(),
            ) for _ in range(num_heads)])

Is the WeightedAttention taking Simple Network as Input which comprises of linear layer? Also, how does it work inside the WeightedAttention class? The WeightedAttention and Simple Network class are as follows:

class WeightedAttention(nn.Module):
    """
    Weighted softmax attention layer
    """
    def __init__(self, gate_nn, message_nn, num_heads=1):
        """
        Inputs
        ----------
        gate_nn: Variable(nn.Module)
        """
        
        super(WeightedAttention, self).__init__()
        self.gate_nn = gate_nn
        self.message_nn = message_nn
        self.pow = torch.nn.Parameter(torch.randn((1)))

    def forward(self, fea, index, weights):
        """ forward pass """

        gate = self.gate_nn(fea)

        gate = gate - scatter_max(gate, index, dim=0)[0][index]
        gate = (weights ** self.pow) * gate.exp()
        # gate = weights * gate.exp()
        # gate = gate.exp()
        gate = gate / (scatter_add(gate, index, dim=0)[index] + 1e-13)

        fea = self.message_nn(fea)
        out = scatter_add(gate * fea, index, dim=0)

        return out

    def __repr__(self):
        return '{}(gate_nn={})'.format(self.__class__.__name__,
                                       self.gate_nn)


class SimpleNetwork(nn.Module):
    """
    Simple Feed Forward Neural Network
    """
    def __init__(self, input_dim, output_dim, hidden_layer_dims):
        """
        Inputs
        ----------
        input_dim: int
        output_dim: int
        hidden_layer_dims: list(int)

        """
        super(SimpleNetwork, self).__init__()

        dims = [input_dim]+hidden_layer_dims
        
        self.fcs = nn.ModuleList([nn.Linear(dims[i], dims[i+1])
                                  for i in range(len(dims)-1)])
        self.acts = nn.ModuleList([nn.LeakyReLU() for _ in range(len(dims)-1)])

        self.fc_out = nn.Linear(dims[-1], output_dim)

    def forward(self, fea):
        for fc, act in zip(self.fcs, self.acts):
            fea = act(fc(fea))

        return self.fc_out(fea)

    def __repr__(self):
        return '{}'.format(self.__class__.__name__)

WeightedAttention uses a SimpleNetwork for gate_nn and another one for message_nn.
The usage of these modules is shown in WeightedAttention.forward.
I don’t know, how scatter_max and scatter_add are implemented, but the usage of the SimpleNetwork modules is:

gate = self.gate_nn(fea)
...
fea = self.message_nn(fea)
...

SimpleNetwork uses an nn.ModulesList with several linear layers, another module list to store the activation functions, as well as a final linear layer.

But how can you perform scatter_max or scatter_add to a neural network? What will it affect, the weight of the neural network or something else?

scatter_add can be understood as an advanced indexing method.
Have a look at the docs with the example:

>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992,  0.2908,  0.9044,  0.4850,  0.6004],
        [ 0.5735,  0.9006,  0.6797,  0.4152,  0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992,  0.9006,  0.6797,  0.4850,  0.6004],
        [ 0.0000,  0.2908,  0.0000,  0.4152,  0.0000],
        [ 0.5735,  0.0000,  0.9044,  0.0000,  0.1732]])

>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000,  0.0000,  1.2300,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  1.2300]])

scatter_max isn’t a PyTorch method, if I’m not mistaken, but I assume it’ll use the max operation instead of adding the values.

I would recommend to print all tensors, which are used in these operations, and to check the output of the operations in question.

These operations won’t affect the weights or other parametes directly, but are most likely used for some indexing.