CrossEntropyLoss on subset of output

Hi all,

I’m working on what is effectively an image segmentation problem. However, the majority of the pixels have a category that of ‘I don’t care’. By this I mean that I don’t want the weights affected one way or the other no matter what category those pixels are sorted into.

The pixels that I actually care about are disjoint and represent only a small percentage of the total pixels in the output.

Is there a way to perform CrossEntropyLoss on those scattered disjoint pixels only, while ignoring all of the other pixels? So far, the only sure-fire method I’ve thought of is to visit each pixel of interest and calculate the loss and then add all of the losses together.

Is there a better way?

Thanks in advance!

If you know the position of the irrelevant pixels, you could create a mask with zeros for them (and ones for the valid locations), and multiply your unreduced loss with it (reduction='none' for the loss function initialization). Afterwards you could calculate the mean (or normalize the loss somehow) and call .backward() as usual.

1 Like

I tried something similar already, except I applied the mask to the candidate output before putting it through the loss calculation. (The target already had zero at those locations.)

It failed to improve with training, perhaps due to the non-differentiable nature of the operation. I’ll try doing it to the loss result like you are suggesting and see if that works better.

Thanks,

Hi Avi!

In addition to @ptrblck’s suggestion, you could try using the
class-weights feature of CrossEntropyLoss. This allows to
to give each class its own weight in the overall loss. For your
use case I would imagine setting the weight of your “I don’t care”
class to zero and the weights of your “real” classes to one.

Note, both this and @ptrblck’s approach entail calculating the
loss for all of the pixels, and then discarding most of the calculated
losses (by multiplying them by zero). This may seem inefficient
if most of your pixels are “I don’t care,” but, presumably, when an
image is pumped through your network, (almost) all of the input
pixels affect the per-pixel predictions for the “real” pixels, so the
overall network processing dominates the cost of calculating the
loss.

Good luck.

K. Frank

Thanks. Another great idea.

Is that similar to using the ignore_index option (which I just discovered)?