'backward through the graph' error when optimizing parameter in torch.distributions object

I am implementing maximum-likelihood estimation using torch.distributions. When the parameter of my distribution is a variable then optimization works fine. But when the parameter is a function of a variable then I get “RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.”

Please see below for an example that works and an example that does not work. My goal is to have the parameter that initializes torch.distributions object to be the output of calling a function on my variable rather than the variable itself.

############################################################
import torch
from numpy.random import randn

############ This works #####################

generate random data sample with mean=5, std=1

data = torch.tensor(randn(1000)+5,dtype=torch.float32)

create torch variable for the mean

mu = torch.tensor(0,dtype=torch.float32,requires_grad=True)

create torch Normal distribution object

dist = torch.distributions.Normal(loc=mu,scale=1.0)

create optimizer to maximize log-likelihood w.r.t mu

optimizer = torch.optim.RMSprop([mu],lr=0.1)

for i in range(100):
loss = -dist.log_prob(data).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(mu)
out: tensor(5.0025, requires_grad=True)
######################################################

########### This does not work #####################

create torch Normal distribution object but the location parameter is a function of mu

I know that mu>0 and enforce it by using mu^2 for the mean

dist = torch.distributions.Normal(loc=torch.square(mu),scale=1.0)

create optimizer to maximize log-likelihood w.r.t mu

optimizer = torch.optim.RMSprop([mu],lr=0.1)

for i in range(100):
loss = -dist.log_prob(data).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()

print(mu)
########################################################

Here’s the working code for the second case.

mu = torch.tensor(1.0, dtype=torch.float32, requires_grad=True)


optimizer = torch.optim.RMSprop([mu], lr=0.1)

for i in range(400):
    dist = torch.distributions.Normal(loc=torch.square(mu), scale=1.0)
    loss = -dist.log_prob(data).sum()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(mu * mu)

Your code creates a new instance of torch.distributions.Normal object inside the loop. I am trying to avoid this inefficiency by creating the distribution object outside the loop and training its parameters inside. This works fine in my first example but doesn’t work in the second one.
My use case has many distribution objects that are trained simultaneously and recreating each object in each loop iteration is expensive.