Hey everyone,
Currently I am developing an multiclass classifikator. I try to replace the nn.CrossEntropyLoss with my own loss function. The background is that I want to penalize some missclassification stronger than others. For example I have 4 classes, then the following matrix will describe the penalty:
0 , 1 , 4 , 5
1 , 0 , 3 , 8
4 , 3 , 0 , 2
5 , 8 , 2 , 0
As you can see, a misclassification from a sample from class 1 as class 2 is not so fatal as class 4. In the first case, my loss function should return 1 and in the second case 5. I started to implement the function as follow:
class MyLoss(nn.Module):
def __init__(self):
super(MyLoss, self).__init__()
def forward(self, output, target):
_, predicted = torch.max(output, axis=1)
loss = []
for i in range(len(predicted):
loss.append(self._loss_matrix[predicted[i].cpu().numpy()][target[i].cpu().numpy()]
return torch.mean(loss)
I figured out that the loss can not be backpropagated properly since the torch.max function destroys the backpropagation graph. Has anybody any idea how to replace the torch.max function with a differentiable function. I know that I can use softmax(output)[1] for a 2 class classification problem. Is there something similar for the n class problem?
Thanks in advance!!