Use different loss functions for different elements of a batch

Hello,

assume I have the model outpus in the form (N, D) where N is the batch size. Based on some flag, I want to use different loss functions for different batches; i.e. I want to call a certain loss function on row 1, 4 and 5 of outputs, and a different loss function for the remaining rows.

Clearly I could reduce batch size to 1, so that I can determine which loss function to use on a case-by-case basis. But that would be too inefficient.

How can I achieve this? Thank you!

You can

A = output[[1,4,5]]
B = #rest

and do different losses on A and B

1 Like

Thank you! I was trying to do it with torch.index_select and it was raising errors. This way works! :slight_smile: