Backward is not being called in custom torch.autograd.Function

Hi, I have a custom torch.autograd.Function defined like this.

class SourceMemory(Function):
    def __init__(self, M, alpha=0.01):
        super(SourceMemory, self).__init__()
        self.M = M
        self.alpha = alpha

    def forward(self, inputs, targets):
        self.save_for_backward(inputs, targets)
        outputs = inputs.mm(self.M.t())
        return outputs

    def backward(self, grad_outputs):
        inputs, targets = self.saved_tensors
        grad_inputs = None
        if self.needs_input_grad[0]:
            grad_inputs = grad_outputs.mm(self.M)
        for x, y in zip(inputs, targets):
            self.M[y] = self.alpha * self.M[y] + (1. - self.alpha) * x
            self.M[y] = F.normalize(self.M[y], p=2, dim=0)
        return grad_inputs, None


class IntraNet(nn.Module):
    def __init__(self, beta=0.05, alpha=0.01, num_classes=0, num_features=0, weight=None):
        super(IntraNet, self).__init__()
        self.beta = beta
        self.alpha = alpha
        self.weight = weight
        self.register_buffer('M', torch.zeros(num_classes, num_features))

    def forward(self, inputs, targets, features=None, epoch=None):
        self.alpha = self.alpha * epoch
        outputs = SourceMemory(self.M, alpha=self.alpha)(features, targets)
        outputs /= self.beta
        loss = F.cross_entropy(inputs, targets, weight=self.weight)
        return loss, outputs

M is a memory containing the feature. The IntraNet Module is used as a criterion in

  for i, inputs in enumerate(data_loader):
            inputs, pids = self._parse_data(inputs)
            outputs, features = self.model(inputs)
            # this will update the buffer or the memory M
            source_pid_loss, _ = self.inter_criterion(outputs, pids, features=features, epoch=epoch)

Where self.inter_criterion is the IntraNet module. and the model returns a output = [batch_size, num_class] and features = [batch_size, 4096].
The problem is that, backward of SourceMemory is never called. Am I doing something wrong here? I want the buffer M to be filled during the backward pass.
Thank you

Hi,

You are using “old style” autograd.Function and this can happen.
Can you change your code to use the new style as described in this doc (basically remove the init and pass arguments directly to the forward, make forward and backward static method, save stuff in ctx. during the forward if you need them in the backward) and see if it solves the issue.