(I have read this:https://discuss.pytorch.org/t/got-no-graph-nodes-that-require-computing-gradients-when-use-torch-max/7952,but it isn’t what I want to find)

Hello,in a segmentation work,I have

`logits BxCxHxW and targets BxHxW`

and want to calculate

`TP,FP,TN,FN`

so I did something like this:

```
normalizeds = F.softmax(logits)
_, preds = torch.max(normalizeds, 1) #<---torch.max NOT differentiable
```

then using code like:

```
for clazz in range(logits.shape[1]): #loop by channels(class actually)
predict_true = preds == clazz #get predicted truths of specific class
real_true = targets == clazz #get real truths of specific class
TP=real_true.squeeze()[preict_true.squeeze()].int().sum() #calc true positive of specific class
```

The problem is that the ‘torch.max’ is NOT differentiable,any work around to make a differentiable torch.max,ultimately,calculate the TP,TN,FP,TN without losing differentiablity?