 # Beta variational autoencoder

Hi All

has anyone worked with “Beta-variational autoencoder”?

I do have no idea to start from where?
Looked through Web to see someone else had done this in pytorch however, could not find anything,

I guess the main difference between Beta and regular one would be loss calculation. Do you have any idae?

Thanks

Here’s an old implementation of mine (pytorch v 1.0 I guess or maybe 0.4). this is also known as disentagled variational auto encoder:

``````# Disentagled Variational Autoencoders or (β-VAE)
# https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#contractive-autoencoder
# https://openreview.net/forum?id=Sy2fzU9gl
# https://arxiv.org/pdf/1901.09415.pdf
# https://arxiv.org/abs/1606.05579

# The basic idea in disentagled vae is that, we want different neurons in our latent distribution
# to be uncorollated, they all try to learn something different about the input data. In order to implement
# this, the only thing that needs to be added to the vanilla VAE, is a β term.
# previously for the vanilla VAE we had :
#     L = E_q(z|X)[log_p(X|z)] - D_KL[q(z|X)||p(z))]
# Now for the disentagled version (β-VAE) we just add the β like this :
#     L = E_q(z|X)[log_p(X|z)] - βD_KL[q(z|X)||p(z))]
# so to put it simply, in a disentagled vae (B-Vae) the autoencoder will only use a varable if it
# its important

def fc_batchnorm_act(in_, out_, use_bn=True, act=nn.ReLU()):
return nn.Sequential(nn.Linear(in_,out_),
act,
nn.BatchNorm1d(out_) if use_bn else nn.Identity())

class B_VAE(nn.Module):
def __init__(self, embedding_size=5):
super().__init__()
self.embedding_size = embedding_size

# self.fc1 = nn.Linear(28*28, 512)
self.encoder_entry = nn.Sequential(fc_batchnorm_act(28*28,512),
fc_batchnorm_act(512,256),
fc_batchnorm_act(256,128),
fc_batchnorm_act(128,64))
self.fc_mu = nn.Linear(64, embedding_size)
self.fc_std = nn.Linear(64, embedding_size)

self.decoder = nn.Sequential(fc_batchnorm_act(embedding_size, 64),
fc_batchnorm_act(64,128),
fc_batchnorm_act(128,256),
fc_batchnorm_act(256,512),
fc_batchnorm_act(512, 28*28,False,nn.Sigmoid()))
# self.decoder = nn.Sequential(nn.Linear(embedding_size, 512),
#                             nn.ReLU(),
#                             nn.Linear(512, 28*28),
#                             nn.Sigmoid())

def reparameterization_trick(self, mu, logvar):
# divide by two, since we want positive deviation only
std = torch.exp(logvar * 0.5)
# sample epslion from N(0,1)
eps = torch.randn_like(std)
# sampling now can be done by shifting the eps by (adding) the mean
# and scaling it by the variance.
return mu + eps * std

def encode(self, imgs):
imgs = imgs.view(imgs.size(0), -1)
# output = F.relu(self.fc1(imgs))
output = self.encoder_entry(imgs)
# remember we dont use nonlinearities for mu and logvar!
mu = self.fc_mu(output)
logvar = self.fc_std(output)
z = self.reparameterization_trick(mu, logvar)
return z, mu, logvar

def decode(self, z):
reconstructed_imgs = self.decoder(z)
reconstructed_imgs = reconstructed_imgs.view(-1, 1, 28, 28)
return reconstructed_imgs

def forward(self, x):
# encoder
z, mu, logvar = self.encode(x)
# decoder
reconstructed_imgs = self.decode(z)
return reconstructed_imgs, mu, logvar

def loss_disentagled_vae(outputs, imgs, mu, logvar, Beta, reduction='mean', use_mse=False):
# this loss has two parts, a construction loss and a KL divergence loss which
# shows how much distance exists between two given distrubutions.
if reduction=='mean':
if use_mse:
criterion = nn.MSELoss()
else:
criterion = nn.BCELoss(reduction='mean')
recons_loss = criterion(outputs, imgs)
# normalize the reconstruction loss
recons_loss *= 28*28
# Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
# https://arxiv.org/abs/1312.6114
# -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
# when using mean, we always sum over the last dim
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), -1)
# we use beta and multiply it by our kl term. this is specific to
# disentagled vae and is actually the main reason why the disentaglement
# work
else:
criterion = nn.BCELoss(reduction='sum')
recons_loss = criterion(outputs, imgs)
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recons_loss + (Beta*kl)

epochs = 50

embeddingsize = 5
interval = 2000
reduction='mean'
# beta is a value biger than 1
Beta = 5.
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = B_VAE(embeddingsize).to(device)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50)

for e in range(epochs):
for i, (imgs, labels) in enumerate(dataloader_train):
imgs = imgs.to(device)
preds,mu, logvar = model(imgs)

loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)

