I develop differentiable simulations with PyTorch and I want to share weights (torch.nn.Parameter
's) between multiple components (torch.nn.Module
's). The crux is that one components should use the inverse (negative) weights of another.
Consider the following Module
class:
class Component(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.tensor(1.)
def forward(self, x):
return self.attr * x
Now if I have two such components c1
and c2
I could do “normal” weight sharing easily via
c1.attr = c2.attr = torch.nn.Parameter(c1.attr)
However I need that c2.attr == -c1.attr
at all times. If I do
c1.attr = torch.nn.Parameter(c1.attr)
c2.attr = -c1.attr
then the operation -c1.attr
creates a new node in the graph and doesn’t refer to the original parameter anymore. If I now use the two tensors c1.attr
and c2.attr
during multiple epochs the value of c2.attr
won’t update since it is “stuck” at the initially captured value of -c1.attr
.
What I do right now is update all these weight sharing relationships after each call to optimizer.step()
. This however moves the logic of my model outside of its own definition and requires the user to not forget to call that update method every epoch, otherwise the weight sharing is ineffective.
Hence I would like to ask whether there exists a way for specifying some kind of deferred operations. That is operations that are only evaluated when used elsewhere in another computation. So for example:
c2.attr = defer_mul(-1, c1.attr)
would compute -1 * c1.attr
only when requested in another computation (and not cache any results). So x * c2.attr
would result in x * (-1 * c1.attr)
using the up-to-date value of c1.attr
every time c2.attr
is used.
Here is a more complete code example to illustrate the situation:
import torch
class Component(torch.nn.Module):
def __init__(self):
super().__init__()
self.attr = torch.tensor(1.)
def forward(self, x):
return self.attr * x
class Model(torch.nn.Module):
def __init__(self):
super().__init__()
self.config = torch.tensor(-2.)
self.c1 = Component()
self.c2 = Component()
self.c1.attr = torch.nn.Parameter(self.c1.attr)
# self.c2.attr = -self.c1.attr # stays at `attr = -1` forever
self.update() # update manually instead
def forward(self):
return self.c2(self.c1(self.config))
def update(self):
self.c2.attr = -self.c1.attr
model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
expected = torch.tensor(32.)
for epoch in range(25):
optimizer.zero_grad()
value = model()
loss = (value - expected) ** 2
loss.backward()
optimizer.step()
model.update() # need to update weight sharing relationships at every epoch
print(f'c1 = {model.c1.attr}, c2 = {model.c2.attr}')