Custom loss function not working at all!

Hello, Im using this loss function:

def loss_eq5(p, alpha, K, glob, ann):
    S = torch.sum(alpha, 1, keepdims=True)
    loglikelihood = torch.sum((p-(alpha/S))**2, 1, keepdims=True) + torch.sum(alpha*(S-alpha)/(S*S*(S+1)), 1, keepdims=True)
    KL_reg = torch.min(torch.tensor(1), glob/ann) * KL((alpha - 1)*(1-p) + 1 , K)
    return loglikelihood + KL_reg

To train MNIST on Pytorch. It comes from a paper which does the same in Tensor Flow and I am “traslating it to Pytorch”. The thing is in the paper works great, also using normal Cross Entropy in MNIST works great too! I’ve seen the output of every single layer and I can’t find the error. The code for the training would be this:

def train(epoch):
    global cur_batch_win
    net.train()
    loss_list, batch_list = [], []
    for i, (images, labels) in enumerate(data_train_loader):
        optimizer.zero_grad()

        output = net(images)

        evidence = nn.ReLU()
        evidencia = evidence(output)
        alpha = evidencia + 1

        p = np.zeros([images.shape[0],K])
        p[np.arange(images.shape[0]), labels] = 1

        loss = torch.mean(loss_eq5(torch.tensor(p), alpha, K, torch.tensor(epoch), torch.tensor(550)))

        loss_list.append(loss.detach().cpu().item())
        batch_list.append(i+1)

        if i % 10 == 0:
            print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))


        loss.backward()
        optimizer.step()

Probably it is a bit messy, but I started working on neural nets and programming in Python just a few months ago.

All of this comes from the next poster which measures uncertainty in a net via evidence:

https://github.com/atilberk/evidential-deep-learning-to-quantify-classification-uncertainty/blob/master/poster.pdf

Full working code:

import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torchvision.datasets.mnist import MNIST
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torchvision.transforms as T
import numpy as np


### LOSS:
def loss_eq5(p, alpha, K, glob, ann):
    S = torch.sum(alpha, 1, keepdims=True)
    loglikelihood = torch.sum((p-(alpha/S))**2, 1, keepdims=True) + torch.sum(alpha*(S-alpha)/(S*S*(S+1)), 1, keepdims=True)
    KL_reg = torch.min(torch.tensor(1), glob/ann) * KL((alpha - 1)*(1-p) + 1 , K)
    return loglikelihood + KL_reg

def KL(alpha, K):
  beta = torch.tensor(np.ones((1,K)))
  S_alpha = torch.sum(alpha,1, keepdims=True)

  KL = torch.sum((alpha - beta)*(torch.digamma(alpha)-torch.digamma(S_alpha)),1,keepdims=True) + \
         torch.lgamma(S_alpha) - torch.sum(torch.lgamma(alpha),1,keepdims=True) + \
         torch.sum(torch.lgamma(beta),1,keepdims=True) - torch.lgamma(torch.sum(beta,1,keepdims=True))
  return KL

### CONVNET:
class C1(nn.Module):
    def __init__(self):
        super(C1, self).__init__()

        self.c1 = nn.Sequential(OrderedDict([
            ('c1', nn.Conv2d(1, 20, kernel_size=(5, 5))),
            ('relu1', nn.ReLU()),
            ('s1', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
        ]))

    def forward(self, img):
        output = self.c1(img)
        return output


class C2(nn.Module):
    def __init__(self):
        super(C2, self).__init__()

        self.c2 = nn.Sequential(OrderedDict([
            ('c2', nn.Conv2d(20, 50, kernel_size=(5, 5))),
            ('relu2', nn.ReLU()),
            ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
        ]))

    def forward(self, img):
        output = self.c2(img)
        return output

class F4(nn.Module):
    def __init__(self):
        super(F4, self).__init__()

        self.f4 = nn.Sequential(OrderedDict([
            ('f4', nn.Linear(1250, 500)),
            ('relu4', nn.ReLU())
        ]))

    def forward(self, img):
        output = self.f4(img)
        return output


class F5(nn.Module):
    def __init__(self):
        super(F5, self).__init__()

        self.f5 = nn.Sequential(OrderedDict([
            ('f5', nn.Linear(500, 10))
        ]))

    def forward(self, img):
        output = self.f5(img)
        return output


class LeNet5(nn.Module):
    """
    Input - 1x32x32
    Output - 10
    """
    def __init__(self):
        super(LeNet5, self).__init__()

        self.c1 = C1()
        self.c2 = C2() 
        self.f4 = F4() 
        self.f5 = F5() 

    def forward(self, img):
        out1 = self.c1(img)
        out2 = self.c2(out1)

        out22 = out2.view(out2.size(0), -1)

        out3 = self.f4(out22)
        output = self.f5(out3)
        return out1, out2, out3, output

###DATA LOADING:

!wget www.di.ens.fr/~lelarge/MNIST.tar.gz
!tar -zxvf MNIST.tar.gz

from torchvision.datasets import MNIST
from torchvision import transforms

data_train = MNIST('./', download=True, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor() ]), train=True)
data_test = MNIST('./' ,download=True, transform=transforms.Compose([ transforms.Resize((32, 32)), transforms.ToTensor() ]), train=False)

data_train_loader = DataLoader(data_train, batch_size=256, shuffle=True, num_workers=8)
data_test_loader = DataLoader(data_test, batch_size=1024, num_workers=8)

net = LeNet5()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=2e-3)
K=10
net = LeNet5()

cur_batch_win = None
cur_batch_win_opts = {
    'title': 'Epoch Loss Trace',
    'xlabel': 'Batch Number',
    'ylabel': 'Loss',
    'width': 1200,
    'height': 600,
}


def train(epoch):
    global cur_batch_win
    net.train()
    loss_list, batch_list = [], []
    for i, (images, labels) in enumerate(data_train_loader):
        optimizer.zero_grad()

        out1, out2, out3, output = net(images)

        evidence = nn.ReLU()
        evidencia = evidence(output)
        alpha = evidencia + 1

        p = np.zeros([images.shape[0],K])
        p[np.arange(images.shape[0]), labels] = 1

        loss = torch.mean(loss_eq5(torch.tensor(p), alpha, K, torch.tensor(epoch), torch.tensor(550)))

        loss_list.append(loss.detach().cpu().item())
        batch_list.append(i+1)

        if i % 10 == 0:
            print('Train - Epoch %d, Batch: %d, Loss: %f' % (epoch, i, loss.detach().cpu().item()))


        loss.backward()
        optimizer.step()

train(1)