I need to implement my own function which inherits from autograd.Function, where I mannually calculate some gradient using autograd.grad. But I find this does not work (see below). I am a bit new here. Who can help me?
import torch
def simple_func():
a = torch.tensor(3., requires_grad=True)
c = a + 2
grads = torch.autograd.grad(c, a, grad_outputs=torch.tensor(5.))
print(grads)
class dummy_autograd_func(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x + 2
@staticmethod
def backward(ctx, dy):
simple_func()
return dy
# Running normally; it is fine
simple_func()
# Running within autograd.Function; it fails
x = torch.rand(3, 2, requires_grad=True)
y = dummy_autograd_func.apply(x)
s = y.sum()
s.backward()
The error message is
Traceback (most recent call last):
File "/home/yangzh/work/PycharmProjects/rev_test/simple_test.py", line 29, in <module>
s.backward()
File "/home/yangzh/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
File "/home/yangzh/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
Variable._execution_engine.run_backward(
File "/home/yangzh/anaconda3/lib/python3.8/site-packages/torch/autograd/function.py", line 199, in apply
return user_fn(self, *args)
File "/home/yangzh/work/PycharmProjects/rev_test/simple_test.py", line 19, in backward
simple_func()
File "/home/yangzh/work/PycharmProjects/rev_test/simple_test.py", line 7, in simple_func
grads = torch.autograd.grad(c, a, grad_outputs=torch.tensor(5.))
File "/home/yangzh/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 234, in grad
return Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn