I am using the torch.Distribution package inside of my Module.
However, since the distribution classes do not inherit from torch.Module, they won’t be saved when calling my_module.state_dict().
Is there an easy way to save the distributions?
import torch
class GmmNet(torch.nn.Module):
def __init__(self):
super(GmmNet, self).__init__()
# Example, random parameters
self.gmm = D.MixtureSameFamily(D.Categorical(torch.nn.parameter.Parameter(torch.ones(5,))), D.Normal(torch.randn(5,),(torch.rand(5,))))
def forward(self, x):
return self.gmm.log_prob(x)
print(GmmNet().state_dict())
# --> OrderedDict()
bigger issue is that Distribution objects are immutable, so you can’t push nn.Parameters inside them in __init__. Keep parameters in modules, create transient Distribution objects in forward, then saving & training will work correctly.