Epoch is slow with an addition in loss function

I have a network which is getting loss from three streams. I am adding them as the final loss like this:

loss = nn.BCEWithLogitsLoss(reduction=‘sum’)

Now if I use two of the loss terms then each epoch is taking close to 5 mints. Anytime I add the third one it is taking like 10 minutes. For using two terms I didn’t change anything I just commented out the third term.
Any idea, why this is happening?

Is the size of outputs3 significantly different from the size of outputs1 and outputs2? Also are you using CPU or GPU for your computations? (And is there any reason you’re passing in labels1.float() instead of just labels1?)

You could also try running torch.utils.bottleneck on your code to see where exactly the slowdown is happening.