# Detach rsamples from distributions parameters

I want to compute the gradient for the parameters of a distributions in two steps, such that it is possible to decouple the code defining the distribution from the for loop training. The following works but I have to retain the whole graph (or manually delete references to `cost` and `samples` with `del`, which may become unfeasible when the computational graph becomes more complex).

``````log_std = nn.Parameter(torch.Tensor([1]))
std = torch.exp(log_std)
mean = nn.Parameter(torch.Tensor([1]))
dist = dd.Normal(loc=mean, scale=std)
optim = torch.optim.SGD([log_std, mean], lr=0.01)
target = dd.Normal(5,5)

for i in range(50):
samples = dist.rsample((1000,))
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
cost.backward(retain_graph=True)
optim.step()
print(i, log_std, mean, cost)
``````

I would like to do something like this instead

``````log_std = nn.Parameter(torch.Tensor([1]))
std = torch.exp(log_std)
mean = nn.Parameter(torch.Tensor([1]))
dist = dd.Normal(loc=mean, scale=std)

optim = torch.optim.SGD([log_std, mean], lr=0.01)
target = dd.Normal(5,5)

for i in range(50):
samples = dist.rsample((1000,))
detached_samples = samples.detach()
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
cost.backward() # compute gradients up to detached samples
samples.backward(detached_samples.grad(), retain_graph=True)  # get gradients of distribution parameters log_std, mean and retain only this part of the graph
optim.step()  # apply update
print(i, log_std, mean, cost)
``````

Even if I detach samples, when I run `cost.backward()` I get gradients for `log_std` and `mean`. I donâ€™t know if thatâ€™s intended and what Iâ€™m doing is smart or dumb or even possible, but I would like someone to shed some light on this

Well, if you use samples rather than detached_samples in your costâ€¦

oh my god, that was dumb

using

``````cost = -(target.log_prob(detached_samples) - dist.log_prob(detached_samples)).sum()
``````

doesnâ€™t solve the problem though I think itâ€™s impossible to do this because there is a direct dependence between log_prob and the parameters of `Normal`, detaching samples doesnâ€™t actually forbid the gradients to flow up to `log_std` and `mean`.

Thanks for taking a look!

So one other thing is that you likely want `detached_samples.requires_grad_()`.
So you want gradients w.r.t. detached_samples without affecting the params? Then I would recommend using `torch.autograd.grad` instead of `.backward` to selectively calculate gradients without side-effects.

So this question came from a stack overflow post of mine; The ultimate goal is to train the parameters of the distribution, without having to worry about
a) the type of distribution used, or
b) how the parameters of the distribution are generated (in this case, the std of the normal distribution is defined in terms of its log to ensure non-negative results)

In short, here this means I do not want to have to do std = e^(log_std) at any point other than the first time I define the distribution.

This is something that will work in Tensorflow, as the graphs are static there. It would seem the easiest way would be to retain some small parts of the graph in PyTorch, but this seems to not be possible. The above was a way of trying to achieve that. @tom If you have any suggestions on this I would greatly appreciate it!

Well yeah, but TF will do the computation over and over again.
So what you would do in PyTorch is to wrap parameter, transformation of the parameter and distribution in a `nn.Module` and use that.
A long while ago, I thought that blurring the boundary between `nn.Modules` and `nn.Parameters` was a good idea, I still think it would be, but obviously it doesnâ€™t seem to be immediately pending.

Best regards

Thomas

Thank you @tom, do you happen to have a quick hints on implementing this? The main issue I see is if I pass a module in the place of a parameter, it will not be accesses in the same way (i.e. self.parameter vs self.parameter() ).

Thanks,
Michael

No, donâ€™t use the Module for parameters (doesnâ€™t work well, or we would not need the feature discussed in the linked abstraction), but have the nn.Module wrapping your distribution. You can have your module define log_prob and rsample if you want - except for calling some hooks that you might not need between model() and model.forward(), forward isnâ€™t special.
In a way, this is similar to the what is nn.Module tutorial - nn.Modules are containers for the parameters and wrap the calculation we do with them.

Best regards

Thomas

Constructing new distribution instances is cheap, and you can use the `constraints` library to automatically handle the transformation of the `std` parameter to `(0, inf)`, as below.

Code:

``````import torch
import torch.distributions as dd
from torch.distributions import biject_to

target = dd.Normal(5, 5)

def get_dist():
return dd.Normal(**get_constrained_params())

def get_constrained_params():
# apply automatic constraints to distribution parameters.
constraints = target.arg_constraints
constrained_params = {}
for k, v in params.items():
constrained_params[k] = biject_to(constraints[k])(v)
return constrained_params

for i in range(10000):
dist = get_dist()
samples = dist.rsample((1000,))
cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
cost.backward()
optim.step()
print(i, get_constrained_params())

``````

Does this satisfy what you had in mind?

2 Likes

Thatâ€™s awesome, @neerajprad! Iâ€™d probably stick the things into a nn.Module to hold the parameters.

Best regards

Thomas

1 Like

Based on @tomâ€™s suggestion, here is a modified attempt that uses `nn.Module`

``````import torch
import torch.distributions as dd
import torch.nn as nn
from torch.distributions import biject_to

class DistributionModule(nn.Module):
def __init__(self, dist, arg_constraints, **params):
super(DistributionModule, self).__init__()
self.dist_class = dist
self.arg_constraints = arg_constraints
for k, v in params.items():
self.register_parameter(k, nn.Parameter(v))

def get_dist(self):
constrained_params = dict(self.get_constrained_params())
return self.dist_class(**constrained_params)

def log_prob(self, sample):
return self.get_dist().log_prob(sample)

def rsample(self, sample_shape=()):
return self.get_dist().rsample(sample_shape)

def forward(self, value):
return self.log_prob(value)

def get_constrained_params(self):
for name, param in self.named_parameters():
yield name, biject_to(self.arg_constraints[name])(param)

params = {'loc': torch.tensor([1.]),
'scale': torch.tensor([1.])}
target = dd.Normal(5, 5)
d = DistributionModule(dd.Normal, target.arg_constraints, **params)

for i in range(10000):
Not tested very well, but I think you should be able to use a generic wrapper like this for most distribution classes! The forward method just delegates to `.log_prob` (though it is not being used here).