Hi.
My GraphNet predicts for all events in one batch the same result. This output is about the average of all labels within the batch. But I have an optimization problem and my labels are pretty unique.
The solution is easy, changing the batchsize to 1. Now my model trains just fine and I can reach a good accuracy. Still I want to go back to train on batches to decrease the training time.
I already did optimization on learning rate, width and depth of my model, Loss Function, activation function, dropout, and pooling. And for batchsize 1 the training results are fine.
Here the Code of my model:
class Net(nn.Module):
def __init__(self, n_feats_fc, in_feats_g, parallel_layers, Dropout):
super(Net, self).__init__()
self.edge1 = EdgeConv(50, 100)
self.edge2 = EdgeConv(100, 300)
self.edge3 = EdgeConv(300, 600)
self.Dropout = nn.Dropout(Dropout)
self.pooling = MaxPooling()
self.fc1 = nn.Linear(600,200)
self.fc2 = nn.Linear(200,200)
self.fc3 = nn.Linear(200,200)
self.fc4 = nn.Linear(200,200)
self.fc_out = nn.Linear(200, 9)
def forward(self, graph, batch_size):
feat = torch.tanh(self.edge1(graph ,graph.ndata['x']))
feat = self.Dropout(feat)
feat = torch.tanh(self.edge2(graph, feat))
feat = torch.tanh(self.edge3(graph, feat))
#feat = torch.max(feat, dim = 1)[0]
feat = self.pooling(graph, feat)
feat = torch.tanh(self.fc1(feat))
feat = self.Dropout(feat)
feat = torch.tanh(self.fc2(feat))
feat = torch.tanh(self.fc3(feat))
feat = torch.tanh(self.fc4(feat))
out = torch.clamp(self.fc_out(feat), min=-2, max=2)
del feat
gc.collect()
return out, graph
I use tanh, because my output is in range [-2,2]. I have also a custom Loss Function, but it neigther worked with MSELoss.
Here my Loss:
class MyLoss(_Loss):
__constants__ = ['reduction']
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, pred, tru):
loss = torch.tensor([0])
for p,t in list(zip(pred, tru)):
l = .... #loss calculation
loss = torch.add(loss,l.item())
loss.requires_grad = True
return l
Maybe it is a back propagation problem. But I don’t understand why it doesn’t work.
I hope anyone can help me.