torch.autograd.Function's backward is not triggered when func.apply() is used

Hello. I’m trying to use torch.autograd.Function to define a function with a custom forward and backward pass. I’m following the guide in the pytorch documentation which gives the following example -

>>> class Exp(Function):
>>>     @staticmethod
>>>     def forward(ctx, i):
>>>         result = i.exp()
>>>         ctx.save_for_backward(result)
>>>         return result
>>>
>>>     @staticmethod
>>>     def backward(ctx, grad_output):
>>>         result, = ctx.saved_tensors
>>>         return grad_output * result
>>>
>>> # Use it by calling the apply method:
>>> output = Exp.apply(input)

my code looks something like this -

class Exp(Function):
     @staticmethod
     def forward(ctx, i,):
          #I do some changes to the input i which works as intended
          return i
     @staticmethod
     def backward(ctx,grad_output):
          grad_output = torch.relu(grad_output)
          print(grad_output.shape)
          return grad_output

I expected this to print grad_output shape everytime the function has been called. This function has called in the forward method in the torch.nn.Module subclass (which looks really similar to the Linear layer). Let’s called this class Cats. Cats has been instantiated in the init method of a model.
But the code does not print the grad_output. What could be the reason?

Have you called .backward() anywhere in your code?

Also, can you post a minimal reproducible example to replicate the error?

Hello! Thanks for the reply.
I did not explicitly call the .backward() , however i did apply it according to the pytorch documentation by typing Exp.apply(input)
The following is the code for reproducing the error(provided the proper imports have made) -

class autograd_func(torch.autograd.Function):
    def forward(self,layer):
        alpha = torch.ones(layer.shape[0])
        Alpha = (alpha)
        return layer,Alpha
    def backward(self,grad_output):
        print("checking backward call for autograd_func")
        return grad_output


class custom_linear(torch.nn.Module):
    def __init__(self, in_features: int, out_features: int, bias: bool = False,
                 device=None, dtype=None):
        factory_kwargs = {'device': device, 'dtype': dtype}
        super(custom_linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        if bias:
            self.bias = Parameter(torch.empty(out_features, **factory_kwargs))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self) -> None:
        # Setting a=sqrt(5) in kaiming_uniform is the same as initializing with
        # uniform(-1/sqrt(in_features), 1/sqrt(in_features)). For details, see
        # https://github.com/pytorch/pytorch/issues/57109
        init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
            init.uniform_(self.bias, -bound, bound)

    def forward(self,input):
        A,B = autograd_func.apply(self.weight.data)
        self.weight.data = A
        #B does get used in the complete implementation but it does not make a difference
        out = F.linear(input, self.weight, self.bias)
        return out

    def extra_repr(self) -> str:
        return 'in_features={}, out_features={}, bias={}'.format(
            self.in_features, self.out_features, self.bias is not None)


x = torch.randn([1,100])
y_star = torch.ones([1,10])
cust_layer = custom_linear(100,10)
y = cust_layer(x)
# print(y.shape,y_star.shape)
criterion = torch.nn.MSELoss()
loss = criterion(y,y_star)
loss.backward()

I hope this helps!

the forward and backward methods need to be decorated with @staticmethod, add that above each method and see if that fixes the problem

I did that, did not work…

Ok, so the comments about @staticmethod were correct, it’s just that you need to add a few more things. For example change self to ctx, and remove the calls to .data, for example A,B = autograd_func.apply(self.weight.data) as .data is deprecated. Make sure to follow the exact layout of the arguments for the forward and backward in this example below. Do make sure to read through the tutorial (Extending PyTorch — PyTorch 1.11.0 documentation) too.

# Inherit from Function
class LinearFunction(Function):

    # Note that both forward and backward are @staticmethods
    @staticmethod
    # bias is an optional argument
    def forward(ctx, input, weight, bias=None):
        ctx.save_for_backward(input, weight, bias)
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        # This is a pattern that is very convenient - at the top of backward
        # unpack saved_tensors and initialize all gradients w.r.t. inputs to
        # None. Thanks to the fact that additional trailing Nones are
        # ignored, the return statement is simple even when the function has
        # optional inputs.
        input, weight, bias = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = None

        # These needs_input_grad checks are optional and there only to
        # improve efficiency. If you want to make your code simpler, you can
        # skip them. Returning gradients for inputs that don't require it is
        # not an error.
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
        if bias is not None and ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0)

        return grad_input, grad_weight, grad_bias

This simple code works for me.

import torch

class autograd_func(torch.autograd.Function):
    @staticmethod
    def forward(self,layer):
        alpha = torch.ones(layer.shape[0])
        Alpha = (alpha)
        return layer,Alpha
    
    @staticmethod
    def backward(self,grad_output, other):
        print("checking backward call for autograd_func")
        return grad_output


if __name__ == '__main__':
    x = torch.randn(5,6)
    x.requires_grad = True
    y, a = autograd_func.apply(x)
    y.sum().backward()
1 Like

Thanks for the reply and the links!
Changing it to the ctx did not work.I was unaware of the fact that .data is depreceated. I will rewrite that part of code and try it again, as autograd does not track changed by making a backwards graph when .data is used(correct me if I’m wrong here)
I will try out your suggestions and update you.

Thank you for the reply!
This did work, yes. The torch.randn is a tensor whose computations are tracked by autograd. Using backward on .data is depreceated and changes on it are not tracked by it. Ill make the necessary changes and update you.

It did work! Thank you for your feedback InnovArul and AlphaBetaGamma96