Efficient implementation of Monte Carlo sampling in VAEs

Dear all,

I am implementing a VAE, where you approximate the expectation values by Monte Carlo sampling. In the original Kingma paper they just used one sample per batch but for me it turned out that I need several samples to get a good estimate.
The problem is just that drawing 10 Monte Carlo samples instead of just 1 increases the computation time per epoch by more than 100%. Im doing this sequentially in a for loop, so in each for loop iteration I draw a Monte Carlo sample then plug it into the loss function and calculate that.
My guess is that somehow the algorithm also then backpropagates 10 times through the same networks otherwise I could explain why this increases the running time so much since drawing samples and calculating the loss isnt that big of a time factor.
Do you have any ideas how this could be implemented more efficiently?

I don’t know if you are still interested in this issue. You can draw a bunch of samples at once, storing them in an extra dimension. You can then flatten this dimension and make the batch go through the decoder. You then reshape the tensor to reinsert the extra dimension (it should be of size L, following the notations in Kingma & Welling’s paper). You finally compute the likelihood and average across this dimension.

In principle, it sounds reasonable to expect that the running time goes up quite a bit. Imagine in each step the time it takes to perform the computations before the sampling is a, and the time after the sampling is b. For L = 1, the total time is a+b. If L>1, then total time is a+L*b. If b >> a, then total time is roughly multiplied by L, which means 1000% for L = 10.

By batching the computation as I described, you are able to benefit from the parallel matrix multiplications performed by the GPU. So if your batch size is small enough, the running time should not change.

Keep in mind I am not an expert.

I have written an implementation of this algorithm, but in my case setting L to anything other than 1 completely breaks the algorithm. It might be a bug, or a feature of the problem I am trying to solve that I don’t understand yet.

Keep me posted.