Multi loss backprop for mulriple actions

Inspired by and I’m writing a reinforcement learning program and I’m implementing multiple action heads sharing a common network.


I want to propagate loss1, loss2 and loss3 to the output of the shared network and MSE(loss1,loss2,loss3) from the output of the shared network all the way back. The shared network and each of the action head networks have their own network classes and torch.optim.

I’ve tried

loss1 = loss_fn(x1,y1)


loss2 = loss_fn(x2,y2)


loss3 = loss_fn(x3,y3)

But I’m getting this error during the second backward() call

“Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.”

How can I do this? What am I not understanding about how Pytorch performs backprop? Thanks.

You can combine the losses from all heads, and backprop that using a single optimizer that is initialized with all the parameters in your model (shared network and each of the action heads). Read more about backpropagating loss in multi-task learning architectures.

So your code might look like

optim = optimizer(full_model.parameters(), ...)
combined_loss = weighted_sum(loss1, ... lossN)

Action_head1 will be updated by gradients only w.r.t. loss1 (and so on), and the shared network will be updated by gradients wrt to all losses.

Thank you sir. I had tried that before but I couldn’t make my agent learn properly, so that’s why I was trying something different. Is there a good reference that explains how .backward() and .step() actually do their magic? This is the most difficult aspect of Pytorch for me to understand. Thanks.

You’re welcome! It can be a bit tricky to wrap your head around what’s happening at first under the hood, but very simply:

  1. autograd creates a graph of all your model’s operations
  2. when you call backward, it traverse this graph and collects gradients (using backpropagation). once it reaches the end of the graph, it is released from memory. that’s why you get the error on the second backward() call
  3. when you call step(), every parameter is updated by the gradients collected in the backward() pass.

More detail on how autograd works