Hi, I’m new to pytorch, and I’d like you to help me with something. I’m doing segmentation and need to apply a mask to the lost function nn.MSELoss() where:

I have the following function:

```
loss = mask.numel()/mask.sum() * self.criterion(output, target)
```

mask: is a binary mask of 0 and 1

output: is the output of my network

target: is the objective of the network

when I apply the function to a batch size of size 1, it works correctly.

But when I apply it to a batch size greater than 1, I don’t get the desired results.

Try defining a custom loss function, but it doesn’t seem to be the solution either,

```
class MaskedMSE(nn.Module):
def __init__(self):
super(MaskedMSE,self).__init__()
self.criterion = nn.MSELoss()
def forward(self,input_a,target,density_mask):
mask_f = torch.mean((density_mask.numel()/density_mask.shape[0])/density_mask.sum(dim=1).sum(dim=1).sum(dim=1)))
self.loss = mask_f*self.criterion(input_a,target)
return self.loss
```

I hope you can give me some guidance.