How to implement weighted mean square error?

You can probably use a combination of tensor operations to compute your loss.
For example

def mse_loss(input, target):
    return torch.sum((input - target) ** 2)

def weighted_mse_loss(input, target, weight):
    return torch.sum(weight * (input - target) ** 2)
11 Likes