Hi All
has anyone worked with “Beta-variational autoencoder”?
Hi All
has anyone worked with “Beta-variational autoencoder”?
what is your problem?
I do have no idea to start from where?
Looked through Web to see someone else had done this in pytorch however, could not find anything,
I guess the main difference between Beta and regular one would be loss calculation. Do you have any idae?
Thanks
Here’s an old implementation of mine (pytorch v 1.0 I guess or maybe 0.4). this is also known as disentagled variational auto encoder:
# Disentagled Variational Autoencoders or (β-VAE)
# good reads : https://towardsdatascience.com/disentanglement-with-variational-autoencoder-a-review-653a891b69bd
# https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#contractive-autoencoder
# https://openreview.net/forum?id=Sy2fzU9gl
# https://arxiv.org/pdf/1901.09415.pdf
# https://arxiv.org/abs/1606.05579
# The basic idea in disentagled vae is that, we want different neurons in our latent distribution
# to be uncorollated, they all try to learn something different about the input data. In order to implement
# this, the only thing that needs to be added to the vanilla VAE, is a β term.
# previously for the vanilla VAE we had :
# L = E_q(z|X)[log_p(X|z)] - D_KL[q(z|X)||p(z))]
# Now for the disentagled version (β-VAE) we just add the β like this :
# L = E_q(z|X)[log_p(X|z)] - βD_KL[q(z|X)||p(z))]
# so to put it simply, in a disentagled vae (B-Vae) the autoencoder will only use a varable if it
# its important
def fc_batchnorm_act(in_, out_, use_bn=True, act=nn.ReLU()):
return nn.Sequential(nn.Linear(in_,out_),
act,
nn.BatchNorm1d(out_) if use_bn else nn.Identity())
class B_VAE(nn.Module):
def __init__(self, embedding_size=5):
super().__init__()
self.embedding_size = embedding_size
# self.fc1 = nn.Linear(28*28, 512)
self.encoder_entry = nn.Sequential(fc_batchnorm_act(28*28,512),
fc_batchnorm_act(512,256),
fc_batchnorm_act(256,128),
fc_batchnorm_act(128,64))
self.fc_mu = nn.Linear(64, embedding_size)
self.fc_std = nn.Linear(64, embedding_size)
self.decoder = nn.Sequential(fc_batchnorm_act(embedding_size, 64),
fc_batchnorm_act(64,128),
fc_batchnorm_act(128,256),
fc_batchnorm_act(256,512),
fc_batchnorm_act(512, 28*28,False,nn.Sigmoid()))
# self.decoder = nn.Sequential(nn.Linear(embedding_size, 512),
# nn.ReLU(),
# nn.Linear(512, 28*28),
# nn.Sigmoid())
def reparameterization_trick(self, mu, logvar):
# divide by two, since we want positive deviation only
std = torch.exp(logvar * 0.5)
# sample epslion from N(0,1)
eps = torch.randn_like(std)
# sampling now can be done by shifting the eps by (adding) the mean
# and scaling it by the variance.
return mu + eps * std
def encode(self, imgs):
imgs = imgs.view(imgs.size(0), -1)
# output = F.relu(self.fc1(imgs))
output = self.encoder_entry(imgs)
# remember we dont use nonlinearities for mu and logvar!
mu = self.fc_mu(output)
logvar = self.fc_std(output)
z = self.reparameterization_trick(mu, logvar)
return z, mu, logvar
def decode(self, z):
reconstructed_imgs = self.decoder(z)
reconstructed_imgs = reconstructed_imgs.view(-1, 1, 28, 28)
return reconstructed_imgs
def forward(self, x):
# encoder
z, mu, logvar = self.encode(x)
# decoder
reconstructed_imgs = self.decode(z)
return reconstructed_imgs, mu, logvar
def loss_disentagled_vae(outputs, imgs, mu, logvar, Beta, reduction='mean', use_mse=False):
# this loss has two parts, a construction loss and a KL divergence loss which
# shows how much distance exists between two given distrubutions.
if reduction=='mean':
if use_mse:
criterion = nn.MSELoss()
else:
criterion = nn.BCELoss(reduction='mean')
recons_loss = criterion(outputs, imgs)
# normalize the reconstruction loss
recons_loss *= 28*28
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# when using mean, we always sum over the last dim
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), -1)
# we use beta and multiply it by our kl term. this is specific to
# disentagled vae and is actually the main reason why the disentaglement
# work
return torch.mean(recons_loss + (Beta*kl))
else:
criterion = nn.BCELoss(reduction='sum')
recons_loss = criterion(outputs, imgs)
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recons_loss + (Beta*kl)
epochs = 50
embeddingsize = 5
interval = 2000
reduction='mean'
# beta is a value biger than 1
Beta = 5.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = B_VAE(embeddingsize).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr =0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50)
for e in range(epochs):
for i, (imgs, labels) in enumerate(dataloader_train):
imgs = imgs.to(device)
preds,mu, logvar = model(imgs)
loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i% interval ==0:
loss = loss/len(img) if reduction=='sum' else loss
print(f'epoch {e}/{epochs} [{i*len(imgs)}/{len(dataloader_train.dataset)} ({100.*i/len(dataloader_train):.2f}%)]'
f'\tloss: {loss.item():.4f}'
f'\tlr: {scheduler.get_lr()}')
scheduler.step()
#%%
# test
test_set_size = len(dataloader_test.dataset)
img_pairs = []
losses = []
interval = 10
with torch.no_grad():
for i, (imgs, labels) in enumerate(dataloader_test):
imgs = imgs.to(device)
preds, mu, logvar = model(imgs)
loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)
losses.append({'val_loss':loss.item()})
print(f'[{i*len(imgs)} / {test_set_size} ({100.*i/len(dataloader_test):.2f}%)]'
f'\tloss: {(loss).item():.4f}')
if i%interval==0:
reconstructeds = preds.cpu().detach().view(-1, 1, 28, 28)
imgs = imgs[:20].cpu().detach().numpy()
recons = reconstructeds[:20].numpy()
pairs = np.array([np.dstack((img1,img2)) for img1, img2 in zip(imgs,recons)])
img_pairs.append(pairs)
#%%
# import pandas as pd
# pd.DataFrame(losses).plot()
model.eval()
# create sample image
z = torch.randn(size=(3, model.embedding_size)).to(device)
preds = model.decode(z).cpu().detach()
img = make_grid(preds)
plt.imshow(img.numpy().transpose(1,2,0))
#%%
# visualize latent space
n = 1
z = torch.randn(size=(n,model.embedding_size)).to(device)
print(z.shape)
#%%
fig = plt.figure()
ax = fig.add_subplot(111)
preds = model.decode(z).cpu().detach()
img_latent_space = make_grid(preds,nrow=5).numpy().transpose(1,2,0)
ax.imshow(img_latent_space)
#%%
def change_latentvariable(z, n=3, count = 3, dim=0):
z_new = torch.zeros(size=(n, count, z.size(-1)))
for i in range(count):
z_new[:,i,:] = z[:, :]
z_new[:,i, dim]= z_new[:,i, dim] - (0.2*i)
return z_new
def show_manifold(z, n, count, dim , device):
fig = plt.figure(figsize=(5,5))
ax = fig.add_subplot(111)
latent_space_manifold = change_latentvariable(z, n, count, dim).to(device)
print(latent_space_manifold.shape)
preds = model.decode(latent_space_manifold.view(-1,model.embedding_size)).cpu().detach()
img_latent_space_man = make_grid(preds,nrow=count).numpy().transpose(1,2,0)
ax.imshow(img_latent_space_man)
show_manifold(z, n=n, count=5, dim=3, device=device)
show_manifold(z, n=n, count=5, dim=1, device=device)
show_manifold(z, n=n, count=5, dim=2, device=device)
show_manifold(z, n=n, count=5, dim=3, device=device)
show_manifold(z, n=n, count=5, dim=4, device=device)
# visualize the 2d manifold
#%%
# variations over the latent variable :
z_dim = model.embedding_size
sigma_mean = 2.0*torch.ones((z_dim))
mu_mean = torch.zeros((z_dim))
# Save generated variable images :
nbr_steps = 8
gen_images = torch.ones( (nbr_steps, 1, 28, 28) )
for latent in range(z_dim) :
#var_z0 = torch.stack( [mu_mean]*nbr_steps, dim=0)
var_z0 = torch.zeros(nbr_steps, z_dim)
val = mu_mean[latent]-sigma_mean[latent]
step = 2.0*sigma_mean[latent]/nbr_steps
print(latent, mu_mean[latent]-sigma_mean[latent], mu_mean[latent], mu_mean[latent]+sigma_mean[latent])
for i in range(nbr_steps) :
var_z0[i] = mu_mean
var_z0[i][latent] = val
val += step
var_z0 = var_z0.to(device)
gen_images_latent = model.decode(var_z0)
gen_images_latent = gen_images_latent.cpu().detach()
gen_images = torch.cat( [gen_images, gen_images_latent], dim=0)
img = make_grid(gen_images)
plt.imshow(img.cpu().numpy().transpose(1,2,0))
#%%
def plot_latentspace(num_rows,num_cols=9,figure_width=10.5,image_height=1.5):
fig = plt.figure(figsize=(figure_width, image_height * num_rows))
for i in range(num_rows):
z_i_values = np.linspace(-3.0, 3.0, num_cols)
z_i = z[0][i].detach().cpu().numpy()
z_diffs = np.abs((z_i_values - z_i))
j_min = np.argmin(z_diffs)
for j in range(num_cols):
z_i_value = z_i_values[j]
if j != j_min:
z[0][i] = z_i_value
else:
z[0][i] = float(z_i)
x = model.decode(z).detach().cpu().numpy()
ax = fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1)
ax.imshow(x[0][0], cmap='gray')
if i == 0 or j == j_min:
ax.set_title(f'{z[0][i]:.1f}')
if j == j_min:
ax.set_xticks([], [])
ax.set_yticks([], [])
color = 'mediumseagreen'
width = 8
for side in ['top', 'bottom', 'left', 'right']:
ax.spines[side].set_color(color)
ax.spines[side].set_linewidth(width)
else:
ax.axis('off')
z[0][i] = float(z_i)
plt.tight_layout()
fig.subplots_adjust(wspace=0.04)
num_rows = z.shape[-1]
plot_latentspace(num_rows)
Moafagh bashi!
Salam
Ahvale shoma?
ye soal kochik dashtam.
to ein code ha ye ja zadin embedding_size ein yani chi? chera tarif shode?
Mamnun
Hi,
embedding_size
, is kind of obvious when you recall this is just an autoencoder, it just specifies how many features you want the autoencoder to compress (your representation) into and recover from and it directly affects your end result as well.
if this is not what you wanted to know, please be more specific.
Hi,
I got that.
Thank you.
Actually I changed your codes a little and convert them to convolutional B-VAE.
Now I have no idea how to plot latent space
below you can see my (your) codes
class B_VAE(nn.Module):
def init(self,zdims):
super().init()
self.zdims = zdims
#Encoder layer
self.encoder = nn.Sequential(nn.Conv1d(1,4,kernel_size = 4, stride = 3),
nn.MaxPool1d(kernel_size = 4, stride = 3),
nn.Tanh(),
nn.Conv1d(4,8,kernel_size=4, stride = 2),
nn.MaxPool1d(kernel_size = 4,stride = 2),
nn.Tanh(),
nn.Conv1d(8,12,kernel_size = 4, stride = 2),
nn.MaxPool1d(kernel_size = 4,stride = 1),
nn.Tanh(),
nn.Conv1d(12,16,kernel_size = 4, stride = 1),
nn.MaxPool1d(kernel_size = 3,stride = 1),
nn.Tanh())
#Cov FC LAYER
self.fc_mu = nn.Linear(16,self.zdims)
self.fc_std = nn.Linear(16,self.zdims)
#Deconv FC LAYER
self.fc_d = nn.Linear(self.zdims,16)
#Decoder layer
self.decoder = nn.Sequential(nn.ConvTranspose1d(16,12,kernel_size = 5, stride = 2),
nn.Tanh(),
nn.ConvTranspose1d(12,8,kernel_size = 5,stride = 3),
nn.Tanh(),
nn.ConvTranspose1d(8,4,kernel_size= 12,stride = 4),
nn.Tanh(),
nn.ConvTranspose1d(4,1,kernel_size = 18, stride = 9),
nn.Tanh())
def parameterization_trick(self,mu,logvar):
std = torch.exp(logvar*0.5)
# sample epslion from N(0,1)
eps = torch.randn_like(std)
# sampling now can be done by shifting the eps by (adding) the mean
# and scaling it by the variance.
return mu+eps*std
def encode(self,imgs):
output = self.encoder(imgs)
output = output.view(-1,16)
mu = self.fc_mu(output)
logvar = self.fc_std(output)
z = self.parameterization_trick(mu,logvar)
return mu,logvar,z
def decode(self,z):
deconv_input = (self.fc_d(z))
deconv_input = deconv_input.view(-1,16,2)
reconstructed_img = self.decoder(deconv_input)
return reconstructed_img
def forward(self,x):
mu,logvar,z = self.encode(x)
reconstructed_img = self.decode(z)
return reconstructed_img, mu, logvar
For visualization you need to compress your representation to a lower dimension that you can plot.
try printing the basic bva that you have and understand how the code works then you can easily later it for your own case.