Beta variational autoencoder

Hi All

has anyone worked with “Beta-variational autoencoder”?

what is your problem?

I do have no idea to start from where?
Looked through Web to see someone else had done this in pytorch however, could not find anything,

I guess the main difference between Beta and regular one would be loss calculation. Do you have any idae?

Thanks

Here’s an old implementation of mine (pytorch v 1.0 I guess or maybe 0.4). this is also known as disentagled variational auto encoder:

# Disentagled Variational Autoencoders or (β-VAE)
	# good reads : https://towardsdatascience.com/disentanglement-with-variational-autoencoder-a-review-653a891b69bd 
	# https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#contractive-autoencoder
	# https://openreview.net/forum?id=Sy2fzU9gl
	# https://arxiv.org/pdf/1901.09415.pdf
	# https://arxiv.org/abs/1606.05579
	
	# The basic idea in disentagled vae is that, we want different neurons in our latent distribution
	# to be uncorollated, they all try to learn something different about the input data. In order to implement 
	# this, the only thing that needs to be added to the vanilla VAE, is a β term.
	# previously for the vanilla VAE we had : 
	#     L = E_q(z|X)[log_p(X|z)] - D_KL[q(z|X)||p(z))]
	# Now for the disentagled version (β-VAE) we just add the β like this : 
	#     L = E_q(z|X)[log_p(X|z)] - βD_KL[q(z|X)||p(z))]
	# so to put it simply, in a disentagled vae (B-Vae) the autoencoder will only use a varable if it 
	# its important 
	
	def fc_batchnorm_act(in_, out_, use_bn=True, act=nn.ReLU()):
	    return nn.Sequential(nn.Linear(in_,out_),
	                         act,
	                         nn.BatchNorm1d(out_) if use_bn else nn.Identity())
	                         
	class B_VAE(nn.Module):
	    def __init__(self, embedding_size=5):
	        super().__init__()
	        self.embedding_size = embedding_size
	        
	        # self.fc1 = nn.Linear(28*28, 512)
	        self.encoder_entry = nn.Sequential(fc_batchnorm_act(28*28,512),
	                                           fc_batchnorm_act(512,256),
	                                           fc_batchnorm_act(256,128),
	                                           fc_batchnorm_act(128,64))
	        self.fc_mu = nn.Linear(64, embedding_size)
	        self.fc_std = nn.Linear(64, embedding_size)
	
	        self.decoder = nn.Sequential(fc_batchnorm_act(embedding_size, 64),
	                                     fc_batchnorm_act(64,128),
	                                     fc_batchnorm_act(128,256),
	                                     fc_batchnorm_act(256,512),
	                                     fc_batchnorm_act(512, 28*28,False,nn.Sigmoid()))
	        # self.decoder = nn.Sequential(nn.Linear(embedding_size, 512),
	        #                             nn.ReLU(),
	        #                             nn.Linear(512, 28*28),
	        #                             nn.Sigmoid())
	
	    def reparameterization_trick(self, mu, logvar):
	        # divide by two, since we want positive deviation only
	        std = torch.exp(logvar * 0.5)
	        # sample epslion from N(0,1) 
	        eps = torch.randn_like(std)
	        # sampling now can be done by shifting the eps by (adding) the mean 
	        # and scaling it by the variance. 
	        return mu + eps * std
	
	    def encode(self, imgs):
	        imgs = imgs.view(imgs.size(0), -1)
	        # output = F.relu(self.fc1(imgs))
	        output = self.encoder_entry(imgs)
	        # remember we dont use nonlinearities for mu and logvar!
	        mu = self.fc_mu(output)
	        logvar = self.fc_std(output)
	        z = self.reparameterization_trick(mu, logvar)
	        return z, mu, logvar
	
	    def decode(self, z):
	        reconstructed_imgs = self.decoder(z)
	        reconstructed_imgs = reconstructed_imgs.view(-1, 1, 28, 28)
	        return reconstructed_imgs
	
	    def forward(self, x):
	        # encoder 
	        z, mu, logvar = self.encode(x)
	        # decoder
	        reconstructed_imgs = self.decode(z)
	        return reconstructed_imgs, mu, logvar
	
	def loss_disentagled_vae(outputs, imgs, mu, logvar, Beta, reduction='mean', use_mse=False):
	    # this loss has two parts, a construction loss and a KL divergence loss which
	    # shows how much distance exists between two given distrubutions. 
	    if reduction=='mean':
	        if use_mse:
	            criterion = nn.MSELoss()
	        else:
	            criterion = nn.BCELoss(reduction='mean')
	        recons_loss = criterion(outputs, imgs)
	        # normalize the reconstruction loss
	        recons_loss *= 28*28
	        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
	        # https://arxiv.org/abs/1312.6114
	        # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
	        # when using mean, we always sum over the last dim
	        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), -1)
	        # we use beta and multiply it by our kl term. this is specific to 
	        # disentagled vae and is actually the main reason why the disentaglement 
	        # work
	        return torch.mean(recons_loss + (Beta*kl))
	    else:
	        criterion = nn.BCELoss(reduction='sum')
	        recons_loss = criterion(outputs, imgs)
	        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
	        return recons_loss + (Beta*kl)    
	
	epochs = 50
	
	embeddingsize = 5
	interval = 2000
	reduction='mean'
	# beta is a value biger than 1 
	Beta = 5.
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = B_VAE(embeddingsize).to(device)
	optimizer = torch.optim.Adam(model.parameters(), lr =0.001)
	scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50)
	
	for e in range(epochs):
	    for i, (imgs, labels) in enumerate(dataloader_train):
	        imgs = imgs.to(device)
	        preds,mu, logvar = model(imgs)
	
	        loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)
	        
	        optimizer.zero_grad()
	        loss.backward()
	        optimizer.step() 
	        if i% interval ==0:
	            loss = loss/len(img) if reduction=='sum' else loss
	            print(f'epoch {e}/{epochs} [{i*len(imgs)}/{len(dataloader_train.dataset)} ({100.*i/len(dataloader_train):.2f}%)]'
	                  f'\tloss: {loss.item():.4f}'
	                  f'\tlr: {scheduler.get_lr()}')
	    scheduler.step()
	#%% 
	# test
	test_set_size = len(dataloader_test.dataset)
	img_pairs = []
	losses = []
	interval = 10
	with torch.no_grad():
	    for i, (imgs, labels) in enumerate(dataloader_test):
	        imgs = imgs.to(device)
	        preds, mu, logvar = model(imgs)
	        loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)
	        losses.append({'val_loss':loss.item()})
	        
	        print(f'[{i*len(imgs)} / {test_set_size} ({100.*i/len(dataloader_test):.2f}%)]'
	            f'\tloss: {(loss).item():.4f}')
	
	        if i%interval==0:
	            reconstructeds = preds.cpu().detach().view(-1, 1, 28, 28)
	            imgs = imgs[:20].cpu().detach().numpy()
	            recons = reconstructeds[:20].numpy()
	            pairs = np.array([np.dstack((img1,img2)) for img1, img2 in zip(imgs,recons)])
	            img_pairs.append(pairs)
	#%% 
	# import pandas as pd 
	# pd.DataFrame(losses).plot()
	model.eval()
	# create sample image
	z = torch.randn(size=(3, model.embedding_size)).to(device)
	preds = model.decode(z).cpu().detach()
	img = make_grid(preds)
	plt.imshow(img.numpy().transpose(1,2,0))
	#%%
	# visualize latent space
	n = 1 
	z = torch.randn(size=(n,model.embedding_size)).to(device)
	print(z.shape)
	#%%
	fig = plt.figure()
	ax = fig.add_subplot(111)
	preds = model.decode(z).cpu().detach()
	img_latent_space = make_grid(preds,nrow=5).numpy().transpose(1,2,0)
	ax.imshow(img_latent_space)
	#%%
	def change_latentvariable(z, n=3, count = 3, dim=0):
	    z_new = torch.zeros(size=(n, count, z.size(-1)))
	    for i in range(count):
	        z_new[:,i,:] = z[:, :]
	        z_new[:,i, dim]=  z_new[:,i, dim] - (0.2*i)
	    return z_new
	
	def show_manifold(z, n, count, dim , device):
	    fig = plt.figure(figsize=(5,5))
	    ax = fig.add_subplot(111)
	    latent_space_manifold = change_latentvariable(z, n, count, dim).to(device)
	    print(latent_space_manifold.shape)
	    preds = model.decode(latent_space_manifold.view(-1,model.embedding_size)).cpu().detach()
	    img_latent_space_man = make_grid(preds,nrow=count).numpy().transpose(1,2,0)
	    ax.imshow(img_latent_space_man)
	
	show_manifold(z, n=n, count=5, dim=3, device=device)
	show_manifold(z, n=n,  count=5, dim=1, device=device)
	show_manifold(z, n=n,  count=5, dim=2, device=device)
	show_manifold(z, n=n,  count=5, dim=3, device=device)
	show_manifold(z, n=n,  count=5, dim=4, device=device)
	# visualize the 2d manifold 
	#%%
	# variations over the latent variable :
	z_dim = model.embedding_size
	sigma_mean = 2.0*torch.ones((z_dim))
	mu_mean = torch.zeros((z_dim))
	
	# Save generated variable images :
	nbr_steps = 8
	gen_images = torch.ones( (nbr_steps, 1, 28, 28) )
	
	for latent in range(z_dim) :
	    #var_z0 = torch.stack( [mu_mean]*nbr_steps, dim=0)
	    var_z0 = torch.zeros(nbr_steps, z_dim)
	    val = mu_mean[latent]-sigma_mean[latent]
	    step = 2.0*sigma_mean[latent]/nbr_steps
	    print(latent, mu_mean[latent]-sigma_mean[latent], mu_mean[latent], mu_mean[latent]+sigma_mean[latent])
	    for i in range(nbr_steps) :
	        var_z0[i] = mu_mean
	        var_z0[i][latent] = val
	        val += step
	
	    var_z0 = var_z0.to(device)
	
	
	    gen_images_latent = model.decode(var_z0)
	    gen_images_latent = gen_images_latent.cpu().detach()
	    gen_images = torch.cat( [gen_images, gen_images_latent], dim=0)
	
	img = make_grid(gen_images)
	plt.imshow(img.cpu().numpy().transpose(1,2,0))
	#%%
	def plot_latentspace(num_rows,num_cols=9,figure_width=10.5,image_height=1.5):
	    fig = plt.figure(figsize=(figure_width, image_height * num_rows))
	    
	    for i in range(num_rows):
	        z_i_values = np.linspace(-3.0, 3.0, num_cols)
	        z_i = z[0][i].detach().cpu().numpy()
	        z_diffs = np.abs((z_i_values - z_i))
	        j_min = np.argmin(z_diffs)
	        for j in range(num_cols):
	            z_i_value = z_i_values[j]
	            if j != j_min:
	                z[0][i] = z_i_value
	            else:
	                z[0][i] = float(z_i)
	                
	            x = model.decode(z).detach().cpu().numpy()
	            
	            ax = fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1)
	            ax.imshow(x[0][0], cmap='gray')
	            
	            if i == 0 or j == j_min:
	                ax.set_title(f'{z[0][i]:.1f}')
	            
	            if j == j_min:
	                ax.set_xticks([], [])
	                ax.set_yticks([], []) 
	                color = 'mediumseagreen'
	                width = 8
	                for side in ['top', 'bottom', 'left', 'right']:
	                    ax.spines[side].set_color(color)
	                    ax.spines[side].set_linewidth(width)
	            else:
	                ax.axis('off')
	        z[0][i] = float(z_i)
	        
	    plt.tight_layout()
	    fig.subplots_adjust(wspace=0.04)
	num_rows = z.shape[-1]
	plot_latentspace(num_rows)

