(I have read this:https://discuss.pytorch.org/t/got-no-graph-nodes-that-require-computing-gradients-when-use-torch-max/7952,but it isn’t what I want to find)
Hello,in a segmentation work,I have
logits BxCxHxW and targets BxHxW
and want to calculate
TP,FP,TN,FN
so I did something like this:
normalizeds = F.softmax(logits)
_, preds = torch.max(normalizeds, 1) #<---torch.max NOT differentiable
then using code like:
for clazz in range(logits.shape[1]): #loop by channels(class actually)
predict_true = preds == clazz #get predicted truths of specific class
real_true = targets == clazz #get real truths of specific class
TP=real_true.squeeze()[preict_true.squeeze()].int().sum() #calc true positive of specific class
The problem is that the ‘torch.max’ is NOT differentiable,any work around to make a differentiable torch.max,ultimately,calculate the TP,TN,FP,TN without losing differentiablity?