In my setup, I have 2 networks and I am trying to run backward
on both.
my_losses = [net1.get_loss(), net2.get_loss()]
for loss in my_losses:
loss.backward()
When I run this, the first iteration of the for
loop works fine, but the second call to backward
throws the following error :
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
If I wrap the loss
in a torch.autograd.Variable
and set required_grad=True
this works. But I am not sure if that is doing the right thing. Also, I am not able to understand why I am seeing the error.