[solved] Custom Loss Function Woes

Time for your regularly scheduled custom loss-function question :slight_smile: !

I’m trying to create a variant of MSE. My target and my prediction are RGB color images. What I’d like to do is weigh certain target colors differently, so that pixels that were originally black don’t get penalized as much as non-black pixels.

For grayscale images, this can be done rather easily:

# Pixels originally black should only contribute 50%
scale    = 0.5

y_true_f = y_true.view(-1)
y_pred_f = y_pred.view(-1)
diff     = y_true_f - y_pred_f

zero     = y_true.eq(0).float()
result   = diff.addcmul(-scale, zero, y_true_f)
mse      = torch.mean(result**2,dim=0)

The problem: My images are normalized + standardized, per channel, before being fed into the net. So the new 0-value for red might be -0.3, the 0-value for green might be 0.01, and the 0-value for blue might be -0.1.

How can I alter the above so that instead of checking for 0 directly, I check for the ‘transformed’ 0-value in the R,G,B channels…simultaneously? y_true comes in as [batch, channel, y, x] so I can use indexing to split it out… But I don’t just want to perform the above if the Red value is equal to the transformed 0… or if the G value == the transformed 0… but rather if All 3 RGB channels equal their respective transformed zero.

I hope the question is clear :sweat_smile:

1 Like

Figured this out.

Though as a human I understand all 3 channel components as being inherantly tied together, for the machine, they need not be. If my target is (0,0,0) and I have two predictions: (100,100,100) and (50,100,100), the loss assigned to the later of course would be less than the former. Tl;dr - handle each channel individually, since they each have their own target minimal values.