# Beta variational autoencoder

Hi All

has anyone worked with “Beta-variational autoencoder”?

I do have no idea to start from where?
Looked through Web to see someone else had done this in pytorch however, could not find anything,

I guess the main difference between Beta and regular one would be loss calculation. Do you have any idae?

Thanks

Here’s an old implementation of mine (pytorch v 1.0 I guess or maybe 0.4). this is also known as disentagled variational auto encoder:

``````# Disentagled Variational Autoencoders or (β-VAE)
# https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#contractive-autoencoder
# https://openreview.net/forum?id=Sy2fzU9gl
# https://arxiv.org/pdf/1901.09415.pdf
# https://arxiv.org/abs/1606.05579

# The basic idea in disentagled vae is that, we want different neurons in our latent distribution
# to be uncorollated, they all try to learn something different about the input data. In order to implement
# this, the only thing that needs to be added to the vanilla VAE, is a β term.
# previously for the vanilla VAE we had :
#     L = E_q(z|X)[log_p(X|z)] - D_KL[q(z|X)||p(z))]
# Now for the disentagled version (β-VAE) we just add the β like this :
#     L = E_q(z|X)[log_p(X|z)] - βD_KL[q(z|X)||p(z))]
# so to put it simply, in a disentagled vae (B-Vae) the autoencoder will only use a varable if it
# its important

def fc_batchnorm_act(in_, out_, use_bn=True, act=nn.ReLU()):
return nn.Sequential(nn.Linear(in_,out_),
act,
nn.BatchNorm1d(out_) if use_bn else nn.Identity())

class B_VAE(nn.Module):
def __init__(self, embedding_size=5):
super().__init__()
self.embedding_size = embedding_size

# self.fc1 = nn.Linear(28*28, 512)
self.encoder_entry = nn.Sequential(fc_batchnorm_act(28*28,512),
fc_batchnorm_act(512,256),
fc_batchnorm_act(256,128),
fc_batchnorm_act(128,64))
self.fc_mu = nn.Linear(64, embedding_size)
self.fc_std = nn.Linear(64, embedding_size)

self.decoder = nn.Sequential(fc_batchnorm_act(embedding_size, 64),
fc_batchnorm_act(64,128),
fc_batchnorm_act(128,256),
fc_batchnorm_act(256,512),
fc_batchnorm_act(512, 28*28,False,nn.Sigmoid()))
# self.decoder = nn.Sequential(nn.Linear(embedding_size, 512),
#                             nn.ReLU(),
#                             nn.Linear(512, 28*28),
#                             nn.Sigmoid())

def reparameterization_trick(self, mu, logvar):
# divide by two, since we want positive deviation only
std = torch.exp(logvar * 0.5)
# sample epslion from N(0,1)
eps = torch.randn_like(std)
# sampling now can be done by shifting the eps by (adding) the mean
# and scaling it by the variance.
return mu + eps * std

def encode(self, imgs):
imgs = imgs.view(imgs.size(0), -1)
# output = F.relu(self.fc1(imgs))
output = self.encoder_entry(imgs)
# remember we dont use nonlinearities for mu and logvar!
mu = self.fc_mu(output)
logvar = self.fc_std(output)
z = self.reparameterization_trick(mu, logvar)
return z, mu, logvar

def decode(self, z):
reconstructed_imgs = self.decoder(z)
reconstructed_imgs = reconstructed_imgs.view(-1, 1, 28, 28)
return reconstructed_imgs

def forward(self, x):
# encoder
z, mu, logvar = self.encode(x)
# decoder
reconstructed_imgs = self.decode(z)
return reconstructed_imgs, mu, logvar

def loss_disentagled_vae(outputs, imgs, mu, logvar, Beta, reduction='mean', use_mse=False):
# this loss has two parts, a construction loss and a KL divergence loss which
# shows how much distance exists between two given distrubutions.
if reduction=='mean':
if use_mse:
criterion = nn.MSELoss()
else:
criterion = nn.BCELoss(reduction='mean')
recons_loss = criterion(outputs, imgs)
# normalize the reconstruction loss
recons_loss *= 28*28
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# when using mean, we always sum over the last dim
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), -1)
# we use beta and multiply it by our kl term. this is specific to
# disentagled vae and is actually the main reason why the disentaglement
# work
else:
criterion = nn.BCELoss(reduction='sum')
recons_loss = criterion(outputs, imgs)
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recons_loss + (Beta*kl)

