# MSE (L2) loss on two masked images

I’m working on an Inpainting project.
I’m trying to calculate the L2 reconstruction loss of only the regions that have been masked.
I can mask the output with

``````criterion = nn.MSELoss()
``````

But the problem is that MSELoss averages the loss over all pixels, and since most of my pixels are masked the loss becomes very small.
Is there a way to calculate the MSELoss only on non-masked pixels?

Just as a side note, the mask that I use is a randomly generated mask (not fixed)

1 Like

In case the question wasn’t clear:
I’m looking for a feature similar to ignore_index in nn.CrossEntropyLoss() so that I can ignore the loss on certain pixel values.

you can multiply the `loss` by `masked_input.numel()/mask.sum()` to correct for the number of ON pixels.

I was actually gonna reply that I solved the issue.
Basically the same solution, but I set the MSELoss reduction parameter to ‘sum’
and then divided by the total number of nonzero pixels in the mask.

2 Likes

Hi, this method work with a batch on gpu? Or have to do something else? I have diferent result with batch size 1 and greater batch size

did you try to make the reduce=False in `torch.nn.` `MSELoss` ( size_average=None , reduce=None , reduction=‘mean’ )? It will be return a vector. You can make a mask by yourself. For example, your output is `[1,2,3,4]`, you don’t want to calculate the loss for `1` and `4`, you can just set `mask = [0,1,1,0]` and your `loss_matrix = output * mask` and then `loss = avg(loss_matrix)`

here is the Doc:
reduce ( bool , optional ) – Deprecated (see `reduction` ). By default, the losses are averaged or summed over observations for each minibatch depending on `size_average` . When `reduce` is `False` , returns a loss per batch element instead and ignores `size_average` . Default: `True`

2 Likes

could you pls post your solution?

Hey numpee

I suggest that instead of multiplying mask tensor with the prediction, either use fancy indexing or `torch.gather` to get rid of unwanted regions. Yout suggested solution is totally waste of computation!