Thanks @smth
The only way I have been able to really extract the gradient however is via a global variable at the moment. This is because the function I pass in (apparently) only allows me to pass in one argument, and that is reserved for the yy.grad. What I mean is given here:
yGrad = torch.zeros(1,1)
def extract(xVar):
global yGrad
yGrad = xVar
xx = Variable(torch.randn(1,1), requires_grad = True)
yy = 3*xx
zz = yy**2
yy.register_hook(extract)
#### Run the backprop:
print (yGrad) # Shows 0.
zz.backward()
print (yGrad) # Show the correct dzdy
So here, I am able to extract the yy.grad, BUT, I can only do so with a global variable, which I would rather not do. Is there a simpler way? Many thanks.