epochs = 50

embeddingsize = 5
interval = 2000
reduction='mean'
# beta is a value biger than 1
Beta = 5.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = B_VAE(embeddingsize).to(device)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50)

for e in range(epochs):
for i, (imgs, labels) in enumerate(dataloader_train):
imgs = imgs.to(device)
preds,mu, logvar = model(imgs)

loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)

loss.backward()
optimizer.step()
if i% interval ==0:
loss = loss/len(img) if reduction=='sum' else loss
f'\tloss: {loss.item():.4f}'
f'\tlr: {scheduler.get_lr()}')
scheduler.step()
#%%
# test
img_pairs = []
losses = []
interval = 10
for i, (imgs, labels) in enumerate(dataloader_test):
imgs = imgs.to(device)
preds, mu, logvar = model(imgs)
loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)
losses.append({'val_loss':loss.item()})

f'\tloss: {(loss).item():.4f}')

if i%interval==0:
reconstructeds = preds.cpu().detach().view(-1, 1, 28, 28)
imgs = imgs[:20].cpu().detach().numpy()
recons = reconstructeds[:20].numpy()
pairs = np.array([np.dstack((img1,img2)) for img1, img2 in zip(imgs,recons)])
img_pairs.append(pairs)
#%%
# import pandas as pd
# pd.DataFrame(losses).plot()
model.eval()
# create sample image
z = torch.randn(size=(3, model.embedding_size)).to(device)
preds = model.decode(z).cpu().detach()
img = make_grid(preds)
plt.imshow(img.numpy().transpose(1,2,0))
#%%
# visualize latent space
n = 1
z = torch.randn(size=(n,model.embedding_size)).to(device)
print(z.shape)
#%%
fig = plt.figure()
preds = model.decode(z).cpu().detach()
img_latent_space = make_grid(preds,nrow=5).numpy().transpose(1,2,0)
ax.imshow(img_latent_space)
#%%
def change_latentvariable(z, n=3, count = 3, dim=0):
z_new = torch.zeros(size=(n, count, z.size(-1)))
for i in range(count):
z_new[:,i,:] = z[:, :]
z_new[:,i, dim]=  z_new[:,i, dim] - (0.2*i)
return z_new

def show_manifold(z, n, count, dim , device):
fig = plt.figure(figsize=(5,5))
latent_space_manifold = change_latentvariable(z, n, count, dim).to(device)
print(latent_space_manifold.shape)
preds = model.decode(latent_space_manifold.view(-1,model.embedding_size)).cpu().detach()
img_latent_space_man = make_grid(preds,nrow=count).numpy().transpose(1,2,0)
ax.imshow(img_latent_space_man)

show_manifold(z, n=n, count=5, dim=3, device=device)
show_manifold(z, n=n,  count=5, dim=1, device=device)
show_manifold(z, n=n,  count=5, dim=2, device=device)
show_manifold(z, n=n,  count=5, dim=3, device=device)
show_manifold(z, n=n,  count=5, dim=4, device=device)
# visualize the 2d manifold
#%%
# variations over the latent variable :
z_dim = model.embedding_size
sigma_mean = 2.0*torch.ones((z_dim))
mu_mean = torch.zeros((z_dim))

# Save generated variable images :
nbr_steps = 8
gen_images = torch.ones( (nbr_steps, 1, 28, 28) )

for latent in range(z_dim) :
#var_z0 = torch.stack( [mu_mean]*nbr_steps, dim=0)
var_z0 = torch.zeros(nbr_steps, z_dim)
val = mu_mean[latent]-sigma_mean[latent]
step = 2.0*sigma_mean[latent]/nbr_steps
print(latent, mu_mean[latent]-sigma_mean[latent], mu_mean[latent], mu_mean[latent]+sigma_mean[latent])
for i in range(nbr_steps) :
var_z0[i] = mu_mean
var_z0[i][latent] = val
val += step

var_z0 = var_z0.to(device)

gen_images_latent = model.decode(var_z0)
gen_images_latent = gen_images_latent.cpu().detach()
gen_images = torch.cat( [gen_images, gen_images_latent], dim=0)

img = make_grid(gen_images)
plt.imshow(img.cpu().numpy().transpose(1,2,0))
#%%
def plot_latentspace(num_rows,num_cols=9,figure_width=10.5,image_height=1.5):
fig = plt.figure(figsize=(figure_width, image_height * num_rows))

