Detach rsamples from distributions parameters

I want to compute the gradient for the parameters of a distributions in two steps, such that it is possible to decouple the code defining the distribution from the for loop training. The following works but I have to retain the whole graph (or manually delete references to cost and samples with del, which may become unfeasible when the computational graph becomes more complex).

log_std = nn.Parameter(torch.Tensor([1]))
std = torch.exp(log_std)
mean = nn.Parameter(torch.Tensor([1]))
dist = dd.Normal(loc=mean, scale=std)
optim = torch.optim.SGD([log_std, mean], lr=0.01)
target = dd.Normal(5,5)

for i in range(50):
    optim.zero_grad()
    samples = dist.rsample((1000,))
    cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
    cost.backward(retain_graph=True)
    optim.step()
    print(i, log_std, mean, cost)

I would like to do something like this instead

log_std = nn.Parameter(torch.Tensor([1]))
std = torch.exp(log_std)
mean = nn.Parameter(torch.Tensor([1]))
dist = dd.Normal(loc=mean, scale=std)

optim = torch.optim.SGD([log_std, mean], lr=0.01)
target = dd.Normal(5,5)

for i in range(50):
    optim.zero_grad()
    samples = dist.rsample((1000,))
    detached_samples = samples.detach()
    detached_samples.retain_grad()
    cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
    cost.backward() # compute gradients up to detached samples
    samples.backward(detached_samples.grad(), retain_graph=True)  # get gradients of distribution parameters log_std, mean and retain only this part of the graph
    optim.step()  # apply update
    print(i, log_std, mean, cost)

Even if I detach samples, when I run cost.backward() I get gradients for log_std and mean. I don’t know if that’s intended and what I’m doing is smart or dumb or even possible, but I would like someone to shed some light on this :slight_smile:

Well, if you use samples rather than detached_samples in your cost…

oh my god, that was dumb :laughing:

using

cost = -(target.log_prob(detached_samples) - dist.log_prob(detached_samples)).sum()

doesn’t solve the problem though :frowning: I think it’s impossible to do this because there is a direct dependence between log_prob and the parameters of Normal, detaching samples doesn’t actually forbid the gradients to flow up to log_std and mean.

Thanks for taking a look!

So one other thing is that you likely want detached_samples.requires_grad_().
So you want gradients w.r.t. detached_samples without affecting the params? Then I would recommend using torch.autograd.grad instead of .backward to selectively calculate gradients without side-effects.

So this question came from a stack overflow post of mine; The ultimate goal is to train the parameters of the distribution, without having to worry about
a) the type of distribution used, or
b) how the parameters of the distribution are generated (in this case, the std of the normal distribution is defined in terms of its log to ensure non-negative results)

In short, here this means I do not want to have to do std = e^(log_std) at any point other than the first time I define the distribution.

This is something that will work in Tensorflow, as the graphs are static there. It would seem the easiest way would be to retain some small parts of the graph in PyTorch, but this seems to not be possible. The above was a way of trying to achieve that. @tom If you have any suggestions on this I would greatly appreciate it!

Well yeah, but TF will do the computation over and over again.
So what you would do in PyTorch is to wrap parameter, transformation of the parameter and distribution in a nn.Module and use that.
A long while ago, I thought that blurring the boundary between nn.Modules and nn.Parameters was a good idea, I still think it would be, but obviously it doesn’t seem to be immediately pending.

Best regards

Thomas