Question about MSELoss

Here I have very simple test code

target = Variable(torch.tensor([
    [1, 0],
    [0, 1],
]).float(), requires_grad=False)

output = Variable(torch.tensor([
    [0, 1],
    [0, 1],
]).float(), requires_grad=True)

# loss = nn.L1Loss()(output, target)
loss = nn.MSELoss()(output, target)
print('LOSS\n', loss)

print(type(loss.grad_fn))

grad = loss.grad_fn(torch.tensor(1).float())
print('GRAD\n', grad)

where the output I got is

LOSS
 tensor(0.5000, grad_fn=<MseLossBackward>)
<class 'MseLossBackward'>
GRAD
 tensor([[-0.5000,  0.5000],
        [ 0.0000,  0.0000]], grad_fn=<MseLossBackwardBackward>)

Based on documentation (https://pytorch.org/docs/stable/nn.html#mseloss) and its gradients definitions (https://stats.stackexchange.com/a/312997),
I thought the right value for loss is 1 because sum of squared difference is 2 and N is 2. However it’s 0.5.
Similarly, I was expecting to see [[-2, 2] [0,0]] for gradients but it is different.

I suspect that due to the averaging operation, grad_fn does some rescaling based on the gradient I provide (1 in this case). I want to understand that logic as well.

Does anyone know details about these Loss function implementation?
Thank you

For what concerns the loss value I believe the printed value (0.5) is correct.
In fact if you calculate the loss (output[i]-target[i])^2 element by element you obtain a tensor with values 1, 1, 0, 0 (read row-wise).
So when you sum up you obtain 2 and you have to divide by the total number of elements (4) because you choose implicitly reduction='mean' when you create the MSELoss class. So in the end you obtain 0.5.

I do not understand very well the meaning of the last line of your code:

grad = loss.grad_fn(torch.tensor(1).float())

1 Like

Thank you for your response.
Ya I suspected that its number of elements but the document say that it is batch size so I got confused.

grad = loss.grad_fn(torch.tensor(1).float())

This seems to be computing gradient by calling Backward and multiply the value I provide (which is 1 in this example).