How can I make the gradient of a parameter a function itself?
import torch
def fun(q):
def result(w):
l = w * q
l.backward()
return w.grad
return result
w = torch.tensor((2.), requires_grad=True)
q = torch.tensor((3.), requires_grad=True)
f = fun(q)
print(f(w))
In the code above, how can I make f(w) to have gradient with respect to q?