Make parameter gradient a function of another parameter

How can I make the gradient of a parameter a function itself?

import torch

def fun(q):
    def result(w):
        l = w * q
        return w.grad
    return result

w = torch.tensor((2.), requires_grad=True)
q = torch.tensor((3.), requires_grad=True)

f = fun(q)


In the code above, how can I make f(w) to have gradient with respect to q?


If you use f(w).sum().backward() then it will actually populate the .grad field for both w and q.
Or you can ask for it as autograd.grad(f(w).sum(), q).