What's the purpose of "retain_variables" in Variable backward function

How to use “retain_variables” in Variable backward function.

I tried the following code:

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad = True)
y = x + 2
y.backward(torch.ones(2, 2), retain_variables=True )
print "first gradient of x is:"
print x.grad
z = y * y
gradient = torch.ones(2, 2)
z.backward(gradient)
print "second gradient of x is:"
print x.grad

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad = True)
y = x + 2
y.backward(torch.ones(2, 2), retain_variables=False)
print "first gradient of x is:"
print x.grad
z = y * y
gradient = torch.ones(2, 2)
z.backward(gradient)
print "second gradient of x is:"
print x.grad

Both print the same results:
first gradient of x is:
Variable containing:
1 1
1 1
[torch.FloatTensor of size 2x2]

second gradient of x is:
Variable containing:
7 7
7 7
[torch.FloatTensor of size 2x2]

1 Like

Hi,

According to http://pytorch.org/docs/autograd.html#torch.autograd.Variable.backward this flag is used to prevent any buffer from being freed during the backprop (this is usually done to reduce the memory requirements).
In practice, that means that calling y.backward twice is only possible if the first one was done with retain_variables=True.
You can see this behaviour in the sample below when switching retain_variable between True and False:

import torch
from torch.autograd import Variable
x = Variable(torch.ones(2, 2), requires_grad = True)
y = x ** 2
y.backward(torch.ones(2, 2), retain_variables=False)
print "first backward of x is:"
print x.grad
y.backward(2*torch.ones(2, 2), retain_variables=False)
print "second backward of x is:"
print x.grad
2 Likes

I see. Thanks very much.