Grad_fn is none in forward for grad_enabled tensor after differentiable operation

Hello everybody,

There is something I’m not understanding well about how autograd works. I am trying to implement a maximum likelihood model in pytorch. This model is maximizing a softmax distribution given some parameters, knowing some batch of data. My model class looks like this:

class Model(nn.Module):

    def __init__(self, param_size, kernel_size, max_value):

        super(Model, self).__init__()

        self.kernel_size = kernel_size
        self.max_value = max_value
        self.params = torch.ones(param_size, requires_grad = True).float()

    def forward(self, data):

        pot_data, norm_data = data

        ### global potential computation 

        potential = torch.mul(pot_data, self.params)
        potential = torch.sum(potential, (1,2))

        ### global normalization computation  ###
        
        normalization = torch.mul(norm_data, self.params)     
        normalization = torch.sum(normalization, dim = 2)
        normalization = torch.logsumexp(-normalization, dim = 1)
        normalization = torch.sum(normalization, dim = 1)

        batch_likelihood = torch.sum(-potential - normalization)

        return batch_likelihood

I have a standard training look with some data pre-processing :

for index, (img, label) in enumerate(dataloader):
        
        data = F.unfold(torch.unsqueeze(img, 1),
                    kernel_size = model.kernel_size,
                    stride = 1,
                    padding = 0)

        batch_size = data.size(0)
        nb_patchs = data.size(-1)
        nb_repeats = model.max_value

        pot_data = (data[:,:,:] != data[:,4,:].view((batch_size, 1, nb_patchs)))
        
        dup_data = torch.unsqueeze(data, 1).repeat(1, model.max_value, 1, 1)
        ranges = torch.unsqueeze(torch.unsqueeze(torch.arange(0, model.max_value, 1), 0), 2)
        ranges = ranges.repeat(batch_size, 1, nb_patchs)
        dup_data [:, :,4,:] = ranges 
        norm_data = (dup_data[:,:,:,:] != dup_data[:,:,4,:].view((batch_size, nb_repeats, 1, nb_patchs)))

        data = (pot_data, norm_data)

        optimizer.zero_grad()
        likelihood = model.forward(data)
        (-likelihood).backward()
        optimizer.step()

At every point in the forward method, self.params.grad_fn returns None, and I don’t understand why. Also, If I save my model parameters with torch.save("model.pt") and use something like https://netron.app/, my computation graph doesn’t contain any operation.
However, my likelihood does still get minimized, and my parameters get updated at every batch.

This is very confusing to me: I would expect self.params.grad_fn to be different from None, for example MulBackward after every torch.mul in the forward function, but this is not the case. Why is that ?

Also, given that grad_fn is None, how does the optimizer compute a gradient to update my parameters at every batch ?

I would be grateful for any help for better understanding how Autograd works in this case :slight_smile:

It seems you are trying to initialize self.params as the trainable parameters in the model, which is not the case at the moment, since you are not registering the tensor and would need to use nn.Parameter(torch.ones(...)) instead. This will make sure this parameter will be properly registered and e.g. passed to the optimizer via model.parameters().

Assuming you would have properly registered self.params it wouldn’t show any grad_fn, since it’s a leaf tensor and was not created by an operation. In case you want to check the grad_fns, you could check the forward activations e.g. print(potential.grad_fn).

Thanks a lot for your answer, it makes sense now. If I understand correctly, grad_fns will only be evaluated with respect to tensors with requires_grad=True, and that’s how autograd will differentiate the expression wrt the parameters and not the data (for example) ?
Also, this led to torchviz working well, so I could make sure that my model was computing precisely what I wanted.