I am a n00b to pytorch, but I was surprised that the code below fails:
import torchvision
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions import Gamma
def construct_inverse_gamma(alpha, beta):
"""
If X ~ Gamma(alpha, beta), then 1/X ~ InvGamma(alpha, beta)
"""
reciprocal = torchvision.transforms.Lambda(lambda x: 1/x)
InverseGamma = TransformedDistribution(Gamma(alpha, beta), [reciprocal])
return InverseGamma
InverseGamma=construct_inverse_gamma(1.0, 1.0)
The error message is:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-67-d61ee5ba5c44> in <module>()
11 return InverseGamma
12
---> 13 InverseGamma=construct_inverse_gamma(1.0, 1.0)
<ipython-input-67-d61ee5ba5c44> in construct_inverse_gamma(alpha, beta)
8 """
9 reciprocal = torchvision.transforms.Lambda(lambda x: 1/x)
---> 10 InverseGamma = TransformedDistribution(Gamma(alpha, beta), [reciprocal])
11 return InverseGamma
12
/redacted_path/py2-env/lib/python2.7/site-packages/torch/distributions/transformed_distribution.pyc in __init__(self, base_distribution, transforms, validate_args)
47 elif isinstance(transforms, list):
48 if not all(isinstance(t, Transform) for t in transforms):
---> 49 raise ValueError("transforms must be a Transform or a list of Transforms")
50 self.transforms = transforms
51 else:
ValueError: transforms must be a Transform or a list of Transforms