Detach rsamples from distributions parameters

I want to compute the gradient for the parameters of a distributions in two steps, such that it is possible to decouple the code defining the distribution from the for loop training. The following works but I have to retain the whole graph (or manually delete references to cost and samples with del, which may become unfeasible when the computational graph becomes more complex).

log_std = nn.Parameter(torch.Tensor([1]))
std = torch.exp(log_std)
mean = nn.Parameter(torch.Tensor([1]))
dist = dd.Normal(loc=mean, scale=std)
optim = torch.optim.SGD([log_std, mean], lr=0.01)
target = dd.Normal(5,5)

for i in range(50):
    optim.zero_grad()
    samples = dist.rsample((1000,))
    cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
    cost.backward(retain_graph=True)
    optim.step()
    print(i, log_std, mean, cost)

I would like to do something like this instead

log_std = nn.Parameter(torch.Tensor([1]))
std = torch.exp(log_std)
mean = nn.Parameter(torch.Tensor([1]))
dist = dd.Normal(loc=mean, scale=std)

optim = torch.optim.SGD([log_std, mean], lr=0.01)
target = dd.Normal(5,5)

for i in range(50):
    optim.zero_grad()
    samples = dist.rsample((1000,))
    detached_samples = samples.detach()
    detached_samples.retain_grad()
    cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
    cost.backward() # compute gradients up to detached samples
    samples.backward(detached_samples.grad(), retain_graph=True)  # get gradients of distribution parameters log_std, mean and retain only this part of the graph
    optim.step()  # apply update
    print(i, log_std, mean, cost)

Even if I detach samples, when I run cost.backward() I get gradients for log_std and mean. I don’t know if that’s intended and what I’m doing is smart or dumb or even possible, but I would like someone to shed some light on this :slight_smile:

Well, if you use samples rather than detached_samples in your cost…

oh my god, that was dumb :laughing:

using

cost = -(target.log_prob(detached_samples) - dist.log_prob(detached_samples)).sum()

doesn’t solve the problem though :frowning: I think it’s impossible to do this because there is a direct dependence between log_prob and the parameters of Normal, detaching samples doesn’t actually forbid the gradients to flow up to log_std and mean.

Thanks for taking a look!

So one other thing is that you likely want detached_samples.requires_grad_().
So you want gradients w.r.t. detached_samples without affecting the params? Then I would recommend using torch.autograd.grad instead of .backward to selectively calculate gradients without side-effects.

So this question came from a stack overflow post of mine; The ultimate goal is to train the parameters of the distribution, without having to worry about
a) the type of distribution used, or
b) how the parameters of the distribution are generated (in this case, the std of the normal distribution is defined in terms of its log to ensure non-negative results)

In short, here this means I do not want to have to do std = e^(log_std) at any point other than the first time I define the distribution.

This is something that will work in Tensorflow, as the graphs are static there. It would seem the easiest way would be to retain some small parts of the graph in PyTorch, but this seems to not be possible. The above was a way of trying to achieve that. @tom If you have any suggestions on this I would greatly appreciate it!

Well yeah, but TF will do the computation over and over again.
So what you would do in PyTorch is to wrap parameter, transformation of the parameter and distribution in a nn.Module and use that.
A long while ago, I thought that blurring the boundary between nn.Modules and nn.Parameters was a good idea, I still think it would be, but obviously it doesn’t seem to be immediately pending.

Best regards

Thomas

Thank you @tom, do you happen to have a quick hints on implementing this? The main issue I see is if I pass a module in the place of a parameter, it will not be accesses in the same way (i.e. self.parameter vs self.parameter() ).

Thanks,
Michael

No, don’t use the Module for parameters (doesn’t work well, or we would not need the feature discussed in the linked abstraction), but have the nn.Module wrapping your distribution. You can have your module define log_prob and rsample if you want - except for calling some hooks that you might not need between model() and model.forward(), forward isn’t special.
In a way, this is similar to the what is nn.Module tutorial - nn.Modules are containers for the parameters and wrap the calculation we do with them.

Best regards

Thomas

Constructing new distribution instances is cheap, and you can use the constraints library to automatically handle the transformation of the std parameter to (0, inf), as below.

Code:

import torch
import torch.distributions as dd
from torch.distributions import biject_to

params = {'loc': torch.tensor([1.], requires_grad=True),
          'scale': torch.tensor([1.], requires_grad=True)}

optim = torch.optim.Adam(params.values(), lr=0.001)
target = dd.Normal(5, 5)


def get_dist():
    return dd.Normal(**get_constrained_params())


def get_constrained_params():
    # apply automatic constraints to distribution parameters.
    constraints = target.arg_constraints
    constrained_params = {}
    for k, v in params.items():
        constrained_params[k] = biject_to(constraints[k])(v)
    return constrained_params


for i in range(10000):
    optim.zero_grad()
    dist = get_dist()
    samples = dist.rsample((1000,))
    cost = -(target.log_prob(samples) - dist.log_prob(samples)).sum()
    cost.backward()
    optim.step()
    print(i, get_constrained_params())

Does this satisfy what you had in mind?

2 Likes

That’s awesome, @neerajprad! I’d probably stick the things into a nn.Module to hold the parameters. :slight_smile:

Best regards

Thomas

1 Like

Based on @tom’s suggestion, here is a modified attempt that uses nn.Module :smile:

import torch
import torch.distributions as dd
import torch.nn as nn
from torch.distributions import biject_to


class DistributionModule(nn.Module):
    def __init__(self, dist, arg_constraints, **params):
        super(DistributionModule, self).__init__()
        self.dist_class = dist
        self.arg_constraints = arg_constraints
        for k, v in params.items():
            self.register_parameter(k, nn.Parameter(v))

    def get_dist(self):
        constrained_params = dict(self.get_constrained_params())
        return self.dist_class(**constrained_params)

    def log_prob(self, sample):
        return self.get_dist().log_prob(sample)

    def rsample(self, sample_shape=()):
        return self.get_dist().rsample(sample_shape)

    def forward(self, value):
        return self.log_prob(value)

    def get_constrained_params(self):
        for name, param in self.named_parameters():
            yield name, biject_to(self.arg_constraints[name])(param)


params = {'loc': torch.tensor([1.]),
          'scale': torch.tensor([1.])}
target = dd.Normal(5, 5)
d = DistributionModule(dd.Normal, target.arg_constraints, **params)
optim = torch.optim.Adam(d.parameters(), lr=0.001)


for i in range(10000):
    optim.zero_grad()
    samples = d.rsample((1000,))
    cost = -(target.log_prob(samples) - d.log_prob(samples)).sum()
    cost.backward()
    optim.step()
    print(i, dict(d.get_constrained_params()))

Not tested very well, but I think you should be able to use a generic wrapper like this for most distribution classes! The forward method just delegates to .log_prob (though it is not being used here).

2 Likes