How can Pytorch calculate total loss like Tensorflow?

In Tensorflow, we can do backward like this:

loss = criterion(output, target)
tf.add_loss(loss)
total_loss = tf.losses.get_total_loss()
optimizer = tf.train.MomentumOptimizer(learning_rate=init_lr, momentum=0.9).minimize(total_loss, global_step=global_step)

But in Pytorch, I can’t find any methods like add_loss or get_total_loss. Then I try to use a custom loss function like this:

class Loss_func(nn.Module):
    def __init__(self):
        super(Loss_func, self).__init__()
        self.totalLoss = 0
        return

    def forward(self, output, target):
        temp_loss = criterion(output, target)
        self.totalLoss = self.totalLoss + temp_loss
        return self.totalLoss

I found this error:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

I changed the backward method like this:

loss.backward(retain_graph=True)

A new error occurs:

Warning: Error detected in AddmmBackward. Traceback of forward call that caused the error:
File “/vol/research/Jigsaw/jigsaw_torch.py”, line 231, in train
features, outputs = model(batchImg.cuda())
File “/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/vol/research/Jigsaw/jigsaw_torch.py”, line 86, in forward
x = self.fc(x)
File “/lib/python3.8/site-packages/torch/nn/modules/module.py”, line 727, in _call_impl
result = self.forward(*input, **kwargs)
File “/lib/python3.8/site-packages/torch/nn/modules/linear.py”, line 93, in forward
return F.linear(input, self.weight, self.bias)
File “/lib/python3.8/site-packages/torch/nn/functional.py”, line 1690, in linear
ret = torch.addmm(bias, input, weight.t())
(function _print_stack)
Traceback (most recent call last):
File “/vol/research/Jigsaw/jigsaw_torch.py”, line 439, in
main()
File “/vol/research/Jigsaw/jigsaw_torch.py”, line 435, in main
train(trainSet, testSet, tripletSet, model, opt, lossFunc, epoch)
File “/vol/research/sketch/Jigsaw/jigsaw_torch.py”, line 234, in trainAndValidate
loss.backward(retain_graph=True)
File “/lib/python3.8/site-packages/torch/tensor.py”, line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File “/lib/python3.8/site-packages/torch/autograd/init.py”, line 130, in backward
Variable._execution_engine.run_backward(
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1024, 81]], which is output 0 of TBackward, is at version 2; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!

The model is GoogleNet. Does this error mean that the self.fc = nn.Linear(1024, num_classes) in GoogleNet doesn’t support calculating backward twice?

How can I use total loss to calculate backward?

Hi,

If we assume TF.get_total_loss computes just adddition of multiple loss terms then we have these scenarios:

But the error you are getting is from calling loss.backward for multiple times.

Lets say you have following scenario (scenario #1):

input = #sth
ground_truth = #sth
output = model(input)
loss1 = criterion(output, ground_truth)
loss2 = criterion(changed_output, another_ground_truth)
# etc
total_loss = loss1 + loss2 + ...
total_loss.backward()  # this works 

Now scenario #2:

input = #sth
ground_truth = #sth
output = model(input)
loss1 = criterion(output, ground_truth)
loss1.backward()  #now graph buffers have been freed, you cannot do backward on the same forward 
loss2 = criterion(changed_output, another_ground_truth)
loss2.backward()  # error, trying to do backward second time
# etc
total_loss = loss1 + loss2 + ...
total_loss.backward()  # this obviously does not work

You have not showed how you compute loss.backward or how many times you call it, but if you set retain_graph=True for previously used backwards, such as loss1 and loss2 in scenario #2, then you should be able to do another backward using total_loss.

TF docs are not clear but I think they are using first scenario.

Bests

I have a question.

input1 = #sth
input2 = # sth
label1 = # sth
label2 = # sth
output1 = model(input1)
output2 = model(input2)
loss1 = criterion(output1, label1)
loss2 = criterion(output2, label2)
total = loss1+loss2
total_loss.backward()

Does total loss consider the effect of loss1? I mean will

output2 = model(input2)

overwrite the gradient or any information about input1 ?

Short answer is yes. Autograd will accumulate gradients unless user explicitly changes (zeroes) them.

I think this tutorial A Gentle Introduction to torch.autograd — PyTorch Tutorials 1.8.1+cu102 documentation is a good demonstration. You can think of your problem as a simple equation c = a + b where a and b have their own computational graph (i.e. operations).

1 Like

Thanks for your help~ :+1: