Parametrized Gaussian shared same trainable parameter

Hi all, I want to parametrize two Gaussian distributions, but their parameters are related. But it seems like that the variable will be freed during the training. What can I do to fix this problem?
Here is the minimum code. Thank you!

import torch.nn as nn
import torch.distributions as D
class GaussianModel(nn.Module):
    def __init__(self):
        super(GaussianModel, self).__init__()
        self.v1 = nn.Parameter(torch.zeros(1))
        self.v2 = self.v1*2
        self.g1=D.Normal(torch.tensor([4.0]), self.v1)
        self.g2=D.Normal(torch.tensor([4.0]), self.v2)
    def forward(self,x):
        return torch.mean(self.g1.log_prob(x)+self.g2.log_prob(x))
model=GaussianModel()
opt=torch.optim.SGD(params=model.parameters(), lr=1e-3)
for i in range(10):
    opt.zero_grad()
    x=torch.randn(10)
    loss=model(x)
    loss.backward()
    opt.step()

And this is the message from terminal:

RuntimeErrorTraceback (most recent call last)
<ipython-input-92-46e3175cfeff> in <module>
      3     x=torch.randn(10)
      4     loss=model(x)
----> 5     loss.backward()
      6     opt.step()
      7 #     model.p1.detach()

~/.virtualenvs/ddp-pytorch/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    219                 retain_graph=retain_graph,
    220                 create_graph=create_graph)
--> 221         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    222 
    223     def register_hook(self, hook):

~/.virtualenvs/ddp-pytorch/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    130     Variable._execution_engine.run_backward(
    131         tensors, grad_tensors_, retain_graph, create_graph,
--> 132         allow_unreachable=True)  # allow_unreachable flag
    133 
    134 

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

Hi Iam!

The problem is that self.v2 = self.v1*2 creates a new tensor
for v2, and does so only once at __init__() time. Subsequent
updates to v1 (made by the optimizer) never get reflected in v2
(nor in g2).

(I don’t really understand how this leads to the specific autograd
RuntimeError you see, but this is the key error in your code.)

I would probably just do:

class GaussianModel (nn.Module):
    def __init__ (self):
        super (GaussianModel, self).__init__()
        self.v1 = nn.Parameter (torch.zeros (1))
    def forward (self, x):
        return torch.mean (D.Normal (torch.tensor ([4.0]), self.v1).log_prob (x) + D.Normal (torch.tensor ([4.0]), 2 * self.v1).log_prob (x))

I believe you could also implement this as:

class GaussianModel (nn.Module):
    def __init__ (self):
        super (GaussianModel, self).__init__()
        self.v1 = nn.Parameter (torch.zeros (1))
        self.v2 = self.v1 * 2   # these initial values will be overwritten
        self.g1=D.Normal (torch.tensor ([4.0]), self.v1)
        self.g2=D.Normal (torch.tensor ([4.0]), self.v2)
    def forward (self, x):
        self.v2.copy_ (self.v1 * 2)
        return torch.mean (self.g1.log_prob (x)+self.g2.log_prob (x))

but, to me, my first version is more readable.

As an aside, you are initializing g1 and g2 with a scale (standard
deviation) of 0.0. I would expect this to lead to nans in your first
forward pass. Up-to-date versions of pytorch will flag this:

>>> import torch
>>> torch.__version__
'1.9.0'
>>> d = torch.distributions.Normal (torch.tensor ([4.0]), torch.zeros (1), validate_args = False)
>>> d
Normal(loc: tensor([4.]), scale: tensor([0.]))
>>> d.log_prob (torch.tensor ([1.0]))
tensor([nan])
>>> d = torch.distributions.Normal (torch.tensor ([4.0]), torch.zeros (1))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\LisaBrown\Documents\admin\programs\Miniconda3\lib\site-packages\torch\distributions\normal.py", line 50, in __init__
    super(Normal, self).__init__(batch_shape, validate_args=validate_args)
  File "C:\Users\LisaBrown\Documents\admin\programs\Miniconda3\lib\site-packages\torch\distributions\distribution.py", line 53, in __init__
    raise ValueError("The parameter {} has invalid values".format(param))
ValueError: The parameter scale has invalid values

Best.

K. Frank

1 Like

Thank you Frank! It works!