Hi,
I am trying to implement the following dice loss but get a RuntimeError on loss.backward(). The error says:
RuntimeError: element 0 of variables does not require grad and does not have a grad_fn
Any help is much appreciated.
Thanks
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from torch.autograd import Function, Variable
class DiceCoeff(Function):
"""Dice coeff for individual examples"""
def forward(self, input, target):
_,input = input.max(0)
target=target.long()
#print 'input:'+str(input.size())
#print 'target:'+str(target.size())
self.save_for_backward(input, target)
self.inter = torch.dot(input.view(-1).float(),target.view(-1).float()) + 0.0001
self.union = torch.sum(input) + torch.sum(target) + 0.0001
t = 2*self.inter.float()/self.union.float()
return t
# This function has only a single output, so it gets only one gradient
def backward(self, grad_output):
input, target = self.saved_variables
grad_input = grad_target = None
if self.needs_input_grad[0]:
grad_input = grad_output * 2 * (target * self.union + self.inter) \
/ self.union * self.union
if self.needs_input_grad[1]:
grad_target = None
return grad_input, grad_target
def dice_coeff(input, target):
"""Dice coeff for batches"""
if input.is_cuda:
s = Variable(torch.FloatTensor(1).cuda().zero_())
else:
s = Variable(torch.FloatTensor(1).zero_())
for i, c in enumerate(zip(input, target)):
s = s + DiceCoeff().forward(c[0], c[1])
s = s / (i+1)
return s
class DiceLoss(_Loss):
def forward(self, input, target):
return 1 - dice_coeff(F.sigmoid(input), target)
Traceback (most recent call last):
File “train.py”, line 560, in
main(args)
File “train.py”, line 335, in main
train(train_loader, net, criterion, optimizer, epoch, train_args)
File “train.py”, line 372, in train
loss.backward()
File “/usr/lib64/python2.7/site-packages/torch/autograd/variable.py”, line 167, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
File “/usr/lib64/python2.7/site-packages/torch/autograd/init.py”, line 99, in backward
variables, grad_variables, retain_graph)
RuntimeError: element 0 of variables does not require grad and does not have a grad_fn