Creating an Inverse Gamma distribution in with torch.distributions

I’m looking to define an inverse gamma distribution using torch.distributions, similar to putting:

gamma_dist = torch.distributions.Gamma(alpha, beta)

I see that there does exist a transforms class from which one can create TransformedDistributions (refer to https://pytorch.org/docs/stable/_modules/torch/distributions/transforms.html#PowerTransform), however there is no 1/distribution class. How should one go about creating the Inverse gamma?

1 Like

I am a n00b to pytorch, but I was surprised that the code below fails:

import torchvision
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions import Gamma

def construct_inverse_gamma(alpha, beta):
	"""
	If X ~ Gamma(alpha, beta), then 1/X ~ InvGamma(alpha, beta)
	"""
	reciprocal = torchvision.transforms.Lambda(lambda x: 1/x)
	InverseGamma = TransformedDistribution(Gamma(alpha, beta), [reciprocal])
	return InverseGamma

InverseGamma=construct_inverse_gamma(1.0, 1.0)

The error message is:

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-67-d61ee5ba5c44> in <module>()
     11         return InverseGamma
     12 
---> 13 InverseGamma=construct_inverse_gamma(1.0, 1.0)

<ipython-input-67-d61ee5ba5c44> in construct_inverse_gamma(alpha, beta)
      8 	"""
      9         reciprocal = torchvision.transforms.Lambda(lambda x: 1/x)
---> 10         InverseGamma = TransformedDistribution(Gamma(alpha, beta), [reciprocal])
     11         return InverseGamma
     12 

/redacted_path/py2-env/lib/python2.7/site-packages/torch/distributions/transformed_distribution.pyc in __init__(self, base_distribution, transforms, validate_args)
     47         elif isinstance(transforms, list):
     48             if not all(isinstance(t, Transform) for t in transforms):
---> 49                 raise ValueError("transforms must be a Transform or a list of Transforms")
     50             self.transforms = transforms
     51         else:

ValueError: transforms must be a Transform or a list of Transforms

I need to do some experiments with inverse gamma in pytorch. I wonder if there have been any updates on this. any help will be greatly appreciated

the person above used torchvision transform. here is ready implementation: pyro/inverse_gamma.py at 9f67c43deab7f33216b426843292e5c5ce73df78 · pyro-ppl/pyro · GitHub