Beta variational autoencoder

Hi All

has anyone worked with “Beta-variational autoencoder”?

what is your problem?

I do have no idea to start from where?
Looked through Web to see someone else had done this in pytorch however, could not find anything,

I guess the main difference between Beta and regular one would be loss calculation. Do you have any idae?

Thanks

Here’s an old implementation of mine (pytorch v 1.0 I guess or maybe 0.4). this is also known as disentagled variational auto encoder:

# Disentagled Variational Autoencoders or (β-VAE)
	# good reads : https://towardsdatascience.com/disentanglement-with-variational-autoencoder-a-review-653a891b69bd 
	# https://lilianweng.github.io/lil-log/2018/08/12/from-autoencoder-to-beta-vae.html#contractive-autoencoder
	# https://openreview.net/forum?id=Sy2fzU9gl
	# https://arxiv.org/pdf/1901.09415.pdf
	# https://arxiv.org/abs/1606.05579
	
	# The basic idea in disentagled vae is that, we want different neurons in our latent distribution
	# to be uncorollated, they all try to learn something different about the input data. In order to implement 
	# this, the only thing that needs to be added to the vanilla VAE, is a β term.
	# previously for the vanilla VAE we had : 
	#     L = E_q(z|X)[log_p(X|z)] - D_KL[q(z|X)||p(z))]
	# Now for the disentagled version (β-VAE) we just add the β like this : 
	#     L = E_q(z|X)[log_p(X|z)] - βD_KL[q(z|X)||p(z))]
	# so to put it simply, in a disentagled vae (B-Vae) the autoencoder will only use a varable if it 
	# its important 
	
	def fc_batchnorm_act(in_, out_, use_bn=True, act=nn.ReLU()):
	    return nn.Sequential(nn.Linear(in_,out_),
	                         act,
	                         nn.BatchNorm1d(out_) if use_bn else nn.Identity())
	                         
	class B_VAE(nn.Module):
	    def __init__(self, embedding_size=5):
	        super().__init__()
	        self.embedding_size = embedding_size
	        
	        # self.fc1 = nn.Linear(28*28, 512)
	        self.encoder_entry = nn.Sequential(fc_batchnorm_act(28*28,512),
	                                           fc_batchnorm_act(512,256),
	                                           fc_batchnorm_act(256,128),
	                                           fc_batchnorm_act(128,64))
	        self.fc_mu = nn.Linear(64, embedding_size)
	        self.fc_std = nn.Linear(64, embedding_size)
	
	        self.decoder = nn.Sequential(fc_batchnorm_act(embedding_size, 64),
	                                     fc_batchnorm_act(64,128),
	                                     fc_batchnorm_act(128,256),
	                                     fc_batchnorm_act(256,512),
	                                     fc_batchnorm_act(512, 28*28,False,nn.Sigmoid()))
	        # self.decoder = nn.Sequential(nn.Linear(embedding_size, 512),
	        #                             nn.ReLU(),
	        #                             nn.Linear(512, 28*28),
	        #                             nn.Sigmoid())
	
	    def reparameterization_trick(self, mu, logvar):
	        # divide by two, since we want positive deviation only
	        std = torch.exp(logvar * 0.5)
	        # sample epslion from N(0,1) 
	        eps = torch.randn_like(std)
	        # sampling now can be done by shifting the eps by (adding) the mean 
	        # and scaling it by the variance. 
	        return mu + eps * std
	
	    def encode(self, imgs):
	        imgs = imgs.view(imgs.size(0), -1)
	        # output = F.relu(self.fc1(imgs))
	        output = self.encoder_entry(imgs)
	        # remember we dont use nonlinearities for mu and logvar!
	        mu = self.fc_mu(output)
	        logvar = self.fc_std(output)
	        z = self.reparameterization_trick(mu, logvar)
	        return z, mu, logvar
	
	    def decode(self, z):
	        reconstructed_imgs = self.decoder(z)
	        reconstructed_imgs = reconstructed_imgs.view(-1, 1, 28, 28)
	        return reconstructed_imgs
	
	    def forward(self, x):
	        # encoder 
	        z, mu, logvar = self.encode(x)
	        # decoder
	        reconstructed_imgs = self.decode(z)
	        return reconstructed_imgs, mu, logvar
	
	def loss_disentagled_vae(outputs, imgs, mu, logvar, Beta, reduction='mean', use_mse=False):
	    # this loss has two parts, a construction loss and a KL divergence loss which
	    # shows how much distance exists between two given distrubutions. 
	    if reduction=='mean':
	        if use_mse:
	            criterion = nn.MSELoss()
	        else:
	            criterion = nn.BCELoss(reduction='mean')
	        recons_loss = criterion(outputs, imgs)
	        # normalize the reconstruction loss
	        recons_loss *= 28*28
	        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
	        # https://arxiv.org/abs/1312.6114
	        # -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
	        # when using mean, we always sum over the last dim
	        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), -1)
	        # we use beta and multiply it by our kl term. this is specific to 
	        # disentagled vae and is actually the main reason why the disentaglement 
	        # work
	        return torch.mean(recons_loss + (Beta*kl))
	    else:
	        criterion = nn.BCELoss(reduction='sum')
	        recons_loss = criterion(outputs, imgs)
	        kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
	        return recons_loss + (Beta*kl)    
	
	epochs = 50
	
	embeddingsize = 5
	interval = 2000
	reduction='mean'
	# beta is a value biger than 1 
	Beta = 5.
	device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
	model = B_VAE(embeddingsize).to(device)
	optimizer = torch.optim.Adam(model.parameters(), lr =0.001)
	scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 50)
	
	for e in range(epochs):
	    for i, (imgs, labels) in enumerate(dataloader_train):
	        imgs = imgs.to(device)
	        preds,mu, logvar = model(imgs)
	
	        loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)
	        
	        optimizer.zero_grad()
	        loss.backward()
	        optimizer.step() 
	        if i% interval ==0:
	            loss = loss/len(img) if reduction=='sum' else loss
	            print(f'epoch {e}/{epochs} [{i*len(imgs)}/{len(dataloader_train.dataset)} ({100.*i/len(dataloader_train):.2f}%)]'
	                  f'\tloss: {loss.item():.4f}'
	                  f'\tlr: {scheduler.get_lr()}')
	    scheduler.step()
	#%% 
	# test
	test_set_size = len(dataloader_test.dataset)
	img_pairs = []
	losses = []
	interval = 10
	with torch.no_grad():
	    for i, (imgs, labels) in enumerate(dataloader_test):
	        imgs = imgs.to(device)
	        preds, mu, logvar = model(imgs)
	        loss = loss_disentagled_vae(preds, imgs, mu, logvar, Beta= Beta, reduction=reduction, use_mse=False)
	        losses.append({'val_loss':loss.item()})
	        
	        print(f'[{i*len(imgs)} / {test_set_size} ({100.*i/len(dataloader_test):.2f}%)]'
	            f'\tloss: {(loss).item():.4f}')
	
	        if i%interval==0:
	            reconstructeds = preds.cpu().detach().view(-1, 1, 28, 28)
	            imgs = imgs[:20].cpu().detach().numpy()
	            recons = reconstructeds[:20].numpy()
	            pairs = np.array([np.dstack((img1,img2)) for img1, img2 in zip(imgs,recons)])
	            img_pairs.append(pairs)
	#%% 
	# import pandas as pd 
	# pd.DataFrame(losses).plot()
	model.eval()
	# create sample image
	z = torch.randn(size=(3, model.embedding_size)).to(device)
	preds = model.decode(z).cpu().detach()
	img = make_grid(preds)
	plt.imshow(img.numpy().transpose(1,2,0))
	#%%
	# visualize latent space
	n = 1 
	z = torch.randn(size=(n,model.embedding_size)).to(device)
	print(z.shape)
	#%%
	fig = plt.figure()
	ax = fig.add_subplot(111)
	preds = model.decode(z).cpu().detach()
	img_latent_space = make_grid(preds,nrow=5).numpy().transpose(1,2,0)
	ax.imshow(img_latent_space)
	#%%
	def change_latentvariable(z, n=3, count = 3, dim=0):
	    z_new = torch.zeros(size=(n, count, z.size(-1)))
	    for i in range(count):
	        z_new[:,i,:] = z[:, :]
	        z_new[:,i, dim]=  z_new[:,i, dim] - (0.2*i)
	    return z_new
	
	def show_manifold(z, n, count, dim , device):
	    fig = plt.figure(figsize=(5,5))
	    ax = fig.add_subplot(111)
	    latent_space_manifold = change_latentvariable(z, n, count, dim).to(device)
	    print(latent_space_manifold.shape)
	    preds = model.decode(latent_space_manifold.view(-1,model.embedding_size)).cpu().detach()
	    img_latent_space_man = make_grid(preds,nrow=count).numpy().transpose(1,2,0)
	    ax.imshow(img_latent_space_man)
	
	show_manifold(z, n=n, count=5, dim=3, device=device)
	show_manifold(z, n=n,  count=5, dim=1, device=device)
	show_manifold(z, n=n,  count=5, dim=2, device=device)
	show_manifold(z, n=n,  count=5, dim=3, device=device)
	show_manifold(z, n=n,  count=5, dim=4, device=device)
	# visualize the 2d manifold 
	#%%
	# variations over the latent variable :
	z_dim = model.embedding_size
	sigma_mean = 2.0*torch.ones((z_dim))
	mu_mean = torch.zeros((z_dim))
	
	# Save generated variable images :
	nbr_steps = 8
	gen_images = torch.ones( (nbr_steps, 1, 28, 28) )
	
	for latent in range(z_dim) :
	    #var_z0 = torch.stack( [mu_mean]*nbr_steps, dim=0)
	    var_z0 = torch.zeros(nbr_steps, z_dim)
	    val = mu_mean[latent]-sigma_mean[latent]
	    step = 2.0*sigma_mean[latent]/nbr_steps
	    print(latent, mu_mean[latent]-sigma_mean[latent], mu_mean[latent], mu_mean[latent]+sigma_mean[latent])
	    for i in range(nbr_steps) :
	        var_z0[i] = mu_mean
	        var_z0[i][latent] = val
	        val += step
	
	    var_z0 = var_z0.to(device)
	
	
	    gen_images_latent = model.decode(var_z0)
	    gen_images_latent = gen_images_latent.cpu().detach()
	    gen_images = torch.cat( [gen_images, gen_images_latent], dim=0)
	
	img = make_grid(gen_images)
	plt.imshow(img.cpu().numpy().transpose(1,2,0))
	#%%
	def plot_latentspace(num_rows,num_cols=9,figure_width=10.5,image_height=1.5):
	    fig = plt.figure(figsize=(figure_width, image_height * num_rows))
	    
	    for i in range(num_rows):
	        z_i_values = np.linspace(-3.0, 3.0, num_cols)
	        z_i = z[0][i].detach().cpu().numpy()
	        z_diffs = np.abs((z_i_values - z_i))
	        j_min = np.argmin(z_diffs)
	        for j in range(num_cols):
	            z_i_value = z_i_values[j]
	            if j != j_min:
	                z[0][i] = z_i_value
	            else:
	                z[0][i] = float(z_i)
	                
	            x = model.decode(z).detach().cpu().numpy()
	            
	            ax = fig.add_subplot(num_rows, num_cols, i * num_cols + j + 1)
	            ax.imshow(x[0][0], cmap='gray')
	            
	            if i == 0 or j == j_min:
	                ax.set_title(f'{z[0][i]:.1f}')
	            
	            if j == j_min:
	                ax.set_xticks([], [])
	                ax.set_yticks([], []) 
	                color = 'mediumseagreen'
	                width = 8
	                for side in ['top', 'bottom', 'left', 'right']:
	                    ax.spines[side].set_color(color)
	                    ax.spines[side].set_linewidth(width)
	            else:
	                ax.axis('off')
	        z[0][i] = float(z_i)
	        
	    plt.tight_layout()
	    fig.subplots_adjust(wspace=0.04)
	num_rows = z.shape[-1]
	plot_latentspace(num_rows)

Moafagh bashi!

1 Like