Moafagh bashi!

2 Likes

Salam
Ahvale shoma?
ye soal kochik dashtam.
to ein code ha ye ja zadin embedding_size ein yani chi? chera tarif shode?

Mamnun :slight_smile:

Hi,
embedding_size, is kind of obvious when you recall this is just an autoencoder, it just specifies how many features you want the autoencoder to compress (your representation) into and recover from and it directly affects your end result as well.
if this is not what you wanted to know, please be more specific.

1 Like

Hi,
I got that.
Thank you.
Actually I changed your codes a little and convert them to convolutional B-VAE.
Now I have no idea how to plot latent space :slight_smile:
below you can see my (your) codes :smiley:

class B_VAE(nn.Module):

def init(self,zdims):
super().init()
self.zdims = zdims
#Encoder layer
self.encoder = nn.Sequential(nn.Conv1d(1,4,kernel_size = 4, stride = 3),
nn.MaxPool1d(kernel_size = 4, stride = 3),
nn.Tanh(),

        nn.Conv1d(4,8,kernel_size=4, stride = 2),
        nn.MaxPool1d(kernel_size = 4,stride = 2),
        nn.Tanh(),
                 
        
        nn.Conv1d(8,12,kernel_size = 4, stride = 2),
        nn.MaxPool1d(kernel_size = 4,stride = 1),
        nn.Tanh(),

        nn.Conv1d(12,16,kernel_size = 4, stride = 1),
        nn.MaxPool1d(kernel_size = 3,stride = 1),
        nn.Tanh())

