Pixelwise weights for MSELoss

I would like to use a weighted MSELoss function for image-to-image training. I want to specify a weight for each pixel in the target. Is there a quick/hacky way to do this, or do I need to write my own MSE loss function from scratch?

1 Like

you can do this:

def weighted_mse_loss(input, target, weights):
    out = input - target
    out = out * weights.expand_as(out)
    # expand_as because weights are prob not defined for mini-batch
   loss = out.sum(0) # or sum over whatever dimensions
   return loss
5 Likes

Oh, I see. There’s no magic to the loss functions. You just calculate whatever loss you want using predefined Torch functions and then call backward on the loss. That’s super easy. Thanks!

1 Like

what if I want L2 Loss ?

just do it like this?

def weighted_mse_loss(input,target,weights):
    out = (input-target)**2
    out = out * weights.expand_as(out)
    loss = out.sum(0) # or sum over whatever dimensions
    return loss

right ?



3 Likes

@liygcheng yes. that’s correct.

So long as all the computations are done on Variables, which will ensure the gradients can be computed.

1 Like

Should the “weights” also be wrapped as a Variable in order for auto-grad to work?

Same question, what is your thoughts now ?

yes, the “weights” should also be wrapped as a Variable

Thanks. Actually, variablization must be done if we need them operate at gpu(s).

Thanks for this!
Can the Weights be learned as well? Would the model need to output the weights in that case? How would the code look?

If the weights are learned I would be concerned about the overall training, since the model could try to set them to zero to create a zero loss (or generally lower them).
What is your use case to try to train the weights and how would you avoid a trained zero weight matrix?

I am trying an unsupervised approach to identify “important” areas of an image without explicitly pre-defining the weight-map. For the other question, I am not sure yet. I am thinking of a simple form like this - LW+(1/W) where L W is elementwise multiplication of learned weights and the MSE/MAE loss. and the second term prevents going too close to 0?