Pixelwise weights for MSELoss

you can do this:

def weighted_mse_loss(input, target, weights):
    out = input - target
    out = out * weights.expand_as(out)
    # expand_as because weights are prob not defined for mini-batch
   loss = out.sum(0) # or sum over whatever dimensions
   return loss
5 Likes