How can I call autograd.grad within an autograd.Function?

I need to implement my own function which inherits from autograd.Function, where I mannually calculate some gradient using autograd.grad. But I find this does not work (see below). I am a bit new here. Who can help me?

import torch
def simple_func():
    a = torch.tensor(3., requires_grad=True)
    c = a + 2
    grads = torch.autograd.grad(c, a, grad_outputs=torch.tensor(5.))
    print(grads)

class dummy_autograd_func(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        return x + 2

    @staticmethod
    def backward(ctx, dy):
        simple_func()
        return dy

# Running normally; it is fine
simple_func()

# Running within autograd.Function; it fails
x = torch.rand(3, 2, requires_grad=True)
y = dummy_autograd_func.apply(x)
s = y.sum()
s.backward()

The error message is

Traceback (most recent call last):
  File "/home/yangzh/work/PycharmProjects/rev_test/simple_test.py", line 29, in <module>
    s.backward()
  File "/home/yangzh/anaconda3/lib/python3.8/site-packages/torch/_tensor.py", line 307, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/home/yangzh/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 154, in backward
    Variable._execution_engine.run_backward(
  File "/home/yangzh/anaconda3/lib/python3.8/site-packages/torch/autograd/function.py", line 199, in apply
    return user_fn(self, *args)
  File "/home/yangzh/work/PycharmProjects/rev_test/simple_test.py", line 19, in backward
    simple_func()
  File "/home/yangzh/work/PycharmProjects/rev_test/simple_test.py", line 7, in simple_func
    grads = torch.autograd.grad(c, a, grad_outputs=torch.tensor(5.))
  File "/home/yangzh/anaconda3/lib/python3.8/site-packages/torch/autograd/__init__.py", line 234, in grad
    return Variable._execution_engine.run_backward(
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

By default the gradient computation would be disabled in your custom autograd.Function, so you would need to enable it via torch.set_grad_enabled(True) in the backward.

Thank you. It works now!