Hi,

I want to write a custom loss function which calculates a value from the confusion matrix. But getting confusion matrix in custom loss function breaks the computation graph. Is there a possible way to write a loss function without breaking the computation graph?

Thanks a lot.

My custom loss function is like given below:

```
#my loss function
from softargmax import softargmax1d
from numpy import dtype, require
from torchmetrics import ConfusionMatrix
def calc_my_loss(pred_y, real_y):
ONES = torch.tensor(1, dtype=torch.float32, requires_grad=True)
TWOS = torch.tensor(5, dtype=torch.float32, requires_grad=True)
conf_matrix = ConfusionMatrix(num_classes=3)
#print(f'softpred_ygradfn: {softargmax1d(pred_y).grad_fn}')
conf_matrix = conf_matrix(softargmax1d(pred_y).type(torch.int32), real_y)
my_metrics = torch.tensor(conf_matrix[0][1].item() * ONES + conf_matrix[0][2].item() * TWOS + conf_matrix[1][0].item() * ONES + conf_matrix[1][2].item() * ONES + conf_matrix[2][0].item() * TWOS + conf_matrix[2][1].item() * ONES, dtype=torch.float32)
return my_metrics
```

and softargmax1d function like given below. I have used this function to get gradients for confusion matrix.

```
def softargmax1d(input, beta=100):
*_, n = input.shape
input = nn.functional.softmax(beta * input, dim=-1)
indices = torch.linspace(0, 1, n)
result = torch.sum((n - 1) * input * indices, dim=-1)
return result
```

pred_y variable has the gradients just before the getting confusion matrix. When I calculate the confusion matrix, computation graph breaks and gradients are lost.

This is a multiclass classification and I want to penalize some metrics in confusion matrix with specified weights. How can I do this?

Thank you so much.

When you call `torch.tensor(x)`

, youâ€™re re-wrapping your input data and hence breaking your graph. Secondly, when you call the `.item()`

method you are taking a `Tensor`

and converting it to a `float`

, which PyTorch can no longer track and will break your graph.

You need to re-write `calc_my_loss`

without the use of `torch.tensor`

and `.item()`

Iâ€™ve just corrected my loss function according to your comments and I checked it, my_metrics grad_fn is <AddBackward0 object at 0x2a738bfa0>. I think my custom loss function is differentiable now and has gradients but I trained my model again and after loss.backward() method checked the gradients of layer_1 weights like given below (for training phase). The result is None. Where can be the problem about this, do you have any idea?

```
for e in tqdm(range(1, EPOCHS+1)):
print(f'Epoch: {e}')
# TRAINING
train_epoch_loss = 0
train_epoch_acc = 0
model.train()
for X_train_batch, y_train_batch in train_loader:
optimizer.zero_grad()
y_train_pred = model(X_train_batch)
train_loss = calc_my_loss(y_train_pred, y_train_batch)
train_acc = multi_acc(y_train_pred, y_train_batch)
train_loss.backward()
print(f'layer1GRADIENTS: {model.layer_1.weight.grad}')
optimizer.step()
train_epoch_loss += train_loss.item()
train_epoch_acc += train_acc.item()
```

Hi Burak!

When you convert your `softargmax1d()`

results to â€śhard,â€ť integer labels,

your â€śbreak the graph.â€ť This is because integers, being discrete, canâ€™t be

differentiated, so you canâ€™t backpropagate through them.

You will need to develop some sort of â€śsoftâ€ť confusion matrix function

that works (in a meaningful way) with soft, real-valued â€ślabels.â€ť

A further note: Your use of `beta = 100`

in `softargmax1d()`

makes

for a *very hard* version of `softmax()`

. Even though it is technically

differentiable, its derivative will be almost exactly zero almost everywhere,

so, although you will be able to backpropagate through it, youâ€™ll just get

zeros for the gradients and your training wonâ€™t actually train.

Best.

K. Frank

Hi @KFrank,

Thanks for your reply.

Now Iâ€™ll try to develop a soft confusion matrix function but now I donâ€™t know how to do this. I think my labels must be continuos for this and the confusion matrix must be calculated according to these ranges, am i right? Or do you know any tutorials or other resources about this topic?

Best regards.

edit:

and also is it a problem to have integer labels like 1, 2, 3 in my condition for `y_train_batch`

values?

Hi Burak!

I donâ€™t understand your use case or how you want your confusion matrix

to feed into your custom loss function.

But one approach for a â€śsoftâ€ť confusion matrix would be the following:

Let `pred_i`

be the predicted *probability* for your `i-th`

class (so, typically

the output of a `softmax()`

) and `targ_i`

be the ground-truth *probability*

of your `i-th`

class â€“ that is, your â€śsoftâ€ť labels.

(It is okay if your labels are â€śhard,â€ť that is, that `targ_i`

is `1.0`

for exactly

one value of `i`

and `0.0`

for all the rest.)

The `i,j`

element of your â€śsoftâ€ť confusion matrix could then be chosen to

be the sum over samples of `targ_i * pred_j`

. Note that as your predictions

(and labels) become hard â€“ that is, exactly one class has probability one

and all the others zero â€“ your soft confusion matrix will become the

conventional confusion matrix.

Because the predictions used in your soft confusion matrix are now

continuous probabilities, you will not have any issue backpropagating

through the confusion matrix.

(But note my previous comment that if you harden your `softmax()`

with

a `beta`

as large as `100`

, you wonâ€™t be able to *usefully* backpropagate

through that `softmax()`

.)

Best.

K. Frank

Hi @KFrank,

I have multiclass (number_of_classes=3) problem. In my custom loss function, i want to penalize the false predictions according to the confusion matrix. For example, if the predicted class is 1 but actual class is 2, i will multiply the number of wrong predictions with 1. And if the predicted class is 1 but actual class is 3 i will multiply the number of false predictions with 5 and so on. After all, i will sum the weighted penalty scores.

In my custom loss function, my aim is minimizing this weighted penalty score.

Is this approach and this soft confusion matrix convenient for my custom loss function? I will try to implement this soft confusion matrix.

Thank you so much.

Best regards.