Defining parameters by some transformation OR Retaining sub-graphs, but not the whole graph

Hi all,

I’m coming across an issue I haven’t seen come up before. I work in Bayesian Machine Learning and as such make a lot of use of the distributions in PyTorch. One common thing to do is to define some of the parameters of distributions in terms of the log of their parameter, so that in optimisation they cannot go negative (e.g. the standard deviation of a Normal distribution).

In order to be distribution independent however, I do not want to have to manually recompute the conversion of this parameter. To demonstrate via example:

The following code will NOT run. After the first backward pass, the part of the graph the calculates the exponential of the parameter is automatically removed, and not re-added.

import torch
import torch.nn as nn
import torch.distributions as dd

log_std = nn.Parameter(torch.Tensor([1])) # Define the log of the parameter as an nn.Parameter, this is what we want to optimise
std = torch.exp(log_std) # Define the transformation we want to apply to the parameter to using it in the distribution
mean = nn.Parameter(torch.Tensor([1])) # A normal parameter
dist = dd.Normal(loc=mean, scale=std) # Define the distribution. From here I want to ONLY refer to this, not the other variables

optim = torch.optim.SGD([log_std, mean], lr=0.01) # Standard optimiser
target = dd.Normal(5,5) # Target distribution to match

for i in range(50):
    optim.zero_grad()

    samples = dist.rsample((1000,)) # Sample our model, note no reference to log_std

    cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum() # KLdivergence cost metric
    cost.backward()
    optim.step()
    print(i)
    print(log_std, mean, cost)
    print()

The next set of code WILL run, but I must explicitly reference the log_std parameter in the loop, and recreate the distribution. If I wanted to change the distribution type, it would not be possible without considering the specific case.

import torch
import torch.nn as nn
import torch.distributions as dd

log_std = nn.Parameter(torch.Tensor([1])) # Define the log of the parameter as an nn.Parameter, this is what we want to optimise
mean = nn.Parameter(torch.Tensor([1])) # A normal parameter

optim = torch.optim.SGD([log_std, mean], lr=0.001) # Standard optimiser
target = dd.Normal(5,5) # Target distribution to match

for i in range(50):
    optim.zero_grad()

    std = torch.exp(log_std)  # Define the transformation we want to apply to the parameter to using it in the distribution
    dist = dd.Normal(loc=mean, scale=std)  # Define the distribution.

    samples = dist.rsample((1000,)) # Sample our model, note no reference to log_std

    cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum() # KL divergence cost metric
    cost.backward()
    optim.step()
    print(i)
    print(mean, std, cost)
    print()

The first example does however work in Tensorflow, as the graphs there are static. Does anyone have some ideas on how I might fix this? If it were possible to keep only the part of the graph that defines the relationship std = torch.exp(log_std) then this could work. I have also tried playing with backwards gradient hooks, but unfortunately to calculate the new gradient properly you need access to the parameter value and the learning rate.

Thanks in advance!
Michael