How do I calculate the gradients of a non-leaf variable w.r.t to a loss function?

Hello @Abhai_Kollara

Update:
The proper solution is to use .retain_grad()

v = torch.autograd.Variable(torch.randn(3), requires_grad=True)
v2 = v+1
v2.retain_grad()
v2.sum().backward()
v2.grad

Apparently, this is common enough.

This is what I had posted before I knew better:

how about using hooks, e.g.

v = torch.autograd.Variable(torch.randn(3), requires_grad=True)
def require_nonleaf_grad(v):
     def hook(g):
         v.grad_nonleaf = g
     v.register_hook(hook)
v2 = v+1
require_nonleaf_grad(v2)
v2.sum().backward()
v2.grad_nonleaf

I don’t recommend calling it .grad to not collide with pytorch internals.

Best regards

Thomas

11 Likes