Disconnected Gradient in Pytorch

Ah, so it appears that you’re both kind of right. Grad does seem to calculate only the needed gradient (leaving other variables .grad as None). But still, to avoid ever-growing computation with larger graphs, you need detach'.

The following snippet demonstrates the need for detach:

import time
import torch
from torch.autograd import grad

torch.manual_seed(1234)
h = torch.autograd.Variable(torch.zeros(10, 100))

wt = torch.nn.Linear(100, 100)
last_time = time.time()

for i in xrange(1000):

    h = torch.tanh(wt(h.detach()))  # Maintains a constant speed
    # h = torch.tanh(wt(h))  # Gets slower and slower
    loss = (h**2).sum()
    (dl_dh, ) = grad(loss, (h, ))
    if (i+1)%100==0:
        this_time = time.time()
        print 'Iterations {} to {}: {:.3g}s'.format(i-100, i, this_time - last_time)
        last_time = this_time

assert 103.32 < dl_dh.abs().sum().data.numpy()[0] < 103.33

When h.detach() is used, it shows that the rate of computation stays roughly fixed:

Iterations 0 to 100: 0.0367s
Iterations 100 to 200: 0.032s
Iterations 200 to 300: 0.0321s
Iterations 300 to 400: 0.032s
Iterations 400 to 500: 0.0319s
Iterations 500 to 600: 0.032s
Iterations 600 to 700: 0.0321s
Iterations 700 to 800: 0.0429s
Iterations 800 to 900: 0.0328s
Iterations 900 to 1000: 0.032s

Whereas when it is not used, it slows down:

Iterations 0 to 100: 0.0736s
Iterations 100 to 200: 0.158s
Iterations 200 to 300: 0.279s
Iterations 300 to 400: 0.367s
Iterations 400 to 500: 0.37s
Iterations 500 to 600: 0.467s
Iterations 600 to 700: 0.581s
Iterations 700 to 800: 0.717s
Iterations 800 to 900: 0.74s
Iterations 900 to 1000: 0.793s

What I don’t know is what all this computation is doing (I guess just some overhead due to the ever-growing graph).

4 Likes