import torch
from torch.autograd import Function
# Inherit from Function
class LinearFunction(Function):
# Note that forward, setup_context, and backward are @staticmethods
@staticmethod
def forward(input, weight, bias):
output = input.mm(weight.t())
if bias is not None:
output += bias.unsqueeze(0).expand_as(output)
return output
@staticmethod
# inputs is a Tuple of all of the inputs passed to forward.
# output is the output of the forward().
def setup_context(ctx, inputs, output):
input, weight, bias = inputs
ctx.save_for_backward(input, weight, bias)
# This function has only a single output, so it gets only one gradient
@staticmethod
def backward(ctx, grad_output):
...
will cause an error:
input=torch.randn(1,1)
weight=torch.randn(1,1,requires_grad=True)
bias=None
y=LinearFunction.apply(input,weight,bias)
# TypeError: forward() takes 3 positional arguments but 4 were given
Looks like you need to add ctx to the forward method:
even with the fix the backward is not working too, am I doing anything wrong?
input=torch.randn(1,1)
weight=torch.randn(1,1,requires_grad=True)
bias=None
y=LinearFunction.apply(input,weight,bias)
y.backward()
Traceback (most recent call last):
File "/home/.../Desktop/pytortto/src/test1.py", line 52, in <module>
y.backward()
File "/home/.../cpu_only/lib/python3.9/site-packages/torch/_tensor.py", line 488, in backward
torch.autograd.backward(
File "/home/.../cpu_only/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
File "/home/.../cpu_only/lib/python3.9/site-packages/torch/autograd/function.py", line 267, in apply
return user_fn(self, *args)
File "/home/.../Desktop/pytortto/src/test1.py", line 30, in backward
input, weight, bias = ctx.saved_tensors
ValueError: not enough values to unpack (expected 3, got 0)
# Option 1: alias
linear = LinearFunction.apply
x = torch.randn(10, 10)
lin = nn.Linear(10, 10)
out = linear(x, lin.weight, lin.bias)
out.mean().backward()
print(lin.weight.grad.abs().sum())