[solved] Custom Loss Function Woes

Time for your regularly scheduled custom loss-function question :slight_smile: !

Iā€™m trying to create a variant of MSE. My target and my prediction are RGB color images. What Iā€™d like to do is weigh certain target colors differently, so that pixels that were originally black donā€™t get penalized as much as non-black pixels.

For grayscale images, this can be done rather easily:

# Pixels originally black should only contribute 50%
scale    = 0.5

y_true_f = y_true.view(-1)
y_pred_f = y_pred.view(-1)
diff     = y_true_f - y_pred_f

zero     = y_true.eq(0).float()
result   = diff.addcmul(-scale, zero, y_true_f)
mse      = torch.mean(result**2,dim=0)

The problem: My images are normalized + standardized, per channel, before being fed into the net. So the new 0-value for red might be -0.3, the 0-value for green might be 0.01, and the 0-value for blue might be -0.1.

How can I alter the above so that instead of checking for 0 directly, I check for the ā€˜transformedā€™ 0-value in the R,G,B channelsā€¦simultaneously? y_true comes in as [batch, channel, y, x] so I can use indexing to split it outā€¦ But I donā€™t just want to perform the above if the Red value is equal to the transformed 0ā€¦ or if the G value == the transformed 0ā€¦ but rather if All 3 RGB channels equal their respective transformed zero.

I hope the question is clear :sweat_smile:

1 Like

Figured this out.

Though as a human I understand all 3 channel components as being inherantly tied together, for the machine, they need not be. If my target is (0,0,0) and I have two predictions: (100,100,100) and (50,100,100), the loss assigned to the later of course would be less than the former. Tl;dr - handle each channel individually, since they each have their own target minimal values.