loss.backward()
optimizer.step()
if i% interval ==0:
loss = loss/len(img) if reduction=='sum' else loss
f'\tloss: {loss.item():.4f}'
f'\tlr: {scheduler.get_lr()}')
scheduler.step()
#%%
# test
img_pairs = []
losses = []
interval = 10
for i, (imgs, labels) in enumerate(dataloader_test):
imgs = imgs.to(device)
preds, mu, logvar = model(imgs)
loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)
losses.append({'val_loss':loss.item()})

f'\tloss: {(loss).item():.4f}')

if i%interval==0:
reconstructeds = preds.cpu().detach().view(-1, 1, 28, 28)
imgs = imgs[:20].cpu().detach().numpy()
recons = reconstructeds[:20].numpy()
pairs = np.array([np.dstack((img1,img2)) for img1, img2 in zip(imgs,recons)])
img_pairs.append(pairs)
#%%
# import pandas as pd
# pd.DataFrame(losses).plot()
model.eval()
# create sample image
z = torch.randn(size=(3, model.embedding_size)).to(device)
preds = model.decode(z).cpu().detach()
img = make_grid(preds)
plt.imshow(img.numpy().transpose(1,2,0))
#%%
# visualize latent space
n = 1
z = torch.randn(size=(n,model.embedding_size)).to(device)
print(z.shape)
#%%
fig = plt.figure()
preds = model.decode(z).cpu().detach()
img_latent_space = make_grid(preds,nrow=5).numpy().transpose(1,2,0)
ax.imshow(img_latent_space)
#%%
def change_latentvariable(z, n=3, count = 3, dim=0):
z_new = torch.zeros(size=(n, count, z.size(-1)))
for i in range(count):
z_new[:,i,:] = z[:, :]
z_new[:,i, dim]=  z_new[:,i, dim] - (0.2*i)
return z_new

def show_manifold(z, n, count, dim , device):
fig = plt.figure(figsize=(5,5))
latent_space_manifold = change_latentvariable(z, n, count, dim).to(device)
print(latent_space_manifold.shape)
preds = model.decode(latent_space_manifold.view(-1,model.embedding_size)).cpu().detach()
img_latent_space_man = make_grid(preds,nrow=count).numpy().transpose(1,2,0)
ax.imshow(img_latent_space_man)

show_manifold(z, n=n, count=5, dim=3, device=device)
show_manifold(z, n=n,  count=5, dim=1, device=device)
show_manifold(z, n=n,  count=5, dim=2, device=device)
show_manifold(z, n=n,  count=5, dim=3, device=device)
show_manifold(z, n=n,  count=5, dim=4, device=device)
# visualize the 2d manifold
#%%
# variations over the latent variable :
z_dim = model.embedding_size
sigma_mean = 2.0*torch.ones((z_dim))
mu_mean = torch.zeros((z_dim))

# Save generated variable images :
nbr_steps = 8
gen_images = torch.ones( (nbr_steps, 1, 28, 28) )

for latent in range(z_dim) :
#var_z0 = torch.stack( [mu_mean]*nbr_steps, dim=0)
var_z0 = torch.zeros(nbr_steps, z_dim)
val = mu_mean[latent]-sigma_mean[latent]
step = 2.0*sigma_mean[latent]/nbr_steps
print(latent, mu_mean[latent]-sigma_mean[latent], mu_mean[latent], mu_mean[latent]+sigma_mean[latent])
for i in range(nbr_steps) :
var_z0[i] = mu_mean
var_z0[i][latent] = val
val += step

var_z0 = var_z0.to(device)

gen_images_latent = model.decode(var_z0)
gen_images_latent = gen_images_latent.cpu().detach()
gen_images = torch.cat( [gen_images, gen_images_latent], dim=0)

img = make_grid(gen_images)
plt.imshow(img.cpu().numpy().transpose(1,2,0))
#%%
def plot_latentspace(num_rows,num_cols=9,figure_width=10.5,image_height=1.5):
fig = plt.figure(figsize=(figure_width, image_height * num_rows))

for i in range(num_rows):
z_i_values = np.linspace(-3.0, 3.0, num_cols)
z_i = z[i].detach().cpu().numpy()
z_diffs = np.abs((z_i_values - z_i))
j_min = np.argmin(z_diffs)
for j in range(num_cols):
z_i_value = z_i_values[j]
if j != j_min:
z[i] = z_i_value
else:
z[i] = float(z_i)

x = model.decode(z).detach().cpu().numpy()

ax = fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1)
ax.imshow(x, cmap='gray')

if i == 0 or j == j_min:
ax.set_title(f'{z[i]:.1f}')

if j == j_min:
ax.set_xticks([], [])
ax.set_yticks([], [])
color = 'mediumseagreen'
width = 8
for side in ['top', 'bottom', 'left', 'right']:
ax.spines[side].set_color(color)
ax.spines[side].set_linewidth(width)
else:
ax.axis('off')
z[i] = float(z_i)

plt.tight_layout()