Problem creating a custom loss function

Hi everyone,
I’m currently facing a problem trying to create my own loss function. After watching several tutorials, I thought I had done it the good way, the problem is that my loss does not decrease. It’s perfectly constant. I used Variables with requires_grad=True and only performed torch operations so I really do not understand. Does anyone have an idea?
You can find my code below:

class DistanceLoss(torch.nn.Module):

def __init__(self):
    super(DistanceLoss, self).__init__()

def forward(self, output, target):
    output = Variable(output, requires_grad=True).to('cuda')
    target = Variable(target, requires_grad=True).to('cuda')

   
    binarized_output = torch.argmin(output, 1).type(dtype)
    one_hot_output = torch.nn.functional.one_hot(binarized_output.to(torch.int64)).type(dtype)[0, :, :, :]

    for c in range(2):
        target_coordinates = ((target[:, :, c] == 1).nonzero(as_tuple=False)).type(dtype)
        output_coordinates = ((one_hot_output[:, :, c] == 1).nonzero(as_tuple=False)).type(dtype)
        dist_matrix = torch.cdist(output_coordinates, target_coordinates,
                                  p=2.0, compute_mode='use_mm_for_euclid_dist_if_necessary')
        loss = torch.sum(torch.amin(dist_matrix, 1))
    return Variable(loss, requires_grad=True).to('cuda')

What I understand is that the different operations I perform from the predicted image to the computed loss do not preserve the grad_fn of the different tensors. Nevertheless, I must compute these operations in order to calculate tha loss. Is there a way I can do it?

Thank you very much!

Variable is deprecated and no longer used.
Besides, it resets the tensor history so backpropagation is cut there.
Lastly, you are applying many slicing and conditioning. Realise backprop only happens for the chosen values.

I don’t think you have to force the output to require the gradient (for target this is not often necessary, as it doesn’t depend on the model parameters you want to update), it is normally supposed to come from the forward pass : if the model parameters require the gradient, the output of the model will also require the gradient, as well as the final loss (assuming that you don’t use almost everywhere zero derivative functions, like the indicators in your model or for the loss calculation).

Start by checking the grad_fn (MulBackward, SumBackward, ...) attribute of output and loss: if they don’t have one, this could be the cause of the error

I take an example to explain what I mean:

import torch
class Model(torch.nn.Module) :
    def __init__(self) :
        super().__init__()
        self.a = torch.nn.parameter.Parameter(data = torch.tensor(1.), requires_grad = True)

    def forward(self, x) :
        """The output automatically requires the gradient, but if you leave requires_grad = False in self.a, 
        this will change, and you will not be able to do backward the loss (loss.backward())""" 
        return self.a * x  

def mae(y, y_pred) :
    """mean absolute error"""
    return (y - y_pred).abs().sum()
# or 
class MAE(torch.nn.Module):
    """One of the main advantages of defining a loss function as a class is that you can do loss.to('device_name')"""
    def __init__(self) :
        super().__init__()
        
    def forward(self, y, y_pred) :
        return (y - y_pred).abs().sum() # mae(y, y_pred)
# data : no need to require the gradient (except if you want to calculate dy_pred/dy for example)
x = torch.tensor([1., 3, 3])
y = 2 * x
# model
model = Model()
# loss (optional, we can directly use the mae function here)
mae = MAE() 
  1. Normal process
# model.a.grad is None 
# zero all the gradient (example : optimizer.zero_grad())
y_pred = model(x) # tensor([1., 3., 3.], grad_fn=<MulBackward0>)
loss = mae(y_pred, y) # tensor(7., grad_fn=<SumBackward0>)

loss.backward() 
# model.a.grad = tensor(-7.)

In this case, optimizer.step() will update the parameters (a = a - learning_rate*a.grad for example), and the loss will change at the next iteration, because the model parameters have changed (and so the final output too)

  1. Abnormal process
model = Model()
# model.a.grad is None 
# zero all the gradient (example : optimizer.zero_grad())

y_pred = model(x).detach().requires_grad_(True) # tensor([1., 3., 3.], requires_grad=True)
loss = mae(y_pred, y) # tensor(7., grad_fn=<SumBackward0>)
loss.backward()

# model.a.grad is None

a.grad is always equal to None (or 0, if you zero_grad, so did not change : hence a = a - learning_rate*a.grad = a)

model(x).detach() makes model(x) no longer depend on model.a, .requires_grad_(True) doesn’t change this, but still allows the loss to depend on y_pred, as you can see:

y_pred.grad = tensor([-1., -1., -1.]) # d(y-y_pred)/dy_pred = -1

So, make sure you don’t force anything by yourself, otherwise you will break the computation graph (like me with .detach())

Thank you very much for answering? Indeed, I realised that some operations like:

binarized_output = torch.argmin(output, 1).type(dtype)

make binarized_output.grad_f,n equal to None.

As a consequence, I must find a way to perform those operations which are not differentiable right?

Yes, by using such an operation, you cannot calculate the gradient by usual backpropagation (because the gradient of the loss with respect to the parameter before this non-differentiable function will be zero : in your case it is all the parameters since you do it at the network output).

Alright, in the end I completely changed the way I compute the loss so it is differentiable.
Thanks for your help!