Confused about the MSE loss implementation

Looking at the implementation of MSE loss: https://pytorch.org/docs/stable/_modules/torch/nn/functional.html#mse_loss

I’m pretty confused about this bit here:

    if target.requires_grad:
        ret = (input - target) ** 2
        if reduction != 'none':
            ret = torch.mean(ret) if reduction == 'mean' else torch.sum(ret)
    else:
        expanded_input, expanded_target = torch.broadcast_tensors(input, target)
       ret = torch._C._nn.mse_loss(expanded_input, expanded_target, _Reduction.get_enum(reduction))

Why exactly is the torch._C mse_loss being used if the tensor doesn’t need a gradient (so, presumably, if .backwards doesn’t need to be called on the loss) ?

If so, is it faster, or rather, fast enough to justify the complexity ? Even given that the broadcast operation is needed

Note that the different paths are triggered, if the target requires gradients, not the model output.
Most likely the backward pass for the output and target are not yet implemented in the dispatched torch._C path, so that the loss will be calculated using the Python frontend.

Note that the different paths are triggered, if the target requires gradients, not the model output

Ah, true… but, why would the targets require gradient ?

As in, shouldn’t get gradient be computed on the outputs of the model by comparing them to the targets (in this case via the MSE loss), what’s the point of having a gradient for the target vector since it’s already “correct” ?

You don’t need the gradients in the usual use case, which is also why it wasn’t implemented in the backend in the first place.
However, there seem still to be some use cases as shown e.g. here.

1 Like

Ah, alright, that makes more sense… also potentially explains why a version I implemented using a mask seemed significantly slower than the standard MSELoss.

Thanks a lot for the clarifications :slight_smile: