Pytorch "Extending Pytorch" doc missing ctx in forward method?

As stated, in the Extending Pytorch doc, the example:

import torch
from torch.autograd import Function
# Inherit from Function
class LinearFunction(Function):

    # Note that forward, setup_context, and backward are @staticmethods
    def forward(input, weight, bias):
        output =
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        return output

    # inputs is a Tuple of all of the inputs passed to forward.
    # output is the output of the forward().
    def setup_context(ctx, inputs, output):
        input, weight, bias = inputs
        ctx.save_for_backward(input, weight, bias)

    # This function has only a single output, so it gets only one gradient
    def backward(ctx, grad_output):

will cause an error:

# TypeError: forward() takes 3 positional arguments but 4 were given

Looks like you need to add ctx to the forward method:

    def forward(ctx, input, weight, bias):

it works afterwards

even with the fix the backward is not working too, am I doing anything wrong?


Traceback (most recent call last):
  File "/home/.../Desktop/pytortto/src/", line 52, in <module>
  File "/home/.../cpu_only/lib/python3.9/site-packages/torch/", line 488, in backward
  File "/home/.../cpu_only/lib/python3.9/site-packages/torch/autograd/", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/.../cpu_only/lib/python3.9/site-packages/torch/autograd/", line 267, in apply
    return user_fn(self, *args)
  File "/home/.../Desktop/pytortto/src/", line 30, in backward
    input, weight, bias = ctx.saved_tensors
ValueError: not enough values to unpack (expected 3, got 0)

The example works for me:

# Option 1: alias
linear = LinearFunction.apply

x = torch.randn(10, 10)
lin = nn.Linear(10, 10)
out = linear(x, lin.weight, lin.bias)

and you might want to check the Combined or separate forward() and setup_context() section of the docs, which explains the usage of setup_context starting with PyTorch 2.0.0.

Yep I updated my pytorch from 1.13.1 to 2.0.0 and it works. Looks like setup_context only works starting from 2.0.0. Thanks!