Custom Loss Function - No graph nodes that require computing gradients error

Hi everyone,

I implemented the custom loss function shown below, but I keep getting this error (RuntimeError: there are no graph nodes that require computing gradients). Can anyone provide any advice regarding what I am doing wrong ?

class Viewloss_Function(torch.autograd.Function):
    def __init__(self, class_period=360, num_classes=12):
        super(Viewloss_Function, self).__init__()
        self.num_classes = num_classes
        self.class_period = class_period

    def forward(self, preds, labels, obj_classes):
        preds = preds.float()
        labels = labels.float()
        probs = torch.zeros(preds.size()).cuda()
        batch_size = preds.size(0)
        loss     = torch.zeros(1)
        if torch.cuda.is_available():
            loss = loss.cuda()
        loss = torch.autograd.Variable(loss)
        
        for inst_id in range(batch_size):
            start_id = int(obj_classes[inst_id].data[0]) * self.class_period
            end_id   = start_id + self.class_period
            current_preds = preds[inst_id][start_id:end_id].clone()
            probs[inst_id][start_id:end_id] = torch.nn.functional.softmax(current_preds).data[0]
            loss += (labels[inst_id][start_id:end_id].data[0] * probs[inst_id][start_id:end_id]).sum()

        self.save_for_backward(probs, labels)
        return loss

    def backward(self, grad_output):
        probs, labels = self.saved_variables
        grad_preds = grad_labels = grad_obj_classes = None
        grad_preds = probs - labels
        return grad_preds, grad_labels, grad_obj_classes

The code below results in the Runtime Error.

loss_function = Viewloss_Function(num_classes =3, class_period=3)
loss = loss_function.forward(preds, labels, obj_classes)
loss.backward()

The inputs are:
preds: Parameters containing [torch.DoubleTensor of size 3x15]
labels: Variable containing [torch.cuda.DoubleTensor of size 3x15 (GPU 0)]
obj_classes: Variable containing [torch.cuda.LongTensor of size 3 (GPU 0)]

Thank you very much for your help!

Update:

I didn’t find out why it wasn’t working, however, here’s a different implementation that uses a Module instead.

class ViewpointLoss(nn.Module):
    def __init__(self, num_classes = 12, class_period = 360):
        super(ViewpointLoss, self).__init__()
        self.num_classes = num_classes
        self.class_period = class_period

    def forward(self, preds, labels, obj_classes):
        labels = labels.float()
        batch_size = preds.size(0)
        loss     = torch.zeros(1)

        if torch.cuda.is_available():
            loss = loss.cuda()
        loss = torch.autograd.Variable(loss)

        for inst_id in range(batch_size):
            start_index = int(obj_classes[inst_id].data[0]) * self.class_period
            end_index   = start_index + self.class_period
            loss -= (labels[inst_id, start_index:end_index] * F.softmax(preds[inst_id, start_index:end_index]).log()).sum()

        loss = loss / batch_size
        return loss

This should work:

losses = []
for inst_id in range(batch_size):
  start_index = int(obj_classes[inst_id].data[0]) * self.class_period
  end_index   = start_index + self.class_period
  losses.append(-labels[inst_id, start_index:end_index] * F.softmax(preds[inst_id, start_index:end_index]).log()).sum())
loss = torch.sum(torch.cat(losses))
loss /= batch_size

(unverified, may have some typos)
The reason is that as soon as you do loss = torch.zeros(1) you are creating a Tensor that is not connected to the computational graph at all. You need to be creating the loss Variable that maintains it’s connection to the computation graph.