Custom loss functions

Sure, here is the simple version without weighting, different reduction types etc:

def my_cross_entropy(x, y):
    log_prob = -1.0 * F.log_softmax(x, 1)
    loss = log_prob.gather(1, y.unsqueeze(1))
    loss = loss.mean()
    return loss


criterion = nn.CrossEntropyLoss()

batch_size = 5
nb_classes = 10
x = torch.randn(batch_size, nb_classes, requires_grad=True)
y = torch.randint(0, nb_classes, (batch_size,))

loss_reference = criterion(x, y)
loss = my_cross_entropy(x, y)

print(loss_reference - loss)
> tensor(0., grad_fn=<SubBackward0>)
11 Likes