How to write a loss function with mask?

What I want is fairly simple: a MSE loss function, but able to mask some items:

def masked_mse_loss(a, b, mask):
    sum2 = 0.0
    num = 0
    for i in len(range(a)):
        if mask[i] == 1:
            sum2 += (a[i] - b[i]) ** 2.0
            num += 1
    return sum2 / num

Due to backward issue, I believe such a straightforward implementation would not make pytorch happy, but I have no idea on how to make it in correct way.

I was under the belief that pytorch would accept this kind of loss function. What error does it give you?

I wrote the loss class:

class MaskedMSELoss(torch.nn.Module):
    def __init__(self):
        super(MaskedMSELoss, self).__init__()

    def forward(self, input, target, mask):
        diff2 = (torch.flatten(input) - torch.flatten(target)) ** 2.0
        sum2 = 0.0
        num = 0

        flat_mask = torch.flatten(mask)
        assert(len(flat_mask) == len(diff2))
        for i in range(len(diff2)):
            if flat_mask[i] == 1:
                sum2 += diff2[i]
                num += 1

        return sum2 / num

What would be default backward looks like?

You can call out.backward() to backpropagate the error.

predict = torch.tensor([1.0, 2, 3, 4], dtype=torch.float64, requires_grad=True)
target = torch.tensor([1.0, 1, 1, 1], dtype=torch.float64,  requires_grad=True)
mask = torch.tensor([1, 0, 0, 1], dtype=torch.float64, requires_grad=True)
out = torch.sum(((predict-target)*mask)**2.0)  / torch.sum(mask)
out.backward()
2 Likes

I modified the algorithm to a more simple form that removed for cycle and if, and it works properly:

diff2 = (torch.flatten(input) - torch.flatten(target)) ** 2.0 * torch.flatten(mask)
result = torch.sum(diff2) / torch.sum(mask)
return result

Thanks!