Ah, so it appears that you’re both kind of right. Grad does seem to calculate only the needed gradient (leaving other variables .grad as None). But still, to avoid ever-growing computation with larger graphs, you need detach
'.
The following snippet demonstrates the need for detach:
import time
import torch
from torch.autograd import grad
torch.manual_seed(1234)
h = torch.autograd.Variable(torch.zeros(10, 100))
wt = torch.nn.Linear(100, 100)
last_time = time.time()
for i in xrange(1000):
h = torch.tanh(wt(h.detach())) # Maintains a constant speed
# h = torch.tanh(wt(h)) # Gets slower and slower
loss = (h**2).sum()
(dl_dh, ) = grad(loss, (h, ))
if (i+1)%100==0:
this_time = time.time()
print 'Iterations {} to {}: {:.3g}s'.format(i-100, i, this_time - last_time)
last_time = this_time
assert 103.32 < dl_dh.abs().sum().data.numpy()[0] < 103.33
When h.detach() is used, it shows that the rate of computation stays roughly fixed:
Iterations 0 to 100: 0.0367s
Iterations 100 to 200: 0.032s
Iterations 200 to 300: 0.0321s
Iterations 300 to 400: 0.032s
Iterations 400 to 500: 0.0319s
Iterations 500 to 600: 0.032s
Iterations 600 to 700: 0.0321s
Iterations 700 to 800: 0.0429s
Iterations 800 to 900: 0.0328s
Iterations 900 to 1000: 0.032s
Whereas when it is not used, it slows down:
Iterations 0 to 100: 0.0736s
Iterations 100 to 200: 0.158s
Iterations 200 to 300: 0.279s
Iterations 300 to 400: 0.367s
Iterations 400 to 500: 0.37s
Iterations 500 to 600: 0.467s
Iterations 600 to 700: 0.581s
Iterations 700 to 800: 0.717s
Iterations 800 to 900: 0.74s
Iterations 900 to 1000: 0.793s
What I don’t know is what all this computation is doing (I guess just some overhead due to the ever-growing graph).