Saving torch.Distributions with state_dict()

I am using the torch.Distribution package inside of my Module.
However, since the distribution classes do not inherit from torch.Module, they won’t be saved when calling my_module.state_dict().

Is there an easy way to save the distributions?

import torch

class GmmNet(torch.nn.Module):
    def __init__(self):
        super(GmmNet, self).__init__()
        # Example, random parameters
        self.gmm = D.MixtureSameFamily(D.Categorical(torch.nn.parameter.Parameter(torch.ones(5,))), D.Normal(torch.randn(5,),(torch.rand(5,))))

    def forward(self, x):
        return self.gmm.log_prob(x)
        
print(GmmNet().state_dict())
# --> OrderedDict()

bigger issue is that Distribution objects are immutable, so you can’t push nn.Parameters inside them in __init__. Keep parameters in modules, create transient Distribution objects in forward, then saving & training will work correctly.

Thanks!

For anyone else, this is the working code:

import torch

class GmmNet(torch.nn.Module):
    def __init__(self):
        super(GmmNet, self).__init__()
        self.cat = torch.nn.parameter.Parameter(torch.ones(5,))
        self.cov = torch.nn.parameter.Parameter(torch.ones(5,))

    def forward(self, x):
        gmm = D.MixtureSameFamily(D.Categorical(self.cat), D.Normal(self.cov, self.cov))
        return gmm.log_prob(x)
        
print(GmmNet().state_dict())
# --> OrderedDict([('cat', tensor([1., 1., 1., 1., 1.])), ('cov', tensor([1., 1., 1., 1., 1.]))])