I’m trying to write a python code of a GNN. In my node update I give the edge data by using scatter_mean, scatter_add and scatter_std. As long as I only use these first two everything works fine, but when I also use scatter_add, after the first batch the weights of the edge_encoder become nan and the model does not work any more. I have no idea how the edge model influences the encoder, but somehow this happens. Has anyone seen this before and know how to solve it?
Below is the code of my edge and node models together with the network:
class EdgeModel_1(torch.nn.Module):
def init(self):
super(EdgeModel_1, self).init()
self.edge_mlp = nn.ModuleList()
for l in range(len(NN_edge_layers_1)-1):
# Add linear layers to the neural networks
self.edge_mlp.append(nn.Linear(NN_edge_layers_1[l], NN_edge_layers_1[l+1]))
def forward(self, src, dest, edge_attr, u= None, batch = None): # information about nodes is in src and dest
# feed data to the the nn
out = torch.cat([src, dest, edge_attr], 1)
for layer in self.edge_mlp[:-1]:
out = F.elu(layer(out))
out = self.edge_mlp[-1](out)
return out
class EdgeModel_2(torch.nn.Module):
def init(self):
super(EdgeModel_2, self).init()
self.edge_mlp = nn.ModuleList()
for l in range(len(NN_edge_layers_2)-1):
# Add linear layers to the neural networks
self.edge_mlp.append(nn.Linear(NN_edge_layers_2[l], NN_edge_layers_2[l+1]))
def forward(self, src, dest, edge_attr, u= None, batch = None): # information about nodes is in src and dest
# feed data to the the nn
out = torch.cat([src, dest, edge_attr], 1)
for layer in self.edge_mlp[:-1]:
out = F.elu(layer(out))
out = self.edge_mlp[-1](out)
return out
class NodeModel_1(torch.nn.Module):
def init(self):
super(NodeModel_1, self).init()
self.node_mlp = nn.ModuleList()
for l in range(len(NN_node_layers_1)-1):
# Add linear layers to the neural networks
self.node_mlp.append(nn.Linear(NN_node_layers_1[l], NN_node_layers_1[l+1]))
def forward(self, node_attr, edge_index, edge_attr, u = None, batch = None):
row, col = edge_index
out = torch.cat([scatter_mean(edge_attr, col, dim=0, dim_size=node_attr.size(0)),scatter_add(edge_attr, row, dim=0, dim_size=node_attr.size(0)),scatter_std(edge_attr, row, dim=0, dim_size=node_attr.size(0), unbiased = False)], dim = 1)
out = torch.cat([out, node_attr],dim = 1)
for layer in self.node_mlp[:-1]:
out = F.elu(layer(out))
out = self.node_mlp[-1](out)
return out
class NodeModel_2(torch.nn.Module):
def init(self):
super(NodeModel_2, self).init()
self.node_mlp = nn.ModuleList()
for l in range(len(NN_node_layers_2)-1):
# Add linear layers to the neural networks
self.node_mlp.append(nn.Linear(NN_node_layers_2[l], NN_node_layers_2[l+1]))
def forward(self, node_attr, edge_index, edge_attr, u = None, batch = None):
row, col = edge_index
out = torch.cat([scatter_mean(edge_attr, col, dim=0, dim_size=node_attr.size(0)),scatter_add(edge_attr, row, dim=0, dim_size=node_attr.size(0)),scatter_mean(edge_attr, row, dim=0, dim_size=node_attr.size(0))], dim = 1)
out = torch.cat([out, node_attr],dim = 1)
for layer in self.node_mlp[:-1]:
out = F.elu(layer(out))
out = self.node_mlp[-1](out)
return out
class Net(nn.Module):
def init(self):
super(Net, self).init()
self.network_1= MetaLayer(EdgeModel_1(), NodeModel_1())
self.network_2= MetaLayer(EdgeModel_2(), NodeModel_2())
self.network_3= MetaLayer(EdgeModel_2(), NodeModel_2())
#self.network_4= MetaLayer(EdgeModel_2(), NodeModel_2())
#Initialize encoders and decoders
self.encoder_node = nn.ModuleList()
for l in range(len(encoder_node_layers)-1):
# Add linear layers to the neural networks
self.encoder_node.append(nn.Linear(encoder_node_layers[l], encoder_node_layers[l+1]))
self.encoder_edge = nn.ModuleList()
for l in range(len(encoder_edge_layers)-1):
# Add linear layers to the neural networks
self.encoder_edge.append(nn.Linear(encoder_edge_layers[l], encoder_edge_layers[l+1]))
self.decoder_node = nn.ModuleList()
for l in range(len(decoder_node_layers)-1):
# Add linear layers to the neural networks
self.decoder_node.append(nn.Linear(decoder_node_layers[l], decoder_node_layers[l+1]))
def forward(self, data):
node_attr, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
#First put data through the encoders
for layer in self.encoder_node[:-1]:
node_attr = F.elu(layer(node_attr))
node_attr = self.encoder_node[-1](node_attr)
for layer in self.encoder_edge[:-1]:
edge_attr = F.elu(layer(edge_attr))
edge_attr = self.encoder_edge[-1](edge_attr)
# Save initial input to feed to the neural network
node_attr_zero, edge_attr_zero = 1*node_attr, 1*edge_attr
# Update network for the first time
node_attr , edge_attr, u = self.network_1(node_attr, edge_index, edge_attr)
edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
node_attr = torch.cat([node_attr,node_attr_zero],dim =1)
node_attr , edge_attr, u = self.network_2(node_attr, edge_index, edge_attr)
edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
node_attr = torch.cat([node_attr,node_attr_zero],dim =1)
node_attr , edge_attr, u = self.network_3(node_attr, edge_index, edge_attr)
#edge_attr = torch.cat([edge_attr,edge_attr_zero],dim =1)
#node_attr = torch.cat([node_attr,node_attr_zero],dim =1)
#node_attr , edge_attr, u = self.network_4(node_attr, edge_index, edge_attr)
# Decode node information
for layer in self.decoder_node[:-1]:
node_attr = F.elu(layer(node_attr))
node_attr = self.decoder_node[-1](node_attr)
return node_attr