How to properly get the latent vector of a VAE during the traning?

Hi all,

I have a question, I found a solution to brutforce the problem but that’s not very pretty coding. I was wondering if you had a better a solution for my following problem.

I have made a VAE that’s in the following form:

def forward(self, input):
    latent_vector = encoder(input)
    output = decoder(latent_vector)
    return output 

I actually have loss in the following form:

predicted = self.model(inputs).float()
loss = self.criterion(predicted, expected)

I would like to force the latent vector to have a particular form, so I would like to add a term in the loss like:

loss = self.criterion(predicted, expected) + some_function(latent_vector)

Is there a proper way to get my latent vector to perform that ?



If I have understood your question properly, You can just return output and latent_vector and do the other things as you have mentioned. But if you do not want to return the output not a tuple, I think it is good to add a attribute to your Model class and update it every time through forward function and you can obtain its value just by calling model.self.latent_vector.

class Model:
  def __init__(self):
    self.latent_vector = None

  def forward(self, x):
    self.latent_vector = x
    output = x*2
    return output

model(2)  # prints 4
model.latent_vector  # prints 2

Good luck

Okay thank you so much for your answer in fact I just found a problem with this method according to my problem
In fact I don’t want to look at the latent vector of the current input.

I have different classes let say for examples animals : cats, dogs and birds.
If my current input is a cat: cat_i and I have n cats in my dataset: {cat_1, cat_2, …, cat_n}.
I want the latent vector to learn the distribution among the cats, the dogs and the birds separately.

So I would like to code the Loss this way:

Loss = self.criterion(current_reconstructed_cat_i, cat_i) 
+ sum{for j=1 to n, j!=i} some_function(current_latent_vector_of_cat_j, current_latent_vector_of_cat_i)

current_latent_vector_of_cat_j is the current pass forward stopped to the latent_vector.


So based on your explanation, you need all latent vectors of the dataset to be able to calculate the for loop. Actually, it may be tricky, because of memory issues because you need to pass all the dataset containing a class such as cat.

There is point here. the encoder generates latent_vector, so if you want the latent vector of whole dataset to calculate loss for only a batch, you have to prepare latent vectors one the go. I mean when you passed a batch through the network to calculate below code:

you need all latent vectors, then you will optimize your model a step. So for the next batch, the latent vectors you already have are no longer valid and you have to calculate them again. This will be a slow process.But let say there is no such a problem. To calculate latent vector without updating weights, you can do following in your training procedure:

latent_vectors = []
with torch.no_grad():
        for data in data_loader:
            latent_vector = model.encoder(data)
 return latent_vectors

The only point is to use torch.no_grad to preserve the models state. But still I am concerned about the performance.