My GraphNet predicts for all events in one batch the same result. This output is about the average of all labels within the batch. But I have an optimization problem and my labels are pretty unique.
The solution is easy, changing the batchsize to 1. Now my model trains just fine and I can reach a good accuracy. Still I want to go back to train on batches to decrease the training time.
I already did optimization on learning rate, width and depth of my model, Loss Function, activation function, dropout, and pooling. And for batchsize 1 the training results are fine.
Here the Code of my model:
class Net(nn.Module): def __init__(self, n_feats_fc, in_feats_g, parallel_layers, Dropout): super(Net, self).__init__() self.edge1 = EdgeConv(50, 100) self.edge2 = EdgeConv(100, 300) self.edge3 = EdgeConv(300, 600) self.Dropout = nn.Dropout(Dropout) self.pooling = MaxPooling() self.fc1 = nn.Linear(600,200) self.fc2 = nn.Linear(200,200) self.fc3 = nn.Linear(200,200) self.fc4 = nn.Linear(200,200) self.fc_out = nn.Linear(200, 9) def forward(self, graph, batch_size): feat = torch.tanh(self.edge1(graph ,graph.ndata['x'])) feat = self.Dropout(feat) feat = torch.tanh(self.edge2(graph, feat)) feat = torch.tanh(self.edge3(graph, feat)) #feat = torch.max(feat, dim = 1) feat = self.pooling(graph, feat) feat = torch.tanh(self.fc1(feat)) feat = self.Dropout(feat) feat = torch.tanh(self.fc2(feat)) feat = torch.tanh(self.fc3(feat)) feat = torch.tanh(self.fc4(feat)) out = torch.clamp(self.fc_out(feat), min=-2, max=2) del feat gc.collect() return out, graph
I use tanh, because my output is in range [-2,2]. I have also a custom Loss Function, but it neigther worked with MSELoss.
Here my Loss:
class MyLoss(_Loss): __constants__ = ['reduction'] def __init__(self): super(MyLoss, self).__init__() def forward(self, pred, tru): loss = torch.tensor() for p,t in list(zip(pred, tru)): l = .... #loss calculation loss = torch.add(loss,l.item()) loss.requires_grad = True return l
Maybe it is a back propagation problem. But I don’t understand why it doesn’t work.
I hope anyone can help me.