I want to optimize the parameters of a probability distribution using
rsample() in the computational graph. Ideally, I would just do the following:
import torch mu = torch.randn(2, requires_grad=True) dist = torch.distributions.normal.Normal(loc=torch.sigmoid(mu), scale=torch.ones(2)) for _ in range(3): loss = torch.sum(dist.rsample()) loss.backward() with torch.no_grad(): mu.add_(mu.grad, alpha=-0.1)
This returns the
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.
torch.sigmoid() inside the definition of the distribution is causing this problem. Is there a nicer way of accomplishing this other than reconstructing the distribution on every call of
rsample as suggested here?