So based on your explanation, you need all latent vectors of the dataset to be able to calculate the for loop. Actually, it may be tricky, because of memory issues because you need to pass all the dataset containing a class such as cat
.
There is point here. the encoder
generates latent_vector
, so if you want the latent vector of whole dataset to calculate loss for only a batch, you have to prepare latent vectors one the go. I mean when you passed a batch through the network to calculate below code:
you need all latent vectors, then you will optimize your model a step
. So for the next batch, the latent vectors you already have are no longer valid and you have to calculate them again. This will be a slow process.But let say there is no such a problem. To calculate latent vector without updating weights, you can do following in your training procedure:
latent_vectors = []
with torch.no_grad():
for data in data_loader:
latent_vector = model.encoder(data)
latent_vectors.append(latent_vector)
return latent_vectors
The only point is to use torch.no_grad
to preserve the models state. But still I am concerned about the performance.