How to know if the custom loss can be autograded or not


(No Name) #1

Hi guys,

I have read the pytorch tutorial and have some toy examples. But I am still very confused how to know if the code can be autograded or not. I just know that each variable (especially those involved in the computation graph) in the custom loss function should be in Varaible type. While if the loss computed by tensor.sum (or .dot) is float type, then we return a float type, can the custom loss function still be autograded?

Thanks.

I also give a toy example:

‘’‘
input is a torch variable of size BatchxnclassesxHxW representing probabilities for each class
target is a a also tensor, with batchx1xHxW
’’’

def myDiceLoss4Organ(input,target):
eps = 0.000001

uniques=np.unique(target_one_hot.data.cpu().numpy())
assert set(list(uniques))<=set([0,1]), “target must only contain zeros and ones”

probs = F.softmax(input) #maybe it is not necessary

target = target_one_hot.contiguous().view(-1,1).squeeze(1)
result = probs.contiguous().view(-1,1).squeeze(1)

intersect = torch.dot(result, target)
    
target_sum = torch.sum(target)
result_sum = torch.sum(result)
union = result_sum + target_sum + (2*eps)
IoU = intersect / union
dice_total = 1 - 2*IoU
return dice_total

(No Name) #2

Understand by myself now. We need Variable for each computation in the chain-rule…
So this copy of code doesnot work, as torch.dot doesn’t support autograd, it returns a float…


(Marcin Elantkowski) #3

You’re right, we indeed have to use Variables to use autograd.

So this copy of code doesnot work, as torch.dot doesn’t support autograd, it returns a float…

You can use torch.dot with autograd, it returns a Variable with a single element.

import torch as th

v1 = th.autograd.Variable(th.randn(10), requires_grad=True)
v2 = th.autograd.Variable(th.randn(10), requires_grad=True)

res = th.dot(v1, v2)

res.backward()

print(res)

# prints
# Variable containing:
# 1.1473
# [torch.FloatTensor of size 1]

(No Name) #4

Thanks. The reason torch.dot doesn’t work in my case is that I didn’t set requires_grad = True for the input variables?