Deferred operations for advanced weight sharing

I develop differentiable simulations with PyTorch and I want to share weights (torch.nn.Parameter's) between multiple components (torch.nn.Module's). The crux is that one components should use the inverse (negative) weights of another.

Consider the following Module class:

class Component(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attr = torch.tensor(1.)

    def forward(self, x):
        return self.attr * x

Now if I have two such components c1 and c2 I could do “normal” weight sharing easily via

c1.attr = c2.attr = torch.nn.Parameter(c1.attr)

However I need that c2.attr == -c1.attr at all times. If I do

c1.attr = torch.nn.Parameter(c1.attr)
c2.attr = -c1.attr

then the operation -c1.attr creates a new node in the graph and doesn’t refer to the original parameter anymore. If I now use the two tensors c1.attr and c2.attr during multiple epochs the value of c2.attr won’t update since it is “stuck” at the initially captured value of -c1.attr.

What I do right now is update all these weight sharing relationships after each call to optimizer.step(). This however moves the logic of my model outside of its own definition and requires the user to not forget to call that update method every epoch, otherwise the weight sharing is ineffective.

Hence I would like to ask whether there exists a way for specifying some kind of deferred operations. That is operations that are only evaluated when used elsewhere in another computation. So for example:

c2.attr = defer_mul(-1, c1.attr)

would compute -1 * c1.attr only when requested in another computation (and not cache any results). So x * c2.attr would result in x * (-1 * c1.attr) using the up-to-date value of c1.attr every time c2.attr is used.

Here is a more complete code example to illustrate the situation:

import torch


class Component(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.attr = torch.tensor(1.)

    def forward(self, x):
        return self.attr * x


class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.config = torch.tensor(-2.)
        self.c1 = Component()
        self.c2 = Component()

        self.c1.attr = torch.nn.Parameter(self.c1.attr)
        # self.c2.attr = -self.c1.attr  # stays at `attr = -1` forever
        self.update()  # update manually instead

    def forward(self):
        return self.c2(self.c1(self.config))

    def update(self):
        self.c2.attr = -self.c1.attr


model = Model()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

expected = torch.tensor(32.)
for epoch in range(25):
    optimizer.zero_grad()
    value = model()
    loss = (value - expected) ** 2
    loss.backward()
    optimizer.step()
    model.update()  # need to update weight sharing relationships at every epoch

    print(f'c1 = {model.c1.attr}, c2 = {model.c2.attr}')