Hello, I got a problem with memory explosion, when implementing MC approximation of a function in the loop, generating tensors on GPU.
Here is the code:
# self.post_mean and self.post_logvar are parameters to optimize
sampler_normal = torch.distributions.Normal(
loc=torch.zeros_like(self.post_mean),
scale=torch.ones_like(self.post_logvar))
kl_cross = 0 # initialization for average of MC approximation
n_iter = 100
for i in range(n_iter):
weights = sampler_normal.sample((1,))
weights = self.post_mean + \
torch.sqrt(post_var) * weights / torch.norm(weights) * radius
kl_cross += (torch.sum(-1 / 2 * torch.log(2 * prior_var) +
-1 / 2 / prior_var * (weights - self.prior_mean)**2) /\
n_iter)
In the code above the memory for gpu explodes. I believe it happens because kl_cross.grad_function collects pointers to n_iter objects. Thus, if I have several such approximations, then memory increases very quick.
What would be your suggestions to implement MC approximation on GPU?