I am implementing a VAE, where you approximate the expectation values by Monte Carlo sampling. In the original Kingma paper they just used one sample per batch but for me it turned out that I need several samples to get a good estimate.
The problem is just that drawing 10 Monte Carlo samples instead of just 1 increases the computation time per epoch by more than 100%. Im doing this sequentially in a for loop, so in each for loop iteration I draw a Monte Carlo sample then plug it into the loss function and calculate that.
My guess is that somehow the algorithm also then backpropagates 10 times through the same networks otherwise I could explain why this increases the running time so much since drawing samples and calculating the loss isnt that big of a time factor.
Do you have any ideas how this could be implemented more efficiently?