How to implement weighted mean square error?

(Mohammad Mehdi Derakhshani) #1

Hello guys, I would like to implement below loss function which is a weighted mean square loss function:

How can I implement such a lost function in pytorch? In another words, Is there any way to use nn.MSELoss to achieve to my mentioned loss function?

(Francisco Massa) #2

You can probably use a combination of tensor operations to compute your loss.
For example

def mse_loss(input, target):
    return torch.sum((input - target) ** 2)

def weighted_mse_loss(input, target, weight):
    return torch.sum(weight * (input - target) ** 2)

(Mohammad Mehdi Derakhshani) #3

What I have understood from your above snippet is to create a nn.Module and write your above code in its forward function. Am I right? Because I would like to use autograd for backprop!

(Francisco Massa) #4

You don’t need to write a nn.Module for that, a simple function is enough.
Note that backpropagation is handled by Variables, and not by nn.Module.

So this is perfectly fine

def pow(input, n):
    return input ** n

t = Variable(torch.rand(1), requires_grad=True)
t2 = pow(input, 2)
t6 = pow(t2, 3)

(Mohammad Mehdi Derakhshani) #5

Thank you. :slight_smile: I appreciate your response!

Is var[0,:] unpacking for using autograd based backprop or not?
(Will) #6

Were you able to solve this? When i try to set this function as my criterion i get an error saying the function requires the 3 inputs which obviously aren’t calculated yet until training. Maybe this is why it needs to go in nn.module ?

(Will) #7

This is the ugly hack i created that works for this problem. 16 outputs with the first output being weighted 8/16ths and the remaining outputs weighted 0.5/15.

I’m sure there’s a better way to do this but if you’re in a hurry this works.

def weighted_mse_loss(input,target):
    #alpha of 0.5 means half weight goes to first, remaining half split by remaining 15
    weights = Variable(torch.Tensor([0.5,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15,0.5/15])).cuda()  
    pct_var = (input-target)**2
    out = pct_var * weights.expand_as(target)
    loss = out.mean() 
    return loss