I am trying to implement a variational autoencoder, but calculating the Kullback Leibler divergence doesn’t work out the way I hoped.
Background: The input is a 1x800 tensor, which will be mapped to a 1x1200 tensor.
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, IterableDataset
import os
from torch import optim
from model import Network
class Network(nn.Module):
def __init__(self, input_dim, output_dim, latent_dim, layer_dim1, layer_dim2):
"""
Parameter:
input_dim (int): number of inputs
output_dim (int): number of outputs
latent_dim (int): number of latent neurons
Layer_dim (int): number of neurons in hidden layers
"""
super(Network, self).__init__()
self.latent_dim = latent_dim
self.enc1 = nn.Linear(input_dim, layer_dim1)
self.enc2 = nn.Linear(layer_dim1, layer_dim2)
self.latent = nn.Linear(layer_dim2, latent_dim*2)
self.dec1 = nn.Linear(latent_dim, layer_dim2)
self.dec2 = nn.Linear(layer_dim2, layer_dim1)
self.out = nn.Linear(layer_dim1, output_dim)
def encoder(self, x):
z = F.elu(self.enc1(x))
z = F.elu(self.enc2(x))
z = self.latent(z)
self.mu = z[0:self.latent_dim]
self.log_sigma = z[self.latent_dim:]
self.sigma = torch.exp(self.log_sigma)
eps = torch.randn(x.size(0), self.latent_dim)
z_sample = self.mu + self.sigma*eps
self.kl_loss = kl_divergence(self.mu, self.log_sigma, dim=self.latent_dim)
return z_sample
def decoder(self, z):
x = F.elu(self.dec1(z))
x = F.elu(self.dec2(x))
return self.out(x)
def forward(self, batch):
self.latent_rep = self.encoder(batch)
dec_input = self.latent_rep
return self.decoder(dec_input)
def kl_divergence(means, log_sigma, dim, target_sigma=0.1):
"""
Computes Kullback–Leibler divergence for arrays of mean and log(sigma)
"""
target_sigma = torch.Tensor([target_sigma])
out = 1 / 2. * torch.mean(torch.mean(1 / target_sigma**2 * means**2 + torch.exp(2 * log_sigma) / target_sigma**2 - 2 * log_sigma + 2 * torch.log(target_sigma), dim=1) - dim)
out = out
return out
model = Network(800,1200,3,800,200)
SAVE_PATH = "trained/model.dat"
epochs = 5
learning_rate = 0.001
optimizer = optim.Adam(model.parameters(),lr=learning_rate, eps=1e-08)
hist_error = []
hist_loss = []
beta = 0.5
for epoch in range(epochs):
epoch_error = []
epoch_loss = []
for i in x:
optimizer.zero_grad()
print(i)
#i = torch.tensor(i)
#i.unsqueeze_(0)
pred = model.forward(i)
loss = torch.mean(torch.sum((pred - y[i]) ** 2))
loss.backward()
optimizer.step()
error = torch.mean(torch.sqrt((pred - y[i]) ** 2)).detach().numpy()
epoch_error.append(error)
epoch_loss.append(loss.data.detach().numpy())
hist_error.append(np.mean(epoch_error))
hist_loss.append(np.mean(epoch_loss))
print("Epoch %d -- loss %f, RMS error %f " % (epoch+1, hist_loss[-1], hist_error[-1]))
The error message I now get is something I looked into last night, but could not resolve:
Traceback (most recent call last):
File "/home/samim/miniconda3/envs/deep/lib/python3.6/site-packages/IPython/core/interactiveshell.py", line 3343, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-2-54864ad18480>", line 1, in <module>
runfile('/home/samim/Documents/train.py', wdir='/home/samim/Documents')
File "/home/samim/.local/share/JetBrains/PyCharm2020.3/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/home/samim/.local/share/JetBrains/PyCharm2020.3/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/home/samim/Documents/train.py", line 65, in <module>
pred = model.forward(i)
File "/home/samim/Documents/model.py", line 49, in forward
self.latent_rep = self.encoder(batch)
File "/home/samim/Documents/model.py", line 39, in encoder
self.kl_loss = kl_divergence(self.mu, self.log_sigma, dim=self.latent_dim)
File "/home/samim/Documents/model.py", line 59, in kl_divergence
out = 1 / 2. * torch.mean(torch.mean(1 / target_sigma**2 * means**2 + torch.exp(2 * log_sigma) / target_sigma**2 - 2 * log_sigma + 2 * torch.log(target_sigma), dim=1) - dim)
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)