Pytorch "Extending Pytorch" doc missing ctx in forward method?

As stated, in the Extending Pytorch doc, the example:

import torch
from torch.autograd import Function
# Inherit from Function
class LinearFunction(Function):

    # Note that forward, setup_context, and backward are @staticmethods
    @staticmethod
    def forward(input, weight, bias):
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    @staticmethod
    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

    # This function has only a single output, so it gets only one gradient
    @staticmethod
    def backward(ctx, grad_output):
        ...

will cause an error:

input=torch.randn(1,1)
weight=torch.randn(1,1,requires_grad=True)
bias=None
y=LinearFunction.apply(input,weight,bias)
# TypeError: forward() takes 3 positional arguments but 4 were given

Looks like you need to add ctx to the forward method:

@staticmethod
    def forward(ctx, input, weight, bias):
        ...

it works afterwards

even with the fix the backward is not working too, am I doing anything wrong?

input=torch.randn(1,1)
weight=torch.randn(1,1,requires_grad=True)
bias=None
y=LinearFunction.apply(input,weight,bias)
y.backward()

Traceback (most recent call last):
  File "/home/.../Desktop/pytortto/src/test1.py", line 52, in <module>
    y.backward()
  File "/home/.../cpu_only/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
    torch.autograd.backward(
  File "/home/.../cpu_only/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/.../cpu_only/lib/python3.9/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/home/.../Desktop/pytortto/src/test1.py", line 30, in backward
    input, weight, bias = ctx.saved_tensors
ValueError: not enough values to unpack (expected 3, got 0)

The example works for me:

# Option 1: alias
linear = LinearFunction.apply

x = torch.randn(10, 10)
lin = nn.Linear(10, 10)
out = linear(x, lin.weight, lin.bias)
out.mean().backward()
print(lin.weight.grad.abs().sum())

and you might want to check the Combined or separate forward() and setup_context() section of the docs, which explains the usage of setup_context starting with PyTorch 2.0.0.

Yep I updated my pytorch from 1.13.1 to 2.0.0 and it works. Looks like setup_context only works starting from 2.0.0. Thanks!