for i in range(num_rows):
z_i_values = np.linspace(-3.0, 3.0, num_cols)
z_i = z[i].detach().cpu().numpy()
z_diffs = np.abs((z_i_values - z_i))
j_min = np.argmin(z_diffs)
for j in range(num_cols):
z_i_value = z_i_values[j]
if j != j_min:
z[i] = z_i_value
else:
z[i] = float(z_i)

x = model.decode(z).detach().cpu().numpy()

ax = fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1)
ax.imshow(x, cmap='gray')

if i == 0 or j == j_min:
ax.set_title(f'{z[i]:.1f}')

if j == j_min:
ax.set_xticks([], [])
ax.set_yticks([], [])
color = 'mediumseagreen'
width = 8
for side in ['top', 'bottom', 'left', 'right']:
ax.spines[side].set_color(color)
ax.spines[side].set_linewidth(width)
else:
ax.axis('off')
z[i] = float(z_i)

plt.tight_layout()
num_rows = z.shape[-1]
plot_latentspace(num_rows)
``````

Moafagh bashi!

2 Likes

Salam
Ahvale shoma?
ye soal kochik dashtam.
to ein code ha ye ja zadin embedding_size ein yani chi? chera tarif shode?

Mamnun Hi,
`embedding_size`, is kind of obvious when you recall this is just an autoencoder, it just specifies how many features you want the autoencoder to compress (your representation) into and recover from and it directly affects your end result as well.
if this is not what you wanted to know, please be more specific.

1 Like

Hi,
I got that.
Thank you.
Actually I changed your codes a little and convert them to convolutional B-VAE.
Now I have no idea how to plot latent space below you can see my (your) codes class B_VAE(nn.Module):

def init(self,zdims):
super().init()
self.zdims = zdims
#Encoder layer
self.encoder = nn.Sequential(nn.Conv1d(1,4,kernel_size = 4, stride = 3),
nn.MaxPool1d(kernel_size = 4, stride = 3),
nn.Tanh(),

``````        nn.Conv1d(4,8,kernel_size=4, stride = 2),
nn.MaxPool1d(kernel_size = 4,stride = 2),
nn.Tanh(),

nn.Conv1d(8,12,kernel_size = 4, stride = 2),
nn.MaxPool1d(kernel_size = 4,stride = 1),
nn.Tanh(),

nn.Conv1d(12,16,kernel_size = 4, stride = 1),
nn.MaxPool1d(kernel_size = 3,stride = 1),
nn.Tanh())

#Cov FC LAYER
self.fc_mu = nn.Linear(16,self.zdims)
self.fc_std = nn.Linear(16,self.zdims)

#Deconv FC LAYER

self.fc_d = nn.Linear(self.zdims,16)

#Decoder layer
self.decoder = nn.Sequential(nn.ConvTranspose1d(16,12,kernel_size = 5, stride = 2),
nn.Tanh(),

nn.ConvTranspose1d(12,8,kernel_size = 5,stride = 3),
nn.Tanh(),

nn.ConvTranspose1d(8,4,kernel_size= 12,stride = 4),
nn.Tanh(),

nn.ConvTranspose1d(4,1,kernel_size = 18, stride = 9),
nn.Tanh())
``````

def parameterization_trick(self,mu,logvar):

``````  std = torch.exp(logvar*0.5)
# sample epslion from N(0,1)
eps = torch.randn_like(std)
# sampling now can be done by shifting the eps by (adding) the mean
# and scaling it by the variance.
return mu+eps*std
``````

def encode(self,imgs):

``````  output = self.encoder(imgs)
output = output.view(-1,16)
mu = self.fc_mu(output)
logvar = self.fc_std(output)
z = self.parameterization_trick(mu,logvar)
return mu,logvar,z
``````

def decode(self,z):

``````  deconv_input = (self.fc_d(z))
deconv_input = deconv_input.view(-1,16,2)
reconstructed_img = self.decoder(deconv_input)
return reconstructed_img
``````

def forward(self,x):
mu,logvar,z = self.encode(x)
reconstructed_img = self.decode(z)
return reconstructed_img, mu, logvar

For visualization you need to compress your representation to a lower dimension that you can plot.
try printing the basic bva that you have and understand how the code works then you can easily later it for your own case.

1 Like