Scatter_std changes weights of encoder network

I’m trying to write a python code of a GNN. In my node update I give the edge data by using scatter_mean, scatter_add and scatter_std. As long as I only use these first two everything works fine, but when I also use scatter_add, after the first batch the weights of the edge_encoder become nan and the model does not work any more. I have no idea how the edge model influences the encoder, but somehow this happens. Has anyone seen this before and know how to solve it?

Below is the code of my edge and node models together with the network:

class EdgeModel_1(torch.nn.Module):
def init(self):
super(EdgeModel_1, self).init()
self.edge_mlp = nn.ModuleList()

    for l in range(len(NN_edge_layers_1)-1):
        # Add linear layers to the neural networks
        self.edge_mlp.append(nn.Linear(NN_edge_layers_1[l], NN_edge_layers_1[l+1]))
    
def forward(self, src, dest, edge_attr, u= None, batch = None): # information about nodes is in src and dest
     # feed data to the the nn
    out = torch.cat([src, dest, edge_attr], 1)
    for layer in self.edge_mlp[:-1]:
        out =  F.elu(layer(out))
    out = self.edge_mlp[-1](out)
    return out

class EdgeModel_2(torch.nn.Module):
def init(self):
super(EdgeModel_2, self).init()
self.edge_mlp = nn.ModuleList()

    for l in range(len(NN_edge_layers_2)-1):
        # Add linear layers to the neural networks
        self.edge_mlp.append(nn.Linear(NN_edge_layers_2[l], NN_edge_layers_2[l+1]))
    
def forward(self, src, dest, edge_attr, u= None, batch = None): # information about nodes is in src and dest
     # feed data to the the nn
    out = torch.cat([src, dest, edge_attr], 1)

    for layer in self.edge_mlp[:-1]:
        out =  F.elu(layer(out))

    out = self.edge_mlp[-1](out)
    
    return out

class NodeModel_1(torch.nn.Module):
def init(self):
super(NodeModel_1, self).init()
self.node_mlp = nn.ModuleList()

    for l in range(len(NN_node_layers_1)-1):
        # Add linear layers to the neural networks
        self.node_mlp.append(nn.Linear(NN_node_layers_1[l], NN_node_layers_1[l+1]))
    
def forward(self, node_attr, edge_index, edge_attr, u = None, batch = None):
    row, col = edge_index
    
    out = torch.cat([scatter_mean(edge_attr, col, dim=0, dim_size=node_attr.size(0)),scatter_add(edge_attr, row, dim=0, dim_size=node_attr.size(0)),scatter_std(edge_attr, row, dim=0, dim_size=node_attr.size(0), unbiased = False)], dim = 1)
    out = torch.cat([out, node_attr],dim = 1)

    for layer in self.node_mlp[:-1]:
        out =  F.elu(layer(out))
    out = self.node_mlp[-1](out)
    return out

class NodeModel_2(torch.nn.Module):
def init(self):
super(NodeModel_2, self).init()
self.node_mlp = nn.ModuleList()

    for l in range(len(NN_node_layers_2)-1):
        # Add linear layers to the neural networks
        self.node_mlp.append(nn.Linear(NN_node_layers_2[l], NN_node_layers_2[l+1]))
    
def forward(self, node_attr, edge_index, edge_attr, u = None, batch = None):
    row, col = edge_index
    out = torch.cat([scatter_mean(edge_attr, col, dim=0, dim_size=node_attr.size(0)),scatter_add(edge_attr, row, dim=0, dim_size=node_attr.size(0)),scatter_mean(edge_attr, row, dim=0, dim_size=node_attr.size(0))], dim = 1)
    out = torch.cat([out, node_attr],dim = 1)
    for layer in self.node_mlp[:-1]:
        out =  F.elu(layer(out))
    out = self.node_mlp[-1](out)
    return out

class Net(nn.Module):
def init(self):
super(Net, self).init()
self.network_1= MetaLayer(EdgeModel_1(), NodeModel_1())
self.network_2= MetaLayer(EdgeModel_2(), NodeModel_2())
self.network_3= MetaLayer(EdgeModel_2(), NodeModel_2())
#self.network_4= MetaLayer(EdgeModel_2(), NodeModel_2())
#Initialize encoders and decoders
self.encoder_node = nn.ModuleList()
for l in range(len(encoder_node_layers)-1):
# Add linear layers to the neural networks
self.encoder_node.append(nn.Linear(encoder_node_layers[l], encoder_node_layers[l+1]))

    self.encoder_edge = nn.ModuleList()
    for l in range(len(encoder_edge_layers)-1):
        # Add linear layers to the neural networks
        self.encoder_edge.append(nn.Linear(encoder_edge_layers[l], encoder_edge_layers[l+1]))
    
    self.decoder_node = nn.ModuleList()
    for l in range(len(decoder_node_layers)-1):
        # Add linear layers to the neural networks
        self.decoder_node.append(nn.Linear(decoder_node_layers[l], decoder_node_layers[l+1]))

def forward(self, data):

    node_attr, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
    #First put data through the encoders
    
    for layer in self.encoder_node[:-1]:
        node_attr =  F.elu(layer(node_attr))
    node_attr = self.encoder_node[-1](node_attr)
    for layer in self.encoder_edge[:-1]:
        edge_attr =  F.elu(layer(edge_attr))
    edge_attr = self.encoder_edge[-1](edge_attr)
    
    # Save initial input to feed to the  neural network 
    node_attr_zero, edge_attr_zero = 1*node_attr, 1*edge_attr

    # Update network for the first time
    node_attr , edge_attr, u = self.network_1(node_attr, edge_index, edge_attr)
    
    edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
    node_attr = torch.cat([node_attr,node_attr_zero],dim =1)

    node_attr , edge_attr, u = self.network_2(node_attr, edge_index, edge_attr)
    
    edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
    node_attr = torch.cat([node_attr,node_attr_zero],dim =1)
    node_attr , edge_attr, u = self.network_3(node_attr, edge_index, edge_attr)

    #edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
    #node_attr = torch.cat([node_attr,node_attr_zero],dim =1)
    #node_attr , edge_attr, u = self.network_4(node_attr, edge_index, edge_attr)
    
    # Decode node information
    for layer in self.decoder_node[:-1]:
        node_attr =  F.elu(layer(node_attr))
    node_attr = self.decoder_node[-1](node_attr)

    return node_attr

I don’t know where these scatter_ methods are defined, but since you are getting invalid values using scatter_std I assume that this particular method might return NaN (outputs or gradients) for edge cases, e.g. when the values are constant.