Issue with Custom Loss

Hello everyone,

I am currently trying to port my existing (working) keras BNN code to pytorch.

To this end, I have to write a custom NegativeLogLikelihood loss function for a regression setting. My unit test for my loss passes (e.g. for fixed network weights I get the same results and gradients as in my old (working) keras code), but in a simple dummy example (fitting a sinc function) my loss only gives okay results for batch_size == 1 and my network fails to fit sinc properly (at any amount of training iterations) for larger values. Using nn.MSELoss instead works perfectly fine, so I am assuming there is an issue with my loss computation, maybe my gradients are off in some subtle way?

import matplotlib.pyplot as plt
from itertools import islice
try:
    from tqdm import tqdm
except ImportError:
    tqdm = lambda x, total: x
import numpy as np
import torch
from torch.utils import data as data_utils
import torch.nn as nn


class NLLLoss(torch.nn.modules.loss._Loss):

    def __init__(self, parameters, num_datapoints, size_average=False, reduce=True):
        super().__init__(size_average, reduce)
        self.parameters = tuple(parameters)
        self.num_datapoints = num_datapoints

    def log_variance_prior(self, log_variance, mean=1e-6, variance=0.01):
        return torch.mean(
            torch.sum(
                ((-(log_variance - torch.log(torch.tensor(mean))) ** 2) /
                 ((2. * variance))) - 0.5 * torch.log(torch.tensor(variance)),
                dim=1
            )
        )

    def weight_prior(self, parameters, wdecay=1.):
        num_parameters = torch.sum(torch.tensor([
            torch.prod(torch.tensor(parameter.size()))
            for parameter in parameters
        ]))

        log_likelihood = torch.sum(torch.tensor([
            torch.sum(-wdecay * 0.5 * (parameter ** 2))
            for parameter in parameters
        ]))

        return log_likelihood / (num_parameters.float() + 1e-16)

    def forward(self, input, target):
        torch.nn.modules.loss._assert_no_grad(target)

        batch_size, *_ = input.shape
        prediction_mean = input[:, 0].view(-1, 1)

        log_prediction_variance = input[:, 1].view(-1, 1)
        prediction_variance_inverse = 1. / (torch.exp(log_prediction_variance) + 1e-16)

        mean_squared_error = torch.pow(target - prediction_mean, 2)

        log_likelihood = (
            torch.sum(
                torch.sum(
                    -mean_squared_error * 0.5 * prediction_variance_inverse -
                    0.5 * log_prediction_variance,
                    dim=1
                )
            )
        )

        log_likelihood /= batch_size

        log_likelihood += (
            self.log_variance_prior(log_prediction_variance) / self.num_datapoints
        )

        log_likelihood += self.weight_prior(self.parameters) / self.num_datapoints

        return -log_likelihood


#  Helper Functions {{{ #

def infinite_dataloader(dataloader):
    while True:
        yield from dataloader


def tanh_network(input_dimensionality: int):
    class AppendLayer(nn.Module):
        def __init__(self, bias=True, *args, **kwargs):
            super().__init__(*args, **kwargs)
            if bias:
                self.bias = nn.Parameter(torch.Tensor(1, 1))
            else:
                self.register_parameter('bias', None)

        def forward(self, x):
            return torch.cat((x, self.bias * torch.ones_like(x)), dim=1)

    def init_weights(module):
        if type(module) == AppendLayer:
            nn.init.constant_(module.bias, val=np.log(1e-3))
        elif type(module) == nn.Linear:
            nn.init.kaiming_normal_(module.weight, mode="fan_out", nonlinearity="linear")
            nn.init.constant_(module.bias, val=0.0)

    return nn.Sequential(
        nn.Linear(input_dimensionality, 50), nn.Tanh(),
        nn.Linear(50, 50), nn.Tanh(),
        nn.Linear(50, 50), nn.Tanh(),
        nn.Linear(50, 1),
        AppendLayer()
    ).apply(init_weights)
#  }}} Helper Functions #


input_dimensionality, num_datapoints = 1, 100
num_train_steps = 13000

# Set up data
x_train = np.array([
    np.random.uniform(np.zeros(1), np.ones(1), input_dimensionality)
    for _ in range(num_datapoints)
])
y_train = np.sinc(x_train * 10 - 5).sum(axis=1)

# Data Normalization
x_train_, x_mean, x_std = (
    np.true_divide(x_train - np.mean(x_train), np.std(x_train)), np.mean(x_train), np.std(x_train)
)

y_train_, y_mean, y_std = (
    np.true_divide(y_train - np.mean(y_train), np.std(y_train)), np.mean(y_train), np.std(y_train)
)

model = tanh_network(input_dimensionality=input_dimensionality)

# TODO Why does setting batch_size to 1 work with NLL, but setting it to higher values fails?
batch_size = 20  # setting this to 1 gives okay results.
loss_function = NLLLoss(model.parameters(), num_datapoints=num_datapoints)

# NOTE: Using MSE like this also works:
# loss_function = lambda input, target: nn.MSELoss()(input=input[:, 0], target=target)

train_loader = infinite_dataloader(
    data_utils.DataLoader(
        data_utils.TensorDataset(
            torch.from_numpy(x_train_).float(),
            torch.from_numpy(y_train_).float()
        ), batch_size=batch_size
    )
)

optimizer = torch.optim.Adam(model.parameters())

# Train loop
for epoch, (x_batch, y_batch) in tqdm(enumerate(islice(train_loader, num_train_steps)), total=num_train_steps):
    optimizer.zero_grad()
    y_pred = model(x_batch)
    loss = loss_function(input=y_pred, target=y_batch)
    loss.backward()
    optimizer.step()

    if epoch % 100 == 0:
        mse_value = nn.MSELoss()(input=y_pred[:, 0], target=y_batch)
        print("Epoch: {}, Loss: {}, MSE: {}".format(epoch, loss, mse_value))

x_test = np.linspace(0, 1, 100)[:, None]
y_test = np.sinc(x_test * 10 - 5).sum(axis=1)

# Data Normalization
x_test_ = np.true_divide(x_test - x_mean, x_std)
x_test_torch = torch.from_numpy(x_test_).float()
y_test_torch = torch.from_numpy(y_test).float()

# Unnormalize predictions
y_pred = model(x_test_torch).detach().numpy() * y_std + y_mean

plt.plot(x_test[:, 0], y_test, label="true", color="black")
plt.plot(x_train[:, 0], y_train, "ro")

plt.plot(x_test[:, 0], y_pred[:, 0], label="Adam", color="blue")
plt.legend()
plt.show()```

I am new to defining losses myself in pytorch, so any help by more experienced pytorch users would be very appreciated!

Best,
MFreidank