Masked loss on a batch of images

I would like to implement a masked MSE loss function for training a model. I am trying to get hadamard product between mask image and original image. Is there any torch function for this?

2 Likes

Got it.

class MaskedMSE(nn.Module):
    def __init__(self):
        super(MaskedMSE, self).__init__()
        self.criterion = nn.MSELoss()

    def forward(self, input, target, mask_a, mask_b):
        self.loss = self.criterion(input*mask_a, target*mask_b)
        return self.loss
4 Likes

This work for you in batches?