I want to optimize the parameters of a probability distribution using rsample()
in the computational graph. Ideally, I would just do the following:
import torch
mu = torch.randn(2, requires_grad=True)
dist = torch.distributions.normal.Normal(loc=torch.sigmoid(mu), scale=torch.ones(2))
for _ in range(3):
loss = torch.sum(dist.rsample())
loss.backward()
with torch.no_grad():
mu.add_(mu.grad, alpha=-0.1)
This returns the
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.
The transformation torch.sigmoid()
inside the definition of the distribution is causing this problem. Is there a nicer way of accomplishing this other than reconstructing the distribution on every call of rsample
as suggested here?