Is there any work around to make a differentiable torch.max?

(I have read this:,but it isn’t what I want to find)

Hello,in a segmentation work,I have
logits BxCxHxW and targets BxHxW
and want to calculate
so I did something like this:

        normalizeds = F.softmax(logits) 
        _, preds = torch.max(normalizeds, 1)   #<---torch.max NOT differentiable

then using code like:

    for clazz in range(logits.shape[1]):        #loop by channels(class actually)
            predict_true = preds == clazz       #get predicted truths of specific class
            real_true = targets == clazz          #get real truths of specific class
            TP=real_true.squeeze()[preict_true.squeeze()].int().sum()   #calc true positive of specific class

The problem is that the ‘torch.max’ is NOT differentiable,any work around to make a differentiable torch.max,ultimately,calculate the TP,TN,FP,TN without losing differentiablity?