I’ve heard you can somehow use the register_backward_hook method, but I have not found an example of this working. I’ve attached an attempt below, but it seems that the backward hook function, which is supposed to block the gradient computation, is never called.

import torch
import numpy as np
"""
We have the following equations:
x = 3
y = 4*x
loss = (v2-10)**2
This should yield
d_loss_d_y == 2*(3*4-10) == 4
d_loss_d_x == d_loss_d_y*4 == 16
Now - we ONLY want to calcuate d_loss_d_y. We want to block the gradient backpropagation
so that we don't bother calculating d_loss_d_x.
"""
x = torch.autograd.Variable(torch.from_numpy(np.array([3.])), requires_grad=True)
class MyOp(torch.nn.Module):
def __call__(self, x):
return x*4
op = MyOp()
def my_hook(mod, in_grad, out_grad):
print 'Backward Hook Called (THIS NEVER HAPPENS)'
return None # (Should stop further backpropagation?)
op.register_backward_hook(my_hook)
y = op(x)
intermediate_grads = {}
y.register_hook(lambda grad: intermediate_grads.setdefault(y, grad))
loss = ((y - 10.) ** 2).sum()
loss.backward()
assert np.array_equal(intermediate_grads[y].data.numpy(), [4.])
# The following should raise some kind of exception .. because d_loss_d_x shouldn't be computed
assert not np.array_equal(x.grad.data.numpy(), [16.])

Ah, so it appears that you’re both kind of right. Grad does seem to calculate only the needed gradient (leaving other variables .grad as None). But still, to avoid ever-growing computation with larger graphs, you need detach'.

The following snippet demonstrates the need for detach:

import time
import torch
from torch.autograd import grad
torch.manual_seed(1234)
h = torch.autograd.Variable(torch.zeros(10, 100))
wt = torch.nn.Linear(100, 100)
last_time = time.time()
for i in xrange(1000):
h = torch.tanh(wt(h.detach())) # Maintains a constant speed
# h = torch.tanh(wt(h)) # Gets slower and slower
loss = (h**2).sum()
(dl_dh, ) = grad(loss, (h, ))
if (i+1)%100==0:
this_time = time.time()
print 'Iterations {} to {}: {:.3g}s'.format(i-100, i, this_time - last_time)
last_time = this_time
assert 103.32 < dl_dh.abs().sum().data.numpy()[0] < 103.33

When h.detach() is used, it shows that the rate of computation stays roughly fixed:

Iterations 0 to 100: 0.0367s
Iterations 100 to 200: 0.032s
Iterations 200 to 300: 0.0321s
Iterations 300 to 400: 0.032s
Iterations 400 to 500: 0.0319s
Iterations 500 to 600: 0.032s
Iterations 600 to 700: 0.0321s
Iterations 700 to 800: 0.0429s
Iterations 800 to 900: 0.0328s
Iterations 900 to 1000: 0.032s

Whereas when it is not used, it slows down:

Iterations 0 to 100: 0.0736s
Iterations 100 to 200: 0.158s
Iterations 200 to 300: 0.279s
Iterations 300 to 400: 0.367s
Iterations 400 to 500: 0.37s
Iterations 500 to 600: 0.467s
Iterations 600 to 700: 0.581s
Iterations 700 to 800: 0.717s
Iterations 800 to 900: 0.74s
Iterations 900 to 1000: 0.793s

What I don’t know is what all this computation is doing (I guess just some overhead due to the ever-growing graph).