Custon dice_loss function does not minimize the loss

I am new to pytorch. I’m working on semantic segmentation, so I like to use the dice_loss to update the model’s parameters (Previously, I tested the model with the CrossEntropy loss function and works reasonably).

Here, some code:

for epoch in range(epochs):
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(inputs)
_, predicted = torch.max(outputs.data, 1) # since I will work with 1s and 0s
loss = dice_loss(predicted,labels)
loss.backward() # Backward propagation
optimizer.step()

A loss output is as like tensor(0.7531, grad_fn=<AddBackward0>), and loss.backward() does not fails, but in my test the loss never decreases after it inits.

Here my loss function in details:

def dice_loss(predicted, labels):
    """Dice coeff loss for a batch"""
    
    # both the predicted and the labels data are being one-hot encoded
    onehot_pred = torch.Tensor()
    onehot_lab = torch.Tensor()    
    for batch, data in enumerate(zip(predicted, labels)):
        # to_categorical is the KERAS adapted function
        pred = utils.to_categorical(data[0]).unsqueeze(dim=0)
        lab = utils.to_categorical(data[1]).unsqueeze(dim=0)        
        onehot_pred = torch.cat((onehot_pred,pred),dim=0)
        onehot_lab = torch.cat((onehot_lab,lab),dim=0)
    
    # calculate the loss function
    ratio = 1 / predicted.size(0) # instead to divide by the batch_size
    
    # loss accumulator
    dc = torch.tensor(0).float()
    
    # I put required gradient in order to create a grad_fn
    # without loss.backward() does not work
    onehot_pred.requires_grad = True
    onehot_lab.requires_grad = True
    
    for batch , data in enumerate(zip(onehot_pred,onehot_lab)):
        dc += dice_coeff(data[0], data[1])*ratio
    return dc

and the other function

def dice_coeff(predicted, labels):
“”“Dice coeff for a simple plane”“”
eps = 0.0001
inter = torch.dot(predicted.view(-1), labels.view(-1))
union = torch.sum(predicted) + torch.sum(labels) + eps
dice_coeff = 1 - (2 * inter.float() + eps) / union.float()
return dice_coeff

Hi

The problem is that if you have to do add onehot_pred.requires_grad = True, that means it was False before. And so that means that some non-differentiable operation was applied to it.

Also fyi you can change dc = torch.tensor(0).float() to dc = 0.
And onehot_pred = torch.Tensor() to onehot_pred = [], onehot_pred = torch.cat((onehot_pred,pred),dim=0) to onehot_pred.append(pred) and add after the loop onehot_pred = torch.cat(onehot_pred, dim=0).

I am seeing that the problem starts at

`_, predicted = torch.max(outputs.data, 1) `

I read that torch.max or torch.argmax are non differentiable functions. Until this line, outputs.grad_fn has backward and requires_grad is True, but predicted is an object torch.return_types.max, so it doesn’t have grad_fn and it can no more propagate gradients backward.

Right.
You should never use .data. It is both breaking the graph and unsafe.
Max is a differentiable function where the gradient just flows back to the maximum value and all the other entries get 0.
Argmax on the other hand is not differentiable as it return integer values. For which you can’t get gradients.

Hi @albanD, sorry for my delayed answer. I hope you help me. I modified my function in many ways (using Function, using Module,…), but the final computed loss is only a tensor, without any grad_fn (That is, it is out the computation graph).

My goal is get dice_coeff and backward it (the formula isn’t exact).

Can you allow me to show line by line?

criterion = DiceCoeffLoss() # custom loss function declaration, its code is below

outputs = model(inputs)

and his gradient function:
outputs.grad_fn Out[285]: <SigmoidBackward at 0x1b936013e08 # rigth

loss = criterion(outputs,targets)

outputs shape is [n_batch, n_class, H, W]
target shape is [n_batch, H, W]

Into criterion, in order to calculate the loss, first I have to get the predicted labels along the channels:

pred = torch.max(outputs, dim=1)

Here, pred has no direct attribute grad_fn, pred.values yet contains the outputs values and pred.values.grad_fn has the gradient function <MaxBackward0 at 0x1b93601d1c8>.

However, since next operations is based on the predicted labels contained in pred.indices , the final loss is only a tensor.

Here the code (disclaimer, the formula isn’t the rigth):

def dice_coeff(predicted, targets):
    smooth = 1.0        
    inter = predicted.eq(targets).sum() # only true predited
    union = float(predicted.nelement() + targets.nelement() ) # all elements        
    dc = (2*inter + smooth) / (union + smooth)        
    return  dc

# I used nn.Module thinking it will add those operations to the graph    
class DiceCoeffLoss(nn.Module):
    def __init__(self):
        super(DiceCoeffLoss, self).__init__()        
         
    def forward(self, outputs, targets):
        
        pred = torch.max(outputs, dim=1)
        coeff = dice_coeff(pred.indices,targets)    
        return 1 - coeff

As I said above, the problem is that the argmax function is not differentiable. You will have to use something else.
You can try to use soft attention where instead of the one hot encoding of the argmax, you use the result of the softmax layer.
But that will depend on your exact problem what will work or not.

I resolved it. I learned the model output only should be transformed by some differentiable function (grad_fn) , such as exp, log, etc, in order to compute some loss value. Perhaps it was my lack of fundamentals of DL and Pytorch skills. Thank again.