Why cant I see .grad of an intermediate variable?

Hi Kalamaya,

By default, gradients are only retained for leaf variables. non-leaf variables’ gradients are not retained to be inspected later. This was done by design, to save memory.

However, you can inspect and extract the gradients of the intermediate variables via hooks.
You can register a function on a Variable that will be called when the backward of the variable is being processed.

More documentation on hooks is here: http://pytorch.org/docs/autograd.html#torch.autograd.Variable.register_hook

Here’s an example of calling the print function on the variable yy to print out it’s gradient (you can also define your own function that copies the gradient over else-where or modifies the gradient, for example.

from __future__ import print_function
from torch.autograd import Variable
import torch

xx = Variable(torch.randn(1,1), requires_grad = True)
yy = 3*xx
zz = yy**2

yy.register_hook(print)
zz.backward()

Output:

Variable containing:
-3.2480
[torch.FloatTensor of size 1x1]
42 Likes

Thanks @smth

The only way I have been able to really extract the gradient however is via a global variable at the moment. This is because the function I pass in (apparently) only allows me to pass in one argument, and that is reserved for the yy.grad. What I mean is given here:

yGrad = torch.zeros(1,1)
def extract(xVar):
	global yGrad
	yGrad = xVar	

xx = Variable(torch.randn(1,1), requires_grad = True)
yy = 3*xx
zz = yy**2

yy.register_hook(extract)

#### Run the backprop:
print (yGrad) # Shows 0.
zz.backward()
print (yGrad) # Show the correct dzdy

So here, I am able to extract the yy.grad, BUT, I can only do so with a global variable, which I would rather not do. Is there a simpler way? Many thanks.

3 Likes

Might help to take a look at how optimizers update parameters using the gradient. For instance, this line / block of code in SGD. https://github.com/pytorch/pytorch/blob/master/torch/optim/sgd.py#L45

@mrdrozdov I don’t think this applies to this use case, because optimizers always work with leaf Variables.

@Kalamaya Is there any reason why using a closure is not acceptable? If you can give me some more details about your use case, and why do you need the intermediate gradient, I could probably suggest some other way.

@apaszke Ok - I am not familiar with closures, (learning python still), but from the googling I just did, sounds like it is acceptable for my solution: How would we use closures in this case? The examples I saw all have nested functions and I am not seeing the connection still… many thanks!!

1 Like

You can think of a function that also keeps some additional variables from the outer scope. For example in here hook is a closure that remembers a name given to the outer function:

grads = {}
def save_grad(name):
    def hook(grad):
        grads[name] = grad
    return hook

x = Variable(torch.randn(1,1), requires_grad=True)
y = 3*x
z = y**2

# In here, save_grad('y') returns a hook (a function) that keeps 'y' as name
y.register_hook(save_grad('y'))
z.register_hook(save_grad('z'))
z.backward()

print(grads['y'])
print(grads['z'])
31 Likes

Many thanks! I will process this and let you know how it goes! :slight_smile:

Just for my own knowledge, am I to understand that, given what I am trying to do, the only ways we have are i) global variables, and ii) closures?

Thanks again.

1 Like

I’d say that these are the most obvious ways, but you could probably come up with more sophisticated solutions too. As I said, the best one depends on the specific use case, and it’s hard to provide a one that fits all. I find using closures like above to be ok, others will find something else better.

Thanks again. I will process it tonight and reply back here for my exact use case. Thanks again!

1 Like

Aren’t the gradients of internal nodes necessary for doing backprop?

3 Likes

Yes, they are, but as soon as they have been used and are not necessary anymore, they are freed to save memory

While I understand why this design decision was made, are there any plans to make it easier to save the gradients of intermediate variables? For example, it’d be nice if something like this was supported:

from torch.autograd import Variable
import torch

xx = Variable(torch.randn(1,1), requires_grad = True)
yy = 3*xx
yy.require_grad = True  # <-- Override default behavior
zz = yy**2

zz.backward()

# do something with yy.grad

It seems like it’d be easier to let variables keep track of their own gradients rather than having to keep track of them with my own closures. Then if I want to analyze the gradients of my variables (leaf or not), I can do something like

do_something_with_data_and_grad_of(xx)
do_something_with_data_and_grad_of(yy)

Also, it might be useful to be able to set require_gradients for intermediate variables. For example, I might want to plot a histogram of intermediate variable gradients while not needing gradients for upstream variables. Right now, I’d have to set therequire_gradients flag True to upstream nodes just to make sure that the gradients for this intermediate node are computed, but that seems a bit wasteful.

7 Likes

Is it possible to get the gradients of a torch.nn.Linear module using the way you suggested or am I limited to capturing gradients by defining Variables? Would this work for convolutions or recurrent layers?

1 Like

Is it possible to create a (torch.autograd) flag in order to save all the variable’s gradients?

Looks like PyTorch 0.2.0 now has Variable.retain_grad(): http://pytorch.org/docs/master/autograd.html?highlight=retain_grad#torch.autograd.Variable.retain_grad

The above could now be done via
yy.retain_grad()

16 Likes

Hi @smth, thanks for your reply. I have another question, suppose there are two heads on top of yy, how can we get grad_output from one of them, instead of the addition?

For example, how to get yy’s grad_output from zz1 part?

xx = Variable(torch.randn(1,1), requires_grad = True)
yy = 3*xx
zz1 = yy**2
zz2 = yy**2

yy.register_hook(print)
(zz1+zz2).backward()

Ha, I recently did exactly this. Not sure if its the best way, but I did:

  1. detach yy before feeding to get zzs, e.g. yyy = y.detach()
  2. Manually call autograd.grad to get each of zzs grad w.r.t. yyy.
  3. Save the one you want
  4. call yy.backward(grad_to_yyy_1 + grad_to_yyy_2).

Great, thanks! That’s also what in my mind, basically we need a dummy variable.

BTW, is there something like nn.Identity() in torch? I didn’t find it

AFAIK, there is not. You can write one yourself though, although it won’t
be helpful in this case. The important thing is to detach yy from the graph
so you don’t backward through it to the part before it twice.

Interesting stuff! I have 3 questions in connection (my first ever post here, so I beg for forgiveness about my greenness and the lengthy post :slight_smile:

I’m trying to implement GradNorm, a strategy to dynamically adapt weights for the individual loss contributions in a multi-task learning scenario. So, similar to your example and to the post of @hubert0527 think:

parent module C,

“ending” in some last common conv-layer having weights/bias {W}, feeding into individual

branch modules A and B,

such that for each training iteration (t)

Loss_total(t) = w1(t) * L1(t) + w2(t) * L2(t)

where w1(t) and w2(t) are the dynamically adjusted weights of each single-task loss.

C is a CNN with skip connections (concatenation of feature maps) to module A, which is another CNN, module B is a simple classifier with 3 fully connected layers. I’ll attach a sketch further down.

What one needs to calculate primarily is the norm of the gradient of each of the individual single-task losses w1 * L1 and w2 * L2 with respect to the network parameters {W} of the last common layer (the {W} are leaf variables), which can be done by:

GW_1 = || torch.autograd.grad(w1*L1, W) || (Eq 1 / task 1 / branch A)
GW_2 = || torch.autograd.grad(w2*L2, W) || (Eq 2 / task 2 / branch B)

Those are then further processed to update w1(t+1) and w2(t+1), and then the “normal” backward pass all the way back through module C needs to be performed.
Note that calling Loss_total.backward() does NOT yield GW_1 and GW_2 separately, but only their sum.

Here are my questions:

  1. If I execute (Eq 1) WITHOUT setting retain_graph=True, executing (Eq 2) gives the commonly known error for trying to backprop through one and the same (sub)graph twice, EVEN THOUGH the only thing those 2 subgraphs have in common is the parameters {W} at which they “end” so to speak. Why is that?

  2. The whole thing works if I set retain_graph=True in both the (Eq 1) and (Eq 2) executions and then execute Loss_total.backward(), but using the two retain_graph flags in the .grad-calls uses extra memory, which is kind of a waste.
    Why does the memory usage not peak at the end of the forward pass of the network, aren’t all the buffers needed in the .grad or .backward call(s) allocated exactly then?

  3. Is there any way of doing this without using any retain_graph=True (to save memory) and without doing several forward passes to regenerate the graph (to save time)?
    Maybe similar to the test example of @blackyang and @SimonW or also the test example of @smth in this post in several stages with a dummy variable, only that in my case it’s a bit different since I need the gradients of zz1 and zz2 w.r.t. the WEIGHTS of the final common layer, which would be the ‘3’ in yy = 3*xx of the test example, NOT w.r.t. the output yy of the final common layer. How could that be done?

Sketch:

Long slide! Thanks to anyone who read this far. :wink:

4 Likes