Error in official document code?

I’m following the tutorial here and use the code:

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`,
        # which is grad_dx * 6 * x.
        result = grad_output * dx + grad_dx * 6 * x
        return result

# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
    result, dx = MyCube.apply(x)
    return result

x = torch.tensor(2.0, requires_grad=True)
y, dx = my_cube(x)  # y = x^3, dx = 3x^2

But I got this error:

File "H:\code\lab\lab.py", line 33, in <module>
    y, dx = my_cube(x)  # y = x^3, dx = 3x^2
  File "H:\code\lab\lab.py", line 29, in my_cube
    result, dx = MyCube.apply(x)
TypeError: MyCube.forward() takes 1 positional argument but 2 were given

Hi Aakira!

This works* for me (using the latest stable version, 2.6.0).

Here is (a tweaked version of) your example script:

import platform
print (platform.python_version())

import torch
print (torch.__version__)

class MyCube(torch.autograd.Function):
    @staticmethod
    def forward(x):
        # We wish to save dx for backward. In order to do so, it must
        # be returned as an output.
        dx = 3 * x ** 2
        result = x ** 3
        return result, dx

    @staticmethod
    def setup_context(ctx, inputs, output):
        x, = inputs
        result, dx = output
        ctx.save_for_backward(x, dx)

    @staticmethod
    def backward(ctx, grad_output, grad_dx):
        x, dx = ctx.saved_tensors
        # In order for the autograd.Function to work with higher-order
        # gradients, we must add the gradient contribution of `dx`,
        # which is grad_dx * 6 * x.
        result = grad_output * dx + grad_dx * 6 * x
        return result

# Wrap MyCube in a function so that it is clearer what the output is
def my_cube(x):
    result, dx = MyCube.apply(x)
    # return result
    return result, dx   # return a tuple containing both result and dx

x = torch.tensor(2.0, requires_grad=True)
y, dx = my_cube(x)  # y = x^3, dx = 3x^2   # this line asks for a tuple, so we modified my_cube()

print (x, y, dx)

And here is its output:

3.12.8
2.6.0+cu126
tensor(2., requires_grad=True) tensor(8., grad_fn=<MyCubeBackward>) tensor(12., grad_fn=<MyCubeBackward>)

Double-check for possible copy-paste errors and maybe check your pytorch (and
python) version, but this should work with any relatively recent version of pytorch.

*) The last line of your example script requires my_cube() to return a sequence
of some sort (e.g., a tuple), so I tweaked my_cube() accordingly.

Best.

K. Frank

@KFrank Thank for taking your time.

My code just has not worked yet:

3.10.9
1.13.1+cu117
Traceback (most recent call last):
  File "H:\code\lab\lab.py", line 38, in <module>
    y, dx = my_cube(x)  # y = x^3, dx = 3x^2   # this line asks for a tuple, so we modified my_cube()
  File "H:\code\lab\lab.py", line 33, in my_cube
    result, dx = MyCube.apply(x)
TypeError: MyCube.forward() takes 1 positional argument but 2 were given

Hi Aakira!

This error makes it seem that MyCube.forward() is being treated as a non-static
class method (where when you call .forward (x) with a single explicit argument,
it actually gets called with an additional self argument: .forward (self, x)).

Quite some time ago, the .forward() method of torch.autograd.Function was
non-static. But you seem to be using pytorch 1.13, which uses static methods
according to its documentation.

Could you post the entirety of your lab.py as well as the output you when you run it
on the latest stable pytorch (2.6.0)?

Best.

K. Frank