Custom Loss Function for a Batch

Hi, I have to imlpement a custom loss function in pytorch. The network outputs a single tensor and my batch size if 1000. I have to take each tensor from the output and do some operations with it and return a value. My question is, do I need a for loop to iterate over all the tensors in the output batch, find the loss for each tensor and then do the mean or how can I do it more efficiently ?

It depends on the loss formula, but you would want to vectorize the loss instead of using for loops if possible l. Most PyTorch functions lend themselves to it quite naturally through broadcasting.

Best regards

Thomas

Thank for the reply.
The nature of my loss is such that i have to iterate over a list of numbers for each ouput of my network and then calculate the total loss. So, since my network kas 1 output and a batch size of 1000, is there any way to do it without a for loop ?

At this level of generality, the answer necessarily is “maybe”, but so a typical thing would be to do the iteration in a loop but batched on all batch items at once.
In my courses I always give Thomas’s rule of thumb: A for loop is OK as long as the things you operate on have 100s of items, but if you are down to 10s or single numbers, you are in trouble for efficiency.

While there are cases where it is hard or even not doable without custom kernels, many cases lend themselves to a batch strategy.

Best regards

Thomas

1 Like