Time for your regularly scheduled custom loss-function question !
I’m trying to create a variant of MSE. My target and my prediction are RGB color images. What I’d like to do is weigh certain target colors differently, so that pixels that were originally black don’t get penalized as much as non-black pixels.
For grayscale images, this can be done rather easily:
# Pixels originally black should only contribute 50% scale = 0.5 y_true_f = y_true.view(-1) y_pred_f = y_pred.view(-1) diff = y_true_f - y_pred_f zero = y_true.eq(0).float() result = diff.addcmul(-scale, zero, y_true_f) mse = torch.mean(result**2,dim=0)
The problem: My images are normalized + standardized, per channel, before being fed into the net. So the new 0-value for red might be -0.3, the 0-value for green might be 0.01, and the 0-value for blue might be -0.1.
How can I alter the above so that instead of checking for 0 directly, I check for the ‘transformed’ 0-value in the R,G,B channels…simultaneously?
y_true comes in as
[batch, channel, y, x] so I can use indexing to split it out… But I don’t just want to perform the above if the Red value is equal to the transformed 0… or if the G value == the transformed 0… but rather if All 3 RGB channels equal their respective transformed zero.
I hope the question is clear