I am trying to apply the Variational Auto-Encoder (see https://arxiv.org/abs/1312.6114) to the MNIST data set. I am following exactly the paper with the Neural Network architecture and all other parameters and formulas. Unfortunately, after one step, I get an explosion of the loss (the loss is the negative ELBO term here) and an explosion of the gradients. A change of the learning rate, mini batch size and network structures did not help to fix the problem. I wonder what I can further do to avoid this behaviour. Could you help me? I really appreciate every support.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torch.optim import *
import torchvision
import matplotlib.pyplot as plt
import numpy as np
import sys
#Load the data:
dl = DataLoader(torchvision.datasets.MNIST('/data/mnist', train=True, download=True))
tensor = dl.dataset.data
tensor = tensor.to(dtype=torch.float32)
tr = tensor.reshape(tensor.size(0), -1)
tr = tr/255
targets = dl.dataset.targets
targets = targets.to(dtype=torch.long)
#This will be the data which I use:
Traindata = tr[0:50000]
Labels = targets[0:50000]
#For binary VAEs, we use only binary data:
BTraindata= torch.round(Traindata)
#Define the global variables:
minibatch_size=100
latent_dim=2
dim_inp_dec=latent_dim
dim_hidden_dec=100
dim_out_dec=28*28
dim_inp_enc=28*28
dim_hidden_enc=100
train_size=Traindata.size()[0]
#Dataloader:
trainloader=torch.utils.data.DataLoader(Traindata, batch_size=minibatch_size, shuffle=True)
#This is the Variational Auto Encoder class in the case of binary data:
class BVAE(nn.Module):
def __init__(self,latent_dim):
super(BVAE, self).__init__()
#Define the decoder structural parts:
self.dec1=nn.Linear(dim_inp_dec,dim_hidden_dec,bias=True)
self.dec2=nn.Linear(dim_hidden_dec,dim_out_dec,bias=True)
self.decactiv1=nn.Tanh()
self.decactiv2=nn.Sigmoid()
#Define the encoder structural parts:
self.enc1=nn.Linear(dim_inp_enc,dim_hidden_enc,bias=True)
self.enc21=nn.Linear(dim_hidden_enc,latent_dim,bias=True)
self.enc22=nn.Linear(dim_hidden_enc,latent_dim,bias=True)
self.encactiv=nn.Tanh()
self.latent_dim=latent_dim
def encode(self,x):
mean=self.enc21(self.encactiv(self.enc1(x)))
log_var=self.enc22(self.encactiv(self.enc1(x)))
return mean,log_var
def decode(self,x):
x_one=self.dec1(x)
dec_hidden=self.decactiv1(x_one)
output=self.decactiv2(self.dec2(dec_hidden))
return output
def forward(self,x,normsamples):
mean,log_var=self.encode(x)
var=torch.exp(log_var)
sample=mean+torch.exp(log_var)*normsamples
return(self.decode(sample))
#Instead of sampling after the encoder network, we can also replace it by its mean (i.e. ignoring the variance):
def detforward(self,x):
mean=self.encode(x)[0]
return(self.decode(mean))
#The loss of a VAE is the negative ELBO term:
def loss(self,mini_vector,latent_samples):
L=latent_samples.size()[1]
M=mini_vector.size()[0]
if (M!=latent_samples.size()[0]):
sys.exit("Sample vector size does not fit mini_vector size")
ELBO=torch.tensor(0.0,dtype=torch.float)
for i in range(M):
mean,log_var = self.encode(mini_vector[i])
Entropy=0.5*torch.sum(torch.ones(latent_dim)+log_var-torch.pow(mean,2)-torch.exp(log_var))
sum=torch.tensor(0.0,dtype=torch.float)
for l in range(L):
Bern_par=self.decode(mean+torch.sqrt(torch.exp(log_var))*latent_samples[i][l])
sum+=torch.sum(mini_vector[i]*torch.log(Bern_par)+(1-mini_vector[i])*torch.log(1-Bern_par))
#The second term of the Lower bound is:
Recon_term=sum/L
#Add the two terms:
ELBO+=Entropy+Recon_term
ELBO=(train_size/M)*ELBO
return (-ELBO)
#We now define the train function. nsamples is the number of samples per data point (nsamples_per_point):
def train(self,data_loader,learning_rate=0.001,epochs=10,minibatch_size=100,nsamples=1):
for i,data in enumerate(data_loader):
samples=torch.randn(minibatch_size,nsamples,self.latent_dim)
print("-------------Epoch ", i,"---------------")
#Set gradient to zero:
self.zero_grad()
#Compute loss:
loss = self.loss(data,samples)
print("Epoch: ",i," Loss: ",loss.item())
loss.backward()
for p in self.parameters():
with torch.no_grad():
print("Norm of gradient:", torch.sum(p.grad.data**2))
p.data-=learning_rate*p.grad.data
p.grad.zero_()
autoencoder=BVAE(latent_dim)
autoencoder.train(trainloader)
This is the output I get:
-------------Epoch 0 ---------------
Epoch: 0 Loss: 27568040.0
Norm of gradient: tensor(8.4638e+09)
Norm of gradient: tensor(6.8262e+10)
Norm of gradient: tensor(3.3651e+12)
Norm of gradient: tensor(3.2536e+11)
Norm of gradient: tensor(3.1535e+11)
Norm of gradient: tensor(8.2232e+09)
Norm of gradient: tensor(2.7858e+10)
Norm of gradient: tensor(1.4739e+10)
Norm of gradient: tensor(1.4373e+10)
Norm of gradient: tensor(8.9952e+09)
-------------Epoch 1 ---------------
Epoch: 1 Loss: nan
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
-------------Epoch 2 ---------------
Epoch: 2 Loss: nan
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
-------------Epoch 3 ---------------
Epoch: 3 Loss: nan
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)
Norm of gradient: tensor(nan)