I would like to implement a masked MSE loss function for training a model. I am trying to get hadamard product between mask image and original image. Is there any torch function for this?
2 Likes
Got it.
class MaskedMSE(nn.Module):
def __init__(self):
super(MaskedMSE, self).__init__()
self.criterion = nn.MSELoss()
def forward(self, input, target, mask_a, mask_b):
self.loss = self.criterion(input*mask_a, target*mask_b)
return self.loss
4 Likes
This work for you in batches?