How to combine multiple criterions to a loss function?

Yeah, you can optimize a variable number of losses without a problem. I’ve had an application like that, and I just used total_loss = sum(losses) where losses is a list of all costs. Calling .backward() on that should do it. Note that you can’t expect torch.sum to work with lists - it’s a method for Tensors. As I pointed out above you can use sum Python builtin (it will just call the + operator on all the elements, effectively adding up all the losses into a single one).

Iterating over a batch and summing up the losses should work too, but is unnecessary, since the criterions already support batched inputs. It probably doesn’t work for you, because criterions average the loss over the batch, instead of summing them (this can be changed using the size_average constructor argument). So if you have very large batches, your gradients will get multiplied by the batch size, and will likely blow up your network. Just divide the loss by the batch size and it should be fine.

33 Likes