nn.MSELoss() produces different output compared to a mse_loss function

Hi All,

I am getting the following issue. I am using nn.MSELoss() to calculate the loss as well as another function 2 calculate the loss. When I just check with one sample, it produces two different results. What could be the problem? (I commented the place where I have issue)

from networks.auto_encoder import AutoEncoder
from layers.encoder import Encoder
from networks.decoder import Decoder
import torch
from dataset.DamadicEvaluatorDataset import DamadicEvaluatorDataset
import torch.nn as nn
from torch.utils.data import DataLoader
from math import sqrt
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable

def mse_loss(inputs, target):
    return torch.sum((inputs - target) ** 2)

def get_loss_list(d_p, e_p):
    ae_path = 'models/auto-encoder/auto_encoder_f.h5'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = Encoder()
    decoder = Decoder()
    model = AutoEncoder(encoder=encoder, decoder=decoder)
    model.load_state_dict(torch.load(ae_path, map_location=torch.device('cpu')))
    model = model.to(device)
    batch_size = 128
    ds = DamadicEvaluatorDataset(path_list_scaler=d_p, path_list=e_p)
    data_loader = DataLoader(ds, shuffle=True, batch_size=batch_size)

    criterion = nn.MSELoss()
    loss_t = []
    outs = []

    for i, data in enumerate(data_loader):
        out = model(data.float())
        for index in range(out.size(0)):
            loss = mse_loss(out[index], data[index].float())
            loss_1 = criterion(out[index], data[index].float())
            print(loss, loss_1.item()) # these 2 should be same, but I am getting different values
    return loss_t, outs



I think the issue is that you are using sum to reduce the MSE loss.

In nn.MSELoss (please see the linked documentation), you can specify an argument reduction which by default is mean. So, you may need to explicitly state you want reduction='sum'.

criterion = nn.MSELoss(reduction='sum')


1 Like

Yay. Thanks for your time. That is correct. I fixed the issues.