Torch.autograd.grad returns gradient with None grad_fn even set create_graph=True

Hi, I am trying to calculate higher order derivatives of a customized module. However, gradients output of torch.autograd.grad() or backward() has no grad_fn even for setting create_graph=True. The customized module is a cuda extention. Specifically, PreciseRoIPooling.

I created a small example.

import torch
import torch.nn as nn
import torch.nn.functional as F
from ltr.external.PreciseRoIPooling.pytorch.prroi_pool import PrRoIPool2D

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        # PreciseRoIPooling module
        self.pool = PrRoIPool2D(3, 3, spatial_scale=0.5)
        self.f1 = nn.Linear(9,1)

    def forward(self, features,rois):
        # Run the IoUNet module
        poolfeat = self.pool(features, rois)
        poolfeat = poolfeat.reshape(2, 16, -1)
        iou = self.f1(poolfeat)
        # calculate gradient
        grad = torch.autograd.grad(iou, rois, grad_outputs=torch.ones_like(iou), retain_graph=True, create_graph=True)[0]
        print(grad.requires_grad)
        print(grad.grad_fn)
        # grad1 = torch.autograd.grad(grad, self.f1.parameters(), grad_outputs=torch.ones_like(grad), retain_graph=True, create_graph=True)[0]

        return iou

if __name__ == '__main__':
    model = Model().to('cuda:0')
    features = torch.rand((4, 16, 32, 32),requires_grad=True).cuda()
    rois = torch.tensor([
        [0, 0, 0, 14, 14],
        [1, 14, 14, 28, 28],
    ]).float().cuda()
    rois.requires_grad=True
    iou = model(features, rois)

The output is:

False
None

Assuming that PrRoIPool2D is a pooling layer with not parameter, your whole function is just linear.
And as such, its gradient is independent of the input.
This is why you get no grad_fn. Any grad1 that you would have computed will be filled with 0s.

You can make your network deeper so that it is not linear anymore to make the gradient depend on the input. And so it will have a grad_fn populated.

In general, if you get a gradient of None or not graph, it means that they are independent and so the value is just 0s.

If it is a custom Function, you need to make sure that its backward is differentiable via autograd.
If you use non-differentiable ops in the backward, you will have to write a second Function whose forward will be the backward of the first one. Like:

class MyFn(Function):
    @staticmethod
    def forward(ctx, inp):
        returm my_non_diff_forward(inp)

    @staticmethod
    def backward(ctx, gO):
       return  MyFnBackward.apply(gO)

class MyFnBackward(Function):
    @staticmethod
    def forward(ctx, inp):
        returm my_non_diff_backward(inp)

    @staticmethod
    def backward(ctx, gO):
       return  my_diff_double_backward(gO)

Note that if your double backward is not differentiable, you can add a @oncedifferentiable (from torch.autograd.function import oncedifferentiable) to it’s backward to get a nice error if you ever try to backward through that in the future.

The custom CUDA extension is differentiable

How?
The autograd only tracks torch functions. Your extension uses cuda directly yo compute the result. So it is not differentiable!

1 Like

Thanks a lot. You really helped me out.