Hi all, I have a computational graph x -> y -> loss.

I want to compute d_loss/d_y, but NOT d_loss/d_x

How can I do this in PyTorch?

I’ve heard you can somehow use the register_backward_hook method, but I have not found an example of this working. I’ve attached an attempt below, but it seems that the backward hook function, which is supposed to block the gradient computation, is never called.

``````import torch
import numpy as np

"""
We have the following equations:

x = 3
y = 4*x
loss = (v2-10)**2

This should yield
d_loss_d_y == 2*(3*4-10) == 4
d_loss_d_x == d_loss_d_y*4 == 16

Now - we ONLY want to calcuate d_loss_d_y.  We want to block the gradient backpropagation
so that we don't bother calculating d_loss_d_x.

"""

class MyOp(torch.nn.Module):

def __call__(self, x):
return x*4

op = MyOp()

print 'Backward Hook Called (THIS NEVER HAPPENS)'
return None  # (Should stop further backpropagation?)

op.register_backward_hook(my_hook)
y = op(x)

loss = ((y - 10.) ** 2).sum()
loss.backward()

# The following should raise some kind of exception .. because d_loss_d_x shouldn't be computed
2 Likes

Use torch.autograd.grad and don’t call `backward()` on `loss`

``````>>> x = torch.autograd.Variable(torch.FloatTensor([3.0]), requires_grad=True)
>>> y = 4*x
>>> loss = (y - 10)**2
(Variable containing:
4
[torch.FloatTensor of size 1]
,)
True
``````
2 Likes

Nice! never heard of that. Thanks

Can’t you simply use x.detach() ?

4 Likes

Ah, so it appears that you’re both kind of right. Grad does seem to calculate only the needed gradient (leaving other variables .grad as None). But still, to avoid ever-growing computation with larger graphs, you need `detach`'.

The following snippet demonstrates the need for detach:

``````import time
import torch

torch.manual_seed(1234)

wt = torch.nn.Linear(100, 100)
last_time = time.time()

for i in xrange(1000):

h = torch.tanh(wt(h.detach()))  # Maintains a constant speed
# h = torch.tanh(wt(h))  # Gets slower and slower
loss = (h**2).sum()
(dl_dh, ) = grad(loss, (h, ))
if (i+1)%100==0:
this_time = time.time()
print 'Iterations {} to {}: {:.3g}s'.format(i-100, i, this_time - last_time)
last_time = this_time

assert 103.32 < dl_dh.abs().sum().data.numpy()[0] < 103.33
``````

When h.detach() is used, it shows that the rate of computation stays roughly fixed:

``````Iterations 0 to 100: 0.0367s
Iterations 100 to 200: 0.032s
Iterations 200 to 300: 0.0321s
Iterations 300 to 400: 0.032s
Iterations 400 to 500: 0.0319s
Iterations 500 to 600: 0.032s
Iterations 600 to 700: 0.0321s
Iterations 700 to 800: 0.0429s
Iterations 800 to 900: 0.0328s
Iterations 900 to 1000: 0.032s
``````

Whereas when it is not used, it slows down:

``````Iterations 0 to 100: 0.0736s
Iterations 100 to 200: 0.158s
Iterations 200 to 300: 0.279s
Iterations 300 to 400: 0.367s
Iterations 400 to 500: 0.37s
Iterations 500 to 600: 0.467s
Iterations 600 to 700: 0.581s
Iterations 700 to 800: 0.717s
Iterations 800 to 900: 0.74s
Iterations 900 to 1000: 0.793s
``````

What I don’t know is what all this computation is doing (I guess just some overhead due to the ever-growing graph).

4 Likes