Confusion matrix breaks the computation graph for custom loss function

Hi,
I want to write a custom loss function which calculates a value from the confusion matrix. But getting confusion matrix in custom loss function breaks the computation graph. Is there a possible way to write a loss function without breaking the computation graph?
Thanks a lot.

Hi @burakydev,

Can you share what you’ve done so far?

My custom loss function is like given below:

#my loss function
from softargmax import softargmax1d
from numpy import dtype, require
from torchmetrics import ConfusionMatrix

def calc_my_loss(pred_y, real_y):
    ONES = torch.tensor(1, dtype=torch.float32, requires_grad=True)
    TWOS = torch.tensor(5, dtype=torch.float32, requires_grad=True)

    conf_matrix = ConfusionMatrix(num_classes=3)

    #print(f'softpred_ygradfn: {softargmax1d(pred_y).grad_fn}')
    conf_matrix = conf_matrix(softargmax1d(pred_y).type(torch.int32), real_y)


    my_metrics = torch.tensor(conf_matrix[0][1].item() * ONES + conf_matrix[0][2].item() * TWOS + conf_matrix[1][0].item() * ONES + conf_matrix[1][2].item() * ONES + conf_matrix[2][0].item() * TWOS + conf_matrix[2][1].item() * ONES, dtype=torch.float32)

    return my_metrics

and softargmax1d function like given below. I have used this function to get gradients for confusion matrix.

def softargmax1d(input, beta=100):
    *_, n = input.shape
    input = nn.functional.softmax(beta * input, dim=-1)
    indices = torch.linspace(0, 1, n)
    result = torch.sum((n - 1) * input * indices, dim=-1)
    return result

pred_y variable has the gradients just before the getting confusion matrix. When I calculate the confusion matrix, computation graph breaks and gradients are lost.

This is a multiclass classification and I want to penalize some metrics in confusion matrix with specified weights. How can I do this?

Thank you so much.

When you call torch.tensor(x), you’re re-wrapping your input data and hence breaking your graph. Secondly, when you call the .item() method you are taking a Tensor and converting it to a float, which PyTorch can no longer track and will break your graph.

You need to re-write calc_my_loss without the use of torch.tensor and .item()

1 Like

I’ve just corrected my loss function according to your comments and I checked it, my_metrics grad_fn is <AddBackward0 object at 0x2a738bfa0>. I think my custom loss function is differentiable now and has gradients but I trained my model again and after loss.backward() method checked the gradients of layer_1 weights like given below (for training phase). The result is None. Where can be the problem about this, do you have any idea?

for e in tqdm(range(1, EPOCHS+1)):
    print(f'Epoch: {e}')
    # TRAINING
    train_epoch_loss = 0
    train_epoch_acc = 0
    model.train()

    for X_train_batch, y_train_batch in train_loader:
        optimizer.zero_grad()
        
        y_train_pred = model(X_train_batch)
        
        train_loss = calc_my_loss(y_train_pred, y_train_batch)
        train_acc = multi_acc(y_train_pred, y_train_batch)
        
        train_loss.backward()
        print(f'layer1GRADIENTS: {model.layer_1.weight.grad}')
        
        optimizer.step()
        
        train_epoch_loss += train_loss.item()
        train_epoch_acc += train_acc.item()     

Hi Burak!

When you convert your softargmax1d() results to “hard,” integer labels,
your “break the graph.” This is because integers, being discrete, can’t be
differentiated, so you can’t backpropagate through them.

You will need to develop some sort of “soft” confusion matrix function
that works (in a meaningful way) with soft, real-valued “labels.”

A further note: Your use of beta = 100 in softargmax1d() makes
for a very hard version of softmax(). Even though it is technically
differentiable, its derivative will be almost exactly zero almost everywhere,
so, although you will be able to backpropagate through it, you’ll just get
zeros for the gradients and your training won’t actually train.

Best.

K. Frank

2 Likes

Hi @KFrank,
Thanks for your reply.

Now I’ll try to develop a soft confusion matrix function but now I don’t know how to do this. I think my labels must be continuos for this and the confusion matrix must be calculated according to these ranges, am i right? Or do you know any tutorials or other resources about this topic?

Best regards.

edit:

and also is it a problem to have integer labels like 1, 2, 3 in my condition for y_train_batch values?

Hi Burak!

I don’t understand your use case or how you want your confusion matrix
to feed into your custom loss function.

But one approach for a “soft” confusion matrix would be the following:

Let pred_i be the predicted probability for your i-th class (so, typically
the output of a softmax()) and targ_i be the ground-truth probability
of your i-th class – that is, your “soft” labels.

(It is okay if your labels are “hard,” that is, that targ_i is 1.0 for exactly
one value of i and 0.0 for all the rest.)

The i,j element of your “soft” confusion matrix could then be chosen to
be the sum over samples of targ_i * pred_j. Note that as your predictions
(and labels) become hard – that is, exactly one class has probability one
and all the others zero – your soft confusion matrix will become the
conventional confusion matrix.

Because the predictions used in your soft confusion matrix are now
continuous probabilities, you will not have any issue backpropagating
through the confusion matrix.

(But note my previous comment that if you harden your softmax() with
a beta as large as 100, you won’t be able to usefully backpropagate
through that softmax().)

Best.

K. Frank

1 Like

Hi @KFrank,
I have multiclass (number_of_classes=3) problem. In my custom loss function, i want to penalize the false predictions according to the confusion matrix. For example, if the predicted class is 1 but actual class is 2, i will multiply the number of wrong predictions with 1. And if the predicted class is 1 but actual class is 3 i will multiply the number of false predictions with 5 and so on. After all, i will sum the weighted penalty scores.
In my custom loss function, my aim is minimizing this weighted penalty score.

Is this approach and this soft confusion matrix convenient for my custom loss function? I will try to implement this soft confusion matrix.

Thank you so much.
Best regards.