Hi, I’m new to pytorch, and I’d like you to help me with something. I’m doing segmentation and need to apply a mask to the lost function nn.MSELoss() where:
I have the following function:
loss = mask.numel()/mask.sum() * self.criterion(output, target)
mask: is a binary mask of 0 and 1
output: is the output of my network
target: is the objective of the network
when I apply the function to a batch size of size 1, it works correctly.
But when I apply it to a batch size greater than 1, I don’t get the desired results.
Try defining a custom loss function, but it doesn’t seem to be the solution either,
class MaskedMSE(nn.Module):
def __init__(self):
super(MaskedMSE,self).__init__()
self.criterion = nn.MSELoss()
def forward(self,input_a,target,density_mask):
mask_f = torch.mean((density_mask.numel()/density_mask.shape[0])/density_mask.sum(dim=1).sum(dim=1).sum(dim=1)))
self.loss = mask_f*self.criterion(input_a,target)
return self.loss
I hope you can give me some guidance.