I am implementing maximum-likelihood estimation using torch.distributions. When the parameter of my distribution is a variable then optimization works fine. But when the parameter is a function of a variable then I get “RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.”
Please see below for an example that works and an example that does not work. My goal is to have the parameter that initializes torch.distributions object to be the output of calling a function on my variable rather than the variable itself.
############################################################
import torch
from numpy.random import randn
############ This works #####################
generate random data sample with mean=5, std=1
data = torch.tensor(randn(1000)+5,dtype=torch.float32)
create torch variable for the mean
mu = torch.tensor(0,dtype=torch.float32,requires_grad=True)
create torch Normal distribution object
dist = torch.distributions.Normal(loc=mu,scale=1.0)
create optimizer to maximize log-likelihood w.r.t mu
optimizer = torch.optim.RMSprop([mu],lr=0.1)
for i in range(100):
loss = -dist.log_prob(data).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(mu)
out: tensor(5.0025, requires_grad=True)
######################################################
########### This does not work #####################
create torch Normal distribution object but the location parameter is a function of mu
I know that mu>0 and enforce it by using mu^2 for the mean
dist = torch.distributions.Normal(loc=torch.square(mu),scale=1.0)
create optimizer to maximize log-likelihood w.r.t mu
optimizer = torch.optim.RMSprop([mu],lr=0.1)
for i in range(100):
loss = -dist.log_prob(data).sum()
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(mu)
########################################################