How to use MultivariateNormal as nn.Module

I would like to use torch.distributions.MultivariateNormal as a nn.Module with self.loc and self.scale_tril as parameters.

Is there a clean way to achieve this ?

I wish to do something like this

class MultivariateNormal(Module, th.distributions.MultivariateNormal):
    def __init__(self, loc: Tensor, scale_tril: Tensor):
        Module.__init__(self)  # needs to be called before adding Parameter()s
        self._loc = Parameter(loc)
        self._scale_tril = Parameter(scale_tril)
        torch.distributions.MultivariateNormal.__init__(self, self._loc, scale_tril=self._scale_tril)

This is not possible in the current master branch as self.loc uses expand [i.e., copies ] and self._unbroadcasted_scale_tril uses assign thereby creating additional parameter.

1 Like