I want to compute the gradient for the parameters of a distributions in two steps, such that it is possible to decouple the code defining the distribution from the for loop training. The following works but I have to retain the whole graph (or manually delete references to cost and samples with del, which may become unfeasible when the computational graph becomes more complex).

log_std = nn.Parameter(torch.Tensor([1]))
std = torch.exp(log_std)
mean = nn.Parameter(torch.Tensor([1]))
dist = dd.Normal(loc=mean, scale=std)
optim = torch.optim.SGD([log_std, mean], lr=0.01)
target = dd.Normal(5,5)
for i in range(50):
optim.zero_grad()
samples = dist.rsample((1000,))
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
cost.backward(retain_graph=True)
optim.step()
print(i, log_std, mean, cost)

I would like to do something like this instead

log_std = nn.Parameter(torch.Tensor([1]))
std = torch.exp(log_std)
mean = nn.Parameter(torch.Tensor([1]))
dist = dd.Normal(loc=mean, scale=std)
optim = torch.optim.SGD([log_std, mean], lr=0.01)
target = dd.Normal(5,5)
for i in range(50):
optim.zero_grad()
samples = dist.rsample((1000,))
detached_samples = samples.detach()
detached_samples.retain_grad()
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
cost.backward() # compute gradients up to detached samples
samples.backward(detached_samples.grad(), retain_graph=True) # get gradients of distribution parameters log_std, mean and retain only this part of the graph
optim.step() # apply update
print(i, log_std, mean, cost)

Even if I detach samples, when I run cost.backward() I get gradients for log_std and mean. I don’t know if that’s intended and what I’m doing is smart or dumb or even possible, but I would like someone to shed some light on this

doesn’t solve the problem though I think it’s impossible to do this because there is a direct dependence between log_prob and the parameters of Normal, detaching samples doesn’t actually forbid the gradients to flow up to log_std and mean.

So one other thing is that you likely want detached_samples.requires_grad_().
So you want gradients w.r.t. detached_samples without affecting the params? Then I would recommend using torch.autograd.grad instead of .backward to selectively calculate gradients without side-effects.

So this question came from a stack overflow post of mine; The ultimate goal is to train the parameters of the distribution, without having to worry about
a) the type of distribution used, or
b) how the parameters of the distribution are generated (in this case, the std of the normal distribution is defined in terms of its log to ensure non-negative results)

In short, here this means I do not want to have to do std = e^(log_std) at any point other than the first time I define the distribution.

This is something that will work in Tensorflow, as the graphs are static there. It would seem the easiest way would be to retain some small parts of the graph in PyTorch, but this seems to not be possible. The above was a way of trying to achieve that. @tom If you have any suggestions on this I would greatly appreciate it!

Well yeah, but TF will do the computation over and over again.
So what you would do in PyTorch is to wrap parameter, transformation of the parameter and distribution in a nn.Module and use that.
A long while ago, I thought that blurring the boundary between nn.Modules and nn.Parameterswas a good idea, I still think it would be, but obviously it doesn’t seem to be immediately pending.

Thank you @tom, do you happen to have a quick hints on implementing this? The main issue I see is if I pass a module in the place of a parameter, it will not be accesses in the same way (i.e. self.parameter vs self.parameter() ).

No, don’t use the Module for parameters (doesn’t work well, or we would not need the feature discussed in the linked abstraction), but have the nn.Module wrapping your distribution. You can have your module define log_prob and rsample if you want - except for calling some hooks that you might not need between model() and model.forward(), forward isn’t special.
In a way, this is similar to the what is nn.Module tutorial - nn.Modules are containers for the parameters and wrap the calculation we do with them.

Constructing new distribution instances is cheap, and you can use the constraints library to automatically handle the transformation of the std parameter to (0, inf), as below.

Code:

import torch
import torch.distributions as dd
from torch.distributions import biject_to
params = {'loc': torch.tensor([1.], requires_grad=True),
'scale': torch.tensor([1.], requires_grad=True)}
optim = torch.optim.Adam(params.values(), lr=0.001)
target = dd.Normal(5, 5)
def get_dist():
return dd.Normal(**get_constrained_params())
def get_constrained_params():
# apply automatic constraints to distribution parameters.
constraints = target.arg_constraints
constrained_params = {}
for k, v in params.items():
constrained_params[k] = biject_to(constraints[k])(v)
return constrained_params
for i in range(10000):
optim.zero_grad()
dist = get_dist()
samples = dist.rsample((1000,))
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
cost.backward()
optim.step()
print(i, get_constrained_params())

Based on @tom’s suggestion, here is a modified attempt that uses nn.Module

import torch
import torch.distributions as dd
import torch.nn as nn
from torch.distributions import biject_to
class DistributionModule(nn.Module):
def __init__(self, dist, arg_constraints, **params):
super(DistributionModule, self).__init__()
self.dist_class = dist
self.arg_constraints = arg_constraints
for k, v in params.items():
self.register_parameter(k, nn.Parameter(v))
def get_dist(self):
constrained_params = dict(self.get_constrained_params())
return self.dist_class(**constrained_params)
def log_prob(self, sample):
return self.get_dist().log_prob(sample)
def rsample(self, sample_shape=()):
return self.get_dist().rsample(sample_shape)
def forward(self, value):
return self.log_prob(value)
def get_constrained_params(self):
for name, param in self.named_parameters():
yield name, biject_to(self.arg_constraints[name])(param)
params = {'loc': torch.tensor([1.]),
'scale': torch.tensor([1.])}
target = dd.Normal(5, 5)
d = DistributionModule(dd.Normal, target.arg_constraints, **params)
optim = torch.optim.Adam(d.parameters(), lr=0.001)
for i in range(10000):
optim.zero_grad()
samples = d.rsample((1000,))
cost = -(target.log_prob(samples) - d.log_prob(samples)).sum()
cost.backward()
optim.step()
print(i, dict(d.get_constrained_params()))

Not tested very well, but I think you should be able to use a generic wrapper like this for most distribution classes! The forward method just delegates to .log_prob (though it is not being used here).