I would like to use torch.distributions.MultivariateNormal
as a nn.Module
with self.loc
and self.scale_tril
as parameters.
Is there a clean way to achieve this ?
I wish to do something like this
class MultivariateNormal(Module, th.distributions.MultivariateNormal):
def __init__(self, loc: Tensor, scale_tril: Tensor):
Module.__init__(self) # needs to be called before adding Parameter()s
self._loc = Parameter(loc)
self._scale_tril = Parameter(scale_tril)
torch.distributions.MultivariateNormal.__init__(self, self._loc, scale_tril=self._scale_tril)
This is not possible in the current master branch as self.loc
uses expand [i.e., copies ] and self._unbroadcasted_scale_tril
uses assign thereby creating additional parameter.