No gradient in simple function

The following function does not behave as expected when called from my codebase:

def foo():
    torch.autograd.set_grad_enabled(True)
    torch.set_grad_enabled(True)
    with torch.enable_grad():
        y = torch.randn(10, device='cuda')
        y = torch.nn.Parameter(y.detach(), requires_grad=True)
        print(y.requires_grad, y.mean())

As you might notice, I tried my best to enable gradients in this code. If I call this function by itself, the printed output is as expected, i.e. True tensor(-0.8699, device='cuda:0', grad_fn=<MeanBackward0>).

However, if I run this during an evaluation pass of a lightning-model, the output is True tensor(-0.0105, device='cuda:0'), so the grad_fn is missing, even though requires_grad is True. As a result I cannot to a backward pass if I process y further.

What else could prevent gradients from being computed?

Do you see the same behavior without using lightning? PyTorch itself won’t disable gradient calculation behind your back, so you might want to check what the evaluation loop in lightning does.

1 Like

Thanks for the hint towards lightning. I found this issue on their github, apparently lightning sets torch.inference_mode(True) when evaluating. This now works as expected:

def foo():
    with torch.inference_mode(False):
        y = torch.randn(10, device='cuda')
        y = torch.nn.Parameter(y.detach(), requires_grad=True)
        print(y.requires_grad, y.mean())