Custom loss function does not work as expected on large batches

Hi, I’m new to pytorch, and I’d like you to help me with something. I’m doing segmentation and need to apply a mask to the lost function nn.MSELoss() where:

I have the following function:

loss = mask.numel()/mask.sum() * self.criterion(output, target)

mask: is a binary mask of 0 and 1
output: is the output of my network
target: is the objective of the network

when I apply the function to a batch size of size 1, it works correctly.
But when I apply it to a batch size greater than 1, I don’t get the desired results.

Try defining a custom loss function, but it doesn’t seem to be the solution either,

class MaskedMSE(nn.Module):
    def __init__(self):
        super(MaskedMSE,self).__init__()
        self.criterion = nn.MSELoss()

    def forward(self,input_a,target,density_mask):
        mask_f = torch.mean((density_mask.numel()/density_mask.shape[0])/density_mask.sum(dim=1).sum(dim=1).sum(dim=1)))

        self.loss = mask_f*self.criterion(input_a,target)
        return self.loss

I hope you can give me some guidance.

I’m not sure what the desired result is, but maybe this manual approach might help:

# Single sample
N = 1
mask = torch.randint(0, 2, (N, 10, 10)).float()
output = torch.empty(N, 10, 10).uniform_(0, 1)
target = torch.randint(0, 2, (N, 10, 10)).float()
criterion = nn.MSELoss()
loss = mask.numel()/mask.sum() * criterion(output, target)

# Batch of 10 samples
N = 10
mask = torch.randint(0, 2, (N, 10, 10)).float()
output = torch.empty(N, 10, 10).uniform_(0, 1)
target = torch.randint(0, 2, (N, 10, 10)).float()
criterion = nn.MSELoss(reduction='none')
loss = (mask.size(1)*mask.size(2))/mask.sum([1, 2]) * criterion(output, target).mean([1, 2])
loss = loss.mean()
1 Like

Thank you. It works as expected.