Optimizing parameters wrapped in torch.distributions

I want to optimize the parameters of a probability distribution using rsample() in the computational graph. Ideally, I would just do the following:

import torch
mu = torch.randn(2, requires_grad=True)
dist = torch.distributions.normal.Normal(loc=torch.sigmoid(mu), scale=torch.ones(2))

for _ in range(3):
    loss = torch.sum(dist.rsample())
    loss.backward()

    with torch.no_grad():
        mu.add_(mu.grad, alpha=-0.1)

This returns the

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed.

The transformation torch.sigmoid() inside the definition of the distribution is causing this problem. Is there a nicer way of accomplishing this other than reconstructing the distribution on every call of rsample as suggested here?

you should try loss.backward(retain_graph=True)

Wouldn’t that just leave me with the complete computational graph and make each iteration slower and slower as a result?

yes. but as the the error descriptors said, you are backwarding through the graph a second time. So you need to keep computational graph intact to do that.

In that case I prefer the workaround I mentioned as it does not suffer from this problem. I was hoping there was a more elegant solution.

you could have mu.detach() and passed into torch.sigmoid

I often just use rsample/log_prob code without distribution objects. Like:

def normal_rsample(m,s,shape):
	return torch.randn(shape) * s + m