Hi All
Recently I trained a convolutional Beta- VAE. However it dose not work at all, I mean I can no more reconstruct my image. Before that I have trained a classical AE and that one worked perfectly.
Here (Beta-VAE) I tried to check my NN by setting BETA = 0, 1. In the case that BETA=0 reconstruction improved a lot (still not same as classical AE) however in the case of BETA =1 (VAE) again I saw that image reconstruction did not work.
about my data set I have to say I am working with Hyperspecral images (every single pixel is a signal), and the CONV layers that used were conv1d.
my training set was ~ 45,000 pixels. I am a little bit suspicious that may be size of my training set is not big enough, because if there is any problem with my NN by setting BETA = 0 again I would expect to see bad results but I did not/.
Any comments or similar experience?
below you can see my codes:
class B_VAE(nn.Module):
def __init__(self,zdims):
super().__init__()
self.zdims = zdims
#Encoder layer
self.encoder = nn.Sequential(nn.Conv1d(1,5,kernel_size = 4, stride = 3),
nn.MaxPool1d(kernel_size = 4, stride = 3),
nn.Tanh(),
nn.Conv1d(5,10,kernel_size=4, stride = 2),
nn.MaxPool1d(kernel_size = 4,stride = 2),
nn.Tanh(),
nn.Conv1d(10,15,kernel_size = 4, stride = 2),
nn.MaxPool1d(kernel_size = 4,stride = 1),
nn.Tanh(),
nn.Conv1d(15,20,kernel_size = 4, stride = 1),
nn.MaxPool1d(kernel_size = 3,stride = 1),
nn.Tanh()
)
#Decoder layer
self.decoder = nn.Sequential(nn.ConvTranspose1d(20,15,kernel_size = 5, stride = 2),
nn.Tanh(),
nn.ConvTranspose1d(15,10,kernel_size = 5,stride = 3),
nn.Tanh(),
nn.ConvTranspose1d(10,5,kernel_size= 12,stride = 4),
nn.Tanh(),
nn.ConvTranspose1d(5,1,kernel_size = 18, stride = 9),
nn.Tanh())
# Cov FC LAYER
self.fc_mu = nn.Linear(40,self.zdims)
self.fc_std = nn.Linear(40,self.zdims)
# Deconv FC LAYER
self.fc_d = nn.Linear(self.zdims,40)
def parameterization_trick(self,mu,logvar):
std = torch.exp(logvar*0.5)
# sample epslion from N(0,1)
eps = torch.randn_like(std)
# sampling now can be done by shifting the eps by (adding) the mean
# and scaling it by the variance.
z = eps.mul(std).add(mu)
return z
def encode(self,imgs):
output = self.encoder(imgs)
output = output.view(-1,40)
mu = self.fc_mu(output)
logvar = self.fc_std(output)
z = self.parameterization_trick(mu,logvar)
return mu,logvar,z
def decode(self,z):
deconv_input = (self.fc_d(z))
deconv_input = deconv_input.view(-1,20,2)
reconstructed_img = self.decoder(deconv_input)
return reconstructed_img
def forward(self,x):
mu,logvar,z = self.encode(x)
reconstructed_img = self.decode(z)
return reconstructed_img, mu, logvar
def loss_disentangled_vae(outputs, imgs, mu, logvar, Beta):
criterion = nn.MSELoss()
recons_loss = criterion(outputs,imgs)
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return (recons_loss+(Beta*kl))
my Training is here as well:
###training loop, scheduler learning was used to find the best learning rate. in evert 250 epoch learning rate changes by the factor of 0.1
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=280, gamma=0.7, last_epoch=-1)
epoch_num = 0
train_error = []
test_error = []
best_error = 100
for epoch in range(num_epochs):
print(scheduler.get_last_lr())
loss_total = 0
test_loss_total = 0
model.train()
for batch_idx, sample in enumerate(train_loader):
inp = sample
output,mu,logvar = model(inp)
loss = loss_disentangled_vae(output,inp,mu,logvar,Beta=Beta)
loss_total += loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
loss_total = loss_total / (batch_idx+1)
train_error.append(loss_total.item())
model.eval()
with torch.no_grad():
for batch_idx, sample in enumerate(test_loader):
inp = sample
output,mu,logvar = model(inp)
loss = loss_disentangled_vae(output,inp,mu,logvar,Beta=Beta)
test_loss_total += loss
test_loss_total = test_loss_total / (batch_idx+1)
if loss < best_error:
best_error = loss
best_epoch = epoch
print('Best loss at epoch', best_epoch)
model_save_name = '2020-07-20 Li_sample_no_image_processing B=1 Tanh 20 channel (B-VAE)'
path = F"/content/drive/My Drive/{model_save_name}"
torch.save(model.state_dict(), path)
test_error.append(test_loss_total.item())
if epoch%10 == 9:
epoch_num+=epoch_num
print (('\r Train Epoch : {}/{} \tLoss : {:.4f}'.format (epoch+1,num_epochs,loss_total)))
print (('\r Test Epoch : {}/{} \tLoss : {:.4f}'.format (epoch+1,num_epochs,test_loss_total)))
print('best loss', best_error, ' at epoch ', best_epoch)
plt.plot(train_error)
plt.plot(test_error)
plt.xlabel('Number of iteration')
plt.ylabel('Loss (MSE)')
plt.title('Loss vs # of iterations')