Why does calling backward on a loss function inside an autograd function cause an error?

The following autograd function attempts to calculate some auxiliary loss and backprop it. However I get an error element 0 of tensors does not require grad and does not have a grad_fn, even though x requires a gradient. Why?

import numpy as np

import torch
from torch import nn, optim, autograd


def run3():
    class my_function(autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)
            return x

        @staticmethod
        def backward(ctx, grad_output):
            (x,) = ctx.saved_variables
            print('x.requires_grad', x.requires_grad)
            aux_loss = x.sum()
            aux_loss.backward(retain_graph=True)
            x.backward(grad_output)


    my_function = my_function.apply
    N = 5
    K = 3

    a = torch.rand(N, K)
    a = nn.Parameter(a)
    out = my_function(a)
    loss = out.sum()
    print('loss', loss)
    loss.backward()


if __name__ == '__main__':
    run3()

Output:

x.requires_grad True
...
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

Edit: it turns out that aux_loss doesnt have require_grad set. So, I reckon this question relates to A tensor formed by indexing a tensor inside an autograd function does not require grad . Seems something works differently inside of autograd functions to normal somehow?

Within the forward and backward of an autograd.Function, autograd tracing is disabled by default (similar to when you do with torch.no_grad():), so aux_loss does not require gradient. If you wrap the aux_loss with with torch.enable_grad(): your code seems to run. (You don’t return anything from the backward, though, and it you might have funny side effects if you use backward that propagates to portions of the graph outside the function’s variables in your Function.)

Best regards

Thomas

2 Likes

Thanks! Will give that a shot :slight_smile:

Hi, Thomas

I have one thing to confirm. In pytorch 0.3, the forward function, every variable will be transferred to tensor, yet in backward, x, = ctx.saved_variables, then x is a variable. While, from what you say about pytorch > 0.4, the backward function sets autograd tracking disabled by default. Thank you!

I’d say just do something that works with 1.0 and let the old stuff be old. :slight_smile:

I have to be a bit mroe specific here:

“by default” for the backward actually needs to be qualified "unless you differentiate with “create_graph=True”.

The history as far as I remember is:

  • Originally (0.1.2) both forward and backward operated on “old-style-non-variable-tensors”, with PyTorch doing the “unpacking” of Variables into Tensors.
  • Then in 0.2, to facilitate higher order derivatives, the backward was switched to use variables (including using ctx.saved_variables instead of ctx.saved_tensors, also the Functions were now with static methods and ctx). The forward still operated on “old-style-non-variable-tensors” and PyTorch did the unpacking.
    This was still the state in PyTorch 0.3.
  • With the 0.4 merger of Tensor and Variable, everything became Variables but was called Tensors. Thus the forward operated on Tensors-previously-known-as-Variable, with but with an implicit “no_grad” taking the effect of the unpacking. The backward worked as before, but ctx.saved_variables was renamed to ctx.saved_tensors.

Best regards

Thomas

3 Likes

Hi, Thomas
It’s much detailed.
You don’t return anything from the backward, though, and it you might have funny side effects if you use backward that propagates to portions of the graph outside the function’s variables in your Function.

For this case, I am not sure whether this meets the expectation.

def run3():
    class my_function(autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)
            return x

        @staticmethod
        def backward(ctx, grad_output):
            (x,) = ctx.saved_tensors
            print('x.requires_grad', x.requires_grad)
            x_clone = x.clone().detach().requires_grad(True)
            aux_loss = x_clone.sum()
            aux_loss.backward()

            return grad_output + x_clone.grad()

I think you’re just a couple of typos and a with torch.enable_grad(): away from a working function:

class my_function(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        return x
    @staticmethod
    def backward(ctx, grad_output):
        (x,) = ctx.saved_tensors
        print('x.requires_grad', x.requires_grad)
        with torch.enable_grad():
            x_clone = x.clone().detach().requires_grad_(True)
            aux_loss = x_clone.sum()
            aux_loss.backward()
        return grad_output + x_clone.grad

Personally, I’d probably use torch.autograd.grad rather than backward, but that’s me.

Best regards

Thomas