Hi All
I trained a variational autoencoder, however don’t know how I can plot my latent space. Below you can see my codes that used for defining the class of autoencoder:(I would like to plot the z)
class B_VAE(nn.Module):
def init(self,zdims):
super().init()
self.zdims = zdims
#Encoder layer
self.encoder = nn.Sequential(nn.Conv1d(1,4,kernel_size = 4, stride = 3),
nn.MaxPool1d(kernel_size = 4, stride = 3),
nn.Tanh(),
nn.Conv1d(4,8,kernel_size=4, stride = 2),
nn.MaxPool1d(kernel_size = 4,stride = 2),
nn.Tanh(),
nn.Conv1d(8,12,kernel_size = 4, stride = 2),
nn.MaxPool1d(kernel_size = 4,stride = 1),
nn.Tanh(),
nn.Conv1d(12,16,kernel_size = 4, stride = 1),
nn.MaxPool1d(kernel_size = 3,stride = 1),
nn.Tanh())
#Cov FC LAYER
self.fc_mu = nn.Linear(16,self.zdims)
self.fc_std = nn.Linear(16,self.zdims)
#Deconv FC LAYER
self.fc_d = nn.Linear(self.zdims,16)
#Decoder layer
self.decoder = nn.Sequential(nn.ConvTranspose1d(16,12,kernel_size = 5, stride = 2),
nn.Tanh(),
nn.ConvTranspose1d(12,8,kernel_size = 5,stride = 3),
nn.Tanh(),
nn.ConvTranspose1d(8,4,kernel_size= 12,stride = 4),
nn.Tanh(),
nn.ConvTranspose1d(4,1,kernel_size = 18, stride = 9),
nn.Tanh())
def parameterization_trick(self,mu,logvar):
std = torch.exp(logvar*0.5)
# sample epslion from N(0,1)
eps = torch.randn_like(std)
# sampling now can be done by shifting the eps by (adding) the mean
# and scaling it by the variance.
return mu+eps*std
def encode(self,imgs):
output = self.encoder(imgs)
output = output.view(-1,16)
mu = self.fc_mu(output)
logvar = self.fc_std(output)
z = self.parameterization_trick(mu,logvar)
return mu,logvar,z
def decode(self,z):
deconv_input = (self.fc_d(z))
deconv_input = deconv_input.view(-1,16,2)
reconstructed_img = self.decoder(deconv_input)
return reconstructed_img
def forward(self,x):
mu,logvar,z = self.encode(x)
reconstructed_img = self.decode(z)
return reconstructed_img, mu, logvar
def loss_disentangled_vae(outputs, imgs, mu, logvar, Beta):
criterion = nn.MSELoss()
recons_loss = criterion(outputs,imgs)
kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return (recons_loss+(Beta*kl))
Thanks