Detach rsamples from distributions parameters

Based on @tom’s suggestion, here is a modified attempt that uses nn.Module :smile:

import torch
import torch.distributions as dd
import torch.nn as nn
from torch.distributions import biject_to


class DistributionModule(nn.Module):
    def __init__(self, dist, arg_constraints, **params):
        super(DistributionModule, self).__init__()
        self.dist_class = dist
        self.arg_constraints = arg_constraints
        for k, v in params.items():
            self.register_parameter(k, nn.Parameter(v))

    def get_dist(self):
        constrained_params = dict(self.get_constrained_params())
        return self.dist_class(**constrained_params)

    def log_prob(self, sample):
        return self.get_dist().log_prob(sample)

    def rsample(self, sample_shape=()):
        return self.get_dist().rsample(sample_shape)

    def forward(self, value):
        return self.log_prob(value)

    def get_constrained_params(self):
        for name, param in self.named_parameters():
            yield name, biject_to(self.arg_constraints[name])(param)


params = {'loc': torch.tensor([1.]),
          'scale': torch.tensor([1.])}
target = dd.Normal(5, 5)
d = DistributionModule(dd.Normal, target.arg_constraints, **params)
optim = torch.optim.Adam(d.parameters(), lr=0.001)


for i in range(10000):
    optim.zero_grad()
    samples = d.rsample((1000,))
    cost = -(target.log_prob(samples) - d.log_prob(samples)).sum()
    cost.backward()
    optim.step()
    print(i, dict(d.get_constrained_params()))

Not tested very well, but I think you should be able to use a generic wrapper like this for most distribution classes! The forward method just delegates to .log_prob (though it is not being used here).

2 Likes