Normally when we’re doing backprop we would do the following:
loss.backward() # This calculates the gradients
optimizer.step() # This updates the net
However, what if I wish to accumulate the gradients? Meaning I want to run various loss.backward() multiple times first and accumulate the gradients before applying them in optimizer.step().
What do you recommend for storing the gradients? grads = loss.backward()?
How do I feed the accumulated gradients to optimizer.step?
When you perform loss.backward() the gradients are accumulated inplace in each Variable that requires gradient.
That is why you need to perform optimizer.zero_grad() before each backward.
If you want to accumulate gradients from multiple backwards, you can just backward multiple times without resetting the gradients:
optimizer.zero_grad()
for i in range(minibatch):
loss = model(batch_data[i])
loss.backward()
optimizer.step()
Not that loss functions average over the batch size, so if you do multiple backprops you might need to average over the number of the for loop iterations.
Yes, it’s always going to be slower, but it’s a tradeoff between performance and memory usage. Try to do as few iterations as you can (you can split each batch into smaller sub-batches, so that they nearly fill up the memory).
@apaszke, @albanD, I also tried to achieve this. As you have said, doing backward() for each sample is slow compared to accumulating the loss, doing one average and then doing backward. Here is my code
num_epoch = 10
real_batchsize = 100 # I want to update weight every `real_batchsize`
for epoch in range(num_epoch):
total_loss = Variable(torch.zeros(1).cuda(), requires_grad=True)
for batch_idx, (data, target) in enumerate(train_loader):
data, target = Variable(data.cuda()), Variable(target.cuda())
output = net(data)
total_loss = total_loss + loss
if batch_idx % real_batchsize == 0:
ave_loss = total_loss/real_batchsize
ave_loss.backward()
optimizer.step()
total_loss.data.zero_()
optimizer.zero_grad()
The above code will produce an error:
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
I have looked up this issue, but I am not very clear. Just feels that we need a new total_loss after each weight update, so I replace the line
yes you can add a python number to a Variable (or a Tensor) and the output is going to be a Variable (or a Tensor) of the same type as the input (so if it was on gpu, the output will be on gpu).
Doing this is not the same as @albanD 's first answer.
When you call ave_loss.backward() you propagate errors with respect to your (correct) loss, but these errors are functions of what the activations are when the .backward method was called. Since you’ve thrown away all but the last 10 samples, you are making the assumption that the first 90 samples were the same as the last 10.
Regarding this tradeoff, do you save time/memory by using retain_graph=True in that situation ?
For example, my current code looks like this:
x = tensor(x0,requires_grad=True)
loss = 0
for i in range(inputs.numel()): # For my apps, it's between 5 and 50.
rec = f(x,i)
loss += loss_func(inputs[i], rec)
loss.backward()
g = x.grad
My current problem is that the computational graph takes too much memory because the function f does a lot of computation. So a solution would be to do as @albanD suggested:
x = tensor(x0, requires_grad=True)
loss = 0
for i in range(inputs.numel()):
rec = f(x,i)
loss += loss_func(inputs[i], rec)
loss.backward()
g = x.grad
But I feel like the computationnal graph for each iteration of that loop is the same, it’s just the numbers on which we apply it that change. So maybe we could reuse the previous iteration’s graph (by specifying retain_graph=True), could that save some time ? If not, what would happen (in terms of time/memory loss/gain) ?
x = tensor(x0, requires_grad=True, retain_graph=True)
loss = 0
for i in range(inputs.numel()):
rec = f(x,i)
loss += loss_func(inputs[i], rec)
loss.backward()
g = x.grad
This is the expected use case: the graph structure is mostly the same, only the values change. The whole framework is built to make this use case efficient.
And you cannot reuse the graph as the graph is associated with the values of each Tensor and so if the values change, you need to recreate it (which is cheap).
I have a question about the params in optimizer.state_dict() and weights in model.state_dict()
From the source code, I found that the grad is computed during backward() and the weights are updated during the optimizer.step().
I try to output the model.state_dict and optimizer.state_dict before and after backward() and optimizer.step() respectively.
If I save the state_dict by
state_dict = model.state_dict()[key]
And they are the same, before and after backward(), it means that assignment operation of tensors is shared memory operation?
Another question is what does params in optimizer.state_dict mean? There is not any change before and after backward and optimizer.step, does it mean an address to the weight?
.backward() will just populate the .grad fields of the parameters. These gradients are not saved in the state dicts and so nothing will change there.
after opt.step() the values of the parameters will be changed inplace. So if you want to see the difference before and after, you need to clone the original Tensor.
optimizer.state_dict() is dependant on the optimizer itself. It will contain whatever is needed for this optimizer to continue working as if it was not stopped (saving things like momentum terms or statistics).
Oh I see, I guess I was confused by the name retain_graph. I’ve searched a bit and see that it was called retain_variables before. So I guess if I use retain_graph=True while putting loss.backward() inside the for loop, it defeats the purpose of saving memory because it will keep in memory the temporary tensors needed for the previous gradient, right ?
sorry for reply to so old post, i encountered a problem recently:
in Pytorch distribution code, how can i keep the gradient graph (autograd.grad) while using dist.all_reduce() or dist.all_gather()? to avoid the situation that I need to manually calculate the gradient, then backward.