Based on @tom’s suggestion, here is a modified attempt that uses nn.Module
import torch
import torch.distributions as dd
import torch.nn as nn
from torch.distributions import biject_to
class DistributionModule(nn.Module):
def __init__(self, dist, arg_constraints, **params):
super(DistributionModule, self).__init__()
self.dist_class = dist
self.arg_constraints = arg_constraints
for k, v in params.items():
self.register_parameter(k, nn.Parameter(v))
def get_dist(self):
constrained_params = dict(self.get_constrained_params())
return self.dist_class(**constrained_params)
def log_prob(self, sample):
return self.get_dist().log_prob(sample)
def rsample(self, sample_shape=()):
return self.get_dist().rsample(sample_shape)
def forward(self, value):
return self.log_prob(value)
def get_constrained_params(self):
for name, param in self.named_parameters():
yield name, biject_to(self.arg_constraints[name])(param)
params = {'loc': torch.tensor([1.]),
'scale': torch.tensor([1.])}
target = dd.Normal(5, 5)
d = DistributionModule(dd.Normal, target.arg_constraints, **params)
optim = torch.optim.Adam(d.parameters(), lr=0.001)
for i in range(10000):
optim.zero_grad()
samples = d.rsample((1000,))
cost = -(target.log_prob(samples) - d.log_prob(samples)).sum()
cost.backward()
optim.step()
print(i, dict(d.get_constrained_params()))
Not tested very well, but I think you should be able to use a generic wrapper like this for most distribution classes! The forward method just delegates to .log_prob
(though it is not being used here).