Non-leaf tensor in optimizing parameters in torch.distributions.Normal

I’m currently using PyTorch to construct and optimize a simple probabilistic graphical model. I’m perplexed with receiving a “can’t optimize a non-leaf Tensor” error using torch.distributions.Normal, whereas when I create the equivalent example using torch.distributions.MultivariateNormal the optimizer works just fine. Here are two code examples:

# Example optimizing list of torch normal dist
mu = torch.zeros(1)
mu = torch.tensor(mu, requires_grad=True)

S = torch.abs(torch.randn(1))
S = torch.tensor(S, requires_grad=True)

m = torch.distributions.Normal(mu, S)

dist_list = [m, ]

print(mu)
print(S)
print(m.mean)
print(m.stddev)

opt = torch.optim.Adam([dist_list[0].mean], lr = 0.01)

for i in range(100):
    loss = m.log_prob(torch.Tensor([5, 5]))
    opt.zero_grad()
    loss.backward(retain_graph=True)
    opt.step()

This leads to a “can’t optimize leaf-tensor error”. On the other hand, the following code optimizing a multivariate normal distribution runs flawlessly:

# Example optimizing list of torch MVN distributions
mu = torch.zeros(2)
mu = torch.tensor(mu, requires_grad=True)

L = torch.randn(2, 2)
S = torch.mm(L, torch.t(L))
S = torch.add(S, 1e-8 * torch.eye(2))
S = torch.tensor(S, requires_grad=True)

m = torch.distributions.MultivariateNormal(mu, S)

dist_list = [m, ]

opt = torch.optim.Adam([dist_list[0].mean, dist_list[0].covariance_matrix], lr = 0.01)

for i in range(100):
    loss = m.log_prob(torch.Tensor([5, 5]))
    opt.zero_grad()
    loss.backward(retain_graph=True)
    opt.step()

Thanks a bunch in advance.

Looking at the code for torch.distributions.Normal, it seems that mean is a property that returns self.loc which is a broadcasted version of loc ( the mean you specify for the distribution). this extra broadcasting makes m.mean a non-leaf tensor.
Such an extra operation doesn’t seem to take place inside MultivariateNormal, so you can optimize that i think.
You might be better off optimizing mu in the case that is giving you trouble.

1 Like