MSE (L2) loss on two masked images

I’m working on an Inpainting project.
I’m trying to calculate the L2 reconstruction loss of only the regions that have been masked.
I can mask the output with

criterion = nn.MSELoss()
masked_pred = torch.mul(prediction, mask) 
loss = criterion(masked_input, masked_pred) 

But the problem is that MSELoss averages the loss over all pixels, and since most of my pixels are masked the loss becomes very small.
Is there a way to calculate the MSELoss only on non-masked pixels?

Just as a side note, the mask that I use is a randomly generated mask (not fixed)
Thanks in advance

1 Like

In case the question wasn’t clear:
I’m looking for a feature similar to ignore_index in nn.CrossEntropyLoss() so that I can ignore the loss on certain pixel values.

you can multiply the loss by masked_input.numel()/mask.sum() to correct for the number of ON pixels.

Thanks for the reply.
I was actually gonna reply that I solved the issue.
Basically the same solution, but I set the MSELoss reduction parameter to ‘sum’
and then divided by the total number of nonzero pixels in the mask.
your method works too


Hi, this method work with a batch on gpu? Or have to do something else? I have diferent result with batch size 1 and greater batch size

did you try to make the reduce=False in torch.nn. MSELoss ( size_average=None , reduce=None , reduction=‘mean’ )? It will be return a vector. You can make a mask by yourself. For example, your output is [1,2,3,4], you don’t want to calculate the loss for 1 and 4, you can just set mask = [0,1,1,0] and your loss_matrix = output * mask and then loss = avg(loss_matrix)

here is the Doc:
reduce ( bool , optional ) – Deprecated (see reduction ). By default, the losses are averaged or summed over observations for each minibatch depending on size_average . When reduce is False , returns a loss per batch element instead and ignores size_average . Default: True


could you pls post your solution?

Hey numpee

I suggest that instead of multiplying mask tensor with the prediction, either use fancy indexing or torch.gather to get rid of unwanted regions. Yout suggested solution is totally waste of computation!