Custom loss functions

Hi,

I’m implementing a custom loss function in Pytorch 0.4. Reading the docs and the forums, it seems that there are two ways to define a custom loss function:

  • Extending Function and implementing forward and backward methods.
  • Extending Module and implementing only the forward method.

With that in mind, my questions are:

  • Can I write a python function that takes my model outputs as inputs and use torch.* functions to compute my loss function (without extending Function or Module)? If not, why?

Regards,

10 Likes

Sure, as long as you use PyTorch operations, you should be fine.
Here is a dummy implementation of nn.MSELoss using the mean:

def my_loss(output, target):
    loss = torch.mean((output - target)**2)
    return loss

model = nn.Linear(2, 2)
x = torch.randn(1, 2)
target = torch.randn(1, 2)
output = model(x)
loss = my_loss(output, target)
loss.backward()
print(model.weight.grad)
58 Likes

This is a quite simple implementation of custom loss functions while there are not extra parameters. Could you please share some solutions to fix this problem?
For example, in keras, you can implement weighted loss by following:

def label_depend_loss(alpha):
     def label_depend(output, target):
     pos_loss = something
     neg_loss = something
     return pos_loss + alpha * neg_loss

Besides, BCELoss may doesn’t suit this case.

1 Like

Depending on your loss function, you could just multiply the positive and negative losses with your weights.
Maybe nn.BCEWithLogitsLoss might fit your use case providing pos_weight.

2 Likes

Yes, Thanks for you advise. I can split it into two kinds of loss function and just sum it up weightly. Much thanks!

Dear @ptrblck,

Could you further explain the weight in nn.CrossentropyLoss() and nn.BCELoss, pos_weight in nn.BCEWithLogitsLoss()?

  • weight in CrossentropyLoss is a Tensor of size C, but why does it should have the size of nbatch in nn.BCELoss()? And it seems that weight in BCELoss does not work for unbalanced data, right? ( because the weight is related to nbatch )
  • Does pos_weight has the same effect with weight in nn.CrossentropyLoss?

Thanks in advance.

The weight argument in nn.BCE(WithLogits)Loss has the shape of the input batch, since the loss functions take floating point targets, which does not correspond to a class weighting schema. pos_weight on the other side is closer to a class weighting, as it only weights the positive examples. Furthermore, you can balance the recall and precision changing the pos_weight argument.

2 Likes

Hey @ptrblck can you share, a similar dummy function to cross entropy loss. It would be helpful to me. Thanks in Advance!

Sure, here is the simple version without weighting, different reduction types etc:

def my_cross_entropy(x, y):
    log_prob = -1.0 * F.log_softmax(x, 1)
    loss = log_prob.gather(1, y.unsqueeze(1))
    loss = loss.mean()
    return loss


criterion = nn.CrossEntropyLoss()

batch_size = 5
nb_classes = 10
x = torch.randn(batch_size, nb_classes, requires_grad=True)
y = torch.randint(0, nb_classes, (batch_size,))

loss_reference = criterion(x, y)
loss = my_cross_entropy(x, y)

print(loss_reference - loss)
> tensor(0., grad_fn=<SubBackward0>)
11 Likes

but the loss inputs are tensors, how does it work?

Could you add some more details to your question please?
What is confusing about input tensors to a loss function? :slight_smile:

i mean, for each batch the input of the loss function is a list of all predictions and labels in the current batch, and the loss is built for input of only one prediction and label

so how it should be implement?

and another thing - how the backward() of costume function should be implemented?

i mean, for each batch the input of the loss function is a list of all predictions and labels in the current batch, and the loss is built for input of only one prediction and label

so how it should be implement?

and another thing - how the backward() of costume function should be implemented?

I think you could index your output and target at the desired location and pass it to your criterion:

model = models.resnet18()
output = model(torch.randn(10, 3, 224, 224))
target = torch.randint(0, 1000, (10,))

criterion = nn.CrossEntropyLoss()

loss = criterion(output[0:1], target[0:1])
loss.backward()
2 Likes

@netaglazer
I believe if you are worried about the first dimension being the Batch index, pytorch automatically extracts the individual predictions and accumulated the loss as batch loss. So, you can write your loss function assuming your batch has only one sample.
@ptrblck could you please correct me if my understanding about loss function above is wrong?

Hi, Could you please help with some code I have?

Feel free to create a new topic with your problem description as well as your code, so that we can have a look. :slight_smile:

1 Like

opened

Do we need to implement forward pass for my_cross_entropy function?

my_cross_entropy is implemented as a simple function so you can just call it.
You could of course wrap it in an nn.Module and put the operations in the forward method, if that’s more convenient or if you need to store some internal states.

1 Like