assume I have the model outpus in the form (N, D) where N is the batch size. Based on some flag, I want to use different loss functions for different batches; i.e. I want to call a certain loss function on row 1, 4 and 5 of outputs, and a different loss function for the remaining rows.
Clearly I could reduce batch size to 1, so that I can determine which loss function to use on a case-by-case basis. But that would be too inefficient.
How can I achieve this? Thank you!