What determines if torch.autograd.grad output has requires_grad=True?

Hi,
I have a network whose loss depends on both the prediction and the gradients.
In my python 0.3 code I used to do:

grads = torch.autograd.grad(predV, feats, torch.ones_like(predV))
ipdb> grads[0].requires_grad
True
ipdb> predV.requires_grad
True
ipdb> feats[0].requires_grad
True
ipdb> torch.ones_like(predV).requires_grad
False

Notice how the output has requires_grad True?

After some refactoring for 0.4 it suddenly has requires_grad False and now the network will not backprop the error of the gradients anymore.

The weird thing is that I try to replicate it with some toy example in pytorch 0.3 and I cannot.

import torch
x = torch.autograd.Variable(torch.ones((3, 4)), requires_grad=True)
y = x+5
torch.autograd.grad(y, x, torch.ones_like(y))[0].requires_grad
# False

So I would like to know, what exactly makes the output of torch.autograd.grad require gradients?

Hi,

You need to set create_graph=True to get the graph constructed when using autograd.grad (and thus the gradients that can require gradients. See doc for more details.

This had a different behavior in 0.3.1 as can be seen in the old doc where this flag was set to True if the gradients provided are not volatile. But volatile has been removed hence this change.

I see, I had not seen this detail about volatile.

This might warrant a new issue but since it’s on the same problem of training with gradients in the loss… On my second batch that I train I get:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

Each batch loop now does roughly (cutting out details):

optimizer.zero_grad()
y = model.forward(x)
grads = torch.autograd.grad(y, x, torch.ones_like(y), create_graph=True) # to get the gradients 
loss = realgrad - grads
loss.backward() # to backprop the errors. The second time (second loop) this is called it errors
optimizer.step()

I don’t understand how I could get an error on the second batch while it works fine on the first batch.

I would guess that you do operations that require gradients outsite of your training loop. And so this part is common to every batches and will be ok for the first one. But the second one will try to use this same part of the graph and will fail because you already backproped through it.

1 Like

So just to be sure, here you mean as “training loop” everything that is between zero_grad and backward, right?

Yes you are right, I preallocate some tensors before starting the training (I really need to preallocate them for this application). What would be the best way of resetting them so that they can be backpropped again?

In the past I just re-wrapped them in a Variable and that worked

yes, training loop is the loop that iterates over the different batches.
These buffer should not require grad or perform operations linked to things that require grad. If they do and you don’t want these to be part of the backward, use the torch.no_grad() context manager.

1 Like