#Cov FC LAYER
self.fc_mu = nn.Linear(16,self.zdims)
self.fc_std = nn.Linear(16,self.zdims)

#Deconv FC LAYER

self.fc_d = nn.Linear(self.zdims,16)

#Decoder layer
self.decoder = nn.Sequential(nn.ConvTranspose1d(16,12,kernel_size = 5, stride = 2),
        nn.Tanh(),
        

        nn.ConvTranspose1d(12,8,kernel_size = 5,stride = 3),
        nn.Tanh(),
                    

        nn.ConvTranspose1d(8,4,kernel_size= 12,stride = 4),
        nn.Tanh(),

        nn.ConvTranspose1d(4,1,kernel_size = 18, stride = 9),
        nn.Tanh())

def parameterization_trick(self,mu,logvar):

  std = torch.exp(logvar*0.5)
  # sample epslion from N(0,1)
  eps = torch.randn_like(std)
  # sampling now can be done by shifting the eps by (adding) the mean 
    # and scaling it by the variance.
  return mu+eps*std

def encode(self,imgs):

  output = self.encoder(imgs)
  output = output.view(-1,16)
  mu = self.fc_mu(output)
  logvar = self.fc_std(output)
  z = self.parameterization_trick(mu,logvar)
  return mu,logvar,z

def decode(self,z):

  deconv_input = (self.fc_d(z))
  deconv_input = deconv_input.view(-1,16,2)
  reconstructed_img = self.decoder(deconv_input)
  return reconstructed_img

def forward(self,x):
mu,logvar,z = self.encode(x)
reconstructed_img = self.decode(z)
return reconstructed_img, mu, logvar

For visualization you need to compress your representation to a lower dimension that you can plot.
try printing the basic bva that you have and understand how the code works then you can easily later it for your own case.

1 Like