# Confusion matrix breaks the computation graph for custom loss function

Hi,
I want to write a custom loss function which calculates a value from the confusion matrix. But getting confusion matrix in custom loss function breaks the computation graph. Is there a possible way to write a loss function without breaking the computation graph?
Thanks a lot.

Hi @burakydev,

Can you share what youâ€™ve done so far?

My custom loss function is like given below:

``````#my loss function
from softargmax import softargmax1d
from numpy import dtype, require
from torchmetrics import ConfusionMatrix

def calc_my_loss(pred_y, real_y):

conf_matrix = ConfusionMatrix(num_classes=3)

conf_matrix = conf_matrix(softargmax1d(pred_y).type(torch.int32), real_y)

my_metrics = torch.tensor(conf_matrix[0][1].item() * ONES + conf_matrix[0][2].item() * TWOS + conf_matrix[1][0].item() * ONES + conf_matrix[1][2].item() * ONES + conf_matrix[2][0].item() * TWOS + conf_matrix[2][1].item() * ONES, dtype=torch.float32)

return my_metrics
``````

and softargmax1d function like given below. I have used this function to get gradients for confusion matrix.

``````def softargmax1d(input, beta=100):
*_, n = input.shape
input = nn.functional.softmax(beta * input, dim=-1)
indices = torch.linspace(0, 1, n)
result = torch.sum((n - 1) * input * indices, dim=-1)
return result
``````

pred_y variable has the gradients just before the getting confusion matrix. When I calculate the confusion matrix, computation graph breaks and gradients are lost.

This is a multiclass classification and I want to penalize some metrics in confusion matrix with specified weights. How can I do this?

Thank you so much.

When you call `torch.tensor(x)`, youâ€™re re-wrapping your input data and hence breaking your graph. Secondly, when you call the `.item()` method you are taking a `Tensor` and converting it to a `float`, which PyTorch can no longer track and will break your graph.

You need to re-write `calc_my_loss` without the use of `torch.tensor` and `.item()`

1 Like

Iâ€™ve just corrected my loss function according to your comments and I checked it, my_metrics grad_fn is <AddBackward0 object at 0x2a738bfa0>. I think my custom loss function is differentiable now and has gradients but I trained my model again and after loss.backward() method checked the gradients of layer_1 weights like given below (for training phase). The result is None. Where can be the problem about this, do you have any idea?

``````for e in tqdm(range(1, EPOCHS+1)):
print(f'Epoch: {e}')
# TRAINING
train_epoch_loss = 0
train_epoch_acc = 0
model.train()

y_train_pred = model(X_train_batch)

train_loss = calc_my_loss(y_train_pred, y_train_batch)
train_acc = multi_acc(y_train_pred, y_train_batch)

train_loss.backward()

optimizer.step()

train_epoch_loss += train_loss.item()
train_epoch_acc += train_acc.item()
``````

Hi Burak!

When you convert your `softargmax1d()` results to â€śhard,â€ť integer labels,
your â€śbreak the graph.â€ť This is because integers, being discrete, canâ€™t be
differentiated, so you canâ€™t backpropagate through them.

You will need to develop some sort of â€śsoftâ€ť confusion matrix function
that works (in a meaningful way) with soft, real-valued â€ślabels.â€ť

A further note: Your use of `beta = 100` in `softargmax1d()` makes
for a very hard version of `softmax()`. Even though it is technically
differentiable, its derivative will be almost exactly zero almost everywhere,
so, although you will be able to backpropagate through it, youâ€™ll just get

Best.

K. Frank

2 Likes

Hi @KFrank,

Now Iâ€™ll try to develop a soft confusion matrix function but now I donâ€™t know how to do this. I think my labels must be continuos for this and the confusion matrix must be calculated according to these ranges, am i right? Or do you know any tutorials or other resources about this topic?

Best regards.

edit:

and also is it a problem to have integer labels like 1, 2, 3 in my condition for `y_train_batch` values?

Hi Burak!

I donâ€™t understand your use case or how you want your confusion matrix
to feed into your custom loss function.

But one approach for a â€śsoftâ€ť confusion matrix would be the following:

Let `pred_i` be the predicted probability for your `i-th` class (so, typically
the output of a `softmax()`) and `targ_i` be the ground-truth probability
of your `i-th` class â€“ that is, your â€śsoftâ€ť labels.

(It is okay if your labels are â€śhard,â€ť that is, that `targ_i` is `1.0` for exactly
one value of `i` and `0.0` for all the rest.)

The `i,j` element of your â€śsoftâ€ť confusion matrix could then be chosen to
be the sum over samples of `targ_i * pred_j`. Note that as your predictions
(and labels) become hard â€“ that is, exactly one class has probability one
and all the others zero â€“ your soft confusion matrix will become the
conventional confusion matrix.

Because the predictions used in your soft confusion matrix are now
continuous probabilities, you will not have any issue backpropagating
through the confusion matrix.

(But note my previous comment that if you harden your `softmax()` with
a `beta` as large as `100`, you wonâ€™t be able to usefully backpropagate
through that `softmax()`.)

Best.

K. Frank

1 Like

Hi @KFrank,
I have multiclass (number_of_classes=3) problem. In my custom loss function, i want to penalize the false predictions according to the confusion matrix. For example, if the predicted class is 1 but actual class is 2, i will multiply the number of wrong predictions with 1. And if the predicted class is 1 but actual class is 3 i will multiply the number of false predictions with 5 and so on. After all, i will sum the weighted penalty scores.
In my custom loss function, my aim is minimizing this weighted penalty score.

Is this approach and this soft confusion matrix convenient for my custom loss function? I will try to implement this soft confusion matrix.

Thank you so much.
Best regards.