How to use transformed parameters in pytorch model (Module)?

(This question is also posted on https://datascience.stackexchange.com/q/80352/60226. With permission, I will propagate here any answers I would get there).

Sometimes we might want to limit the domain of the parameter from the whole real, to e.g. positive reals. It may be accomplished by using a transform like

par -> par**2 + epsilon

where epsilon is a constant governing how close to zero the value can land during the training process.

(Similarly, one can use torch.sigmoid function to the domain of the parameter from both ends.)

The problem is that from the code logic, the concept of transformed parameter belongs to a definition (initialization) of that parameter, whereas the actual function that does the transform must be called in the forward overload, otherwise we will get

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

when calling the forward for the second time. E.g.

class GenericModule(torch.nn.Module):
    def __init__(self, weight):
        super(GenericModule, self).__init__()
        self.weight = weight
            
    def forward(self,X):
        return X * self.weight

class SpecificModule1(torch.nn.Module):
    def __init__(self):
        super(SpecificModule1, self).__init__()
        epsilon = 0.001
        self.specific_coef = torch.nn.Parameter(torch.Tensor([0.]))
        self.worker = GenericModule(self.specific_coef**2 + epsilon)
    def forward(self,X):
        return self.worker(X)
>>> model = SpecificModule1()
>>> list(model.named_parameters())
[('specific_coef',
  Parameter containing:
  tensor([0.], requires_grad=True))]
>>> X = 1.
>>> optim = torch.optim.RMSprop(model.parameters())
>>> loss=model(X)
>>> loss.backward()
>>> loss=model(X)
>>> loss.backward()
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-114-52a0569421b1> in <module>
----> 1 loss.backward()

~/venv/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    183                 products. Defaults to ``False``.
    184         """
--> 185         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    186 
    187     def register_hook(self, hook):

~/venv/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
    123         retain_graph = create_graph
    124 
--> 125     Variable._execution_engine.run_backward(
    126         tensors, grad_tensors, retain_graph, create_graph,
    127         allow_unreachable=True)  # allow_unreachable flag

RuntimeError: Trying to backward through the graph a second time, but the saved intermediate results have already been freed. Specify retain_graph=True when calling backward the first time.

The usual way of dealing with such a situation is to hardwire the transformation into the forward function of GenericModule. Unfortunately, this has two problems.

  1. Hardwiring the transform into the child’s forward function is against DRY (Don’t Repeat Yourself) principle, because you may have to additionally spawn GenericModuleWithRightBoundedPar that will not be able to reuse the code in the GenericModule.
  2. This will ban you from sharing (or just forwarding of) this transformed parameter with another module, child of GenericModule, unless you pass the transformed parameter in the forward method in GenericModule to its child - which is very awkward indeed because the child must be aware beforehand, that the parameter it uses may be transformed, so it accepts it as a parameter in its forward method rather than initialize it in the constructor following typical design patterns of writing Modules.

The questions:

  1. Did I miss some simple solution of this problem?
  2. Is this problem already solved by someone, so I can just import his library that derives his own classes from torch.nn.Parameter and torch.nn.Module, which supports transformed parameters?