Share parameter between two modules in a nn.Sequential

I’m trying to share a single parameter between two modules used in a nn.Sequential. Each module uses the parameter differently. For example, imagine a function, parametrized by w:

F(x) = x+w

Let’s say I’m trying to find w so that x + w - w/2 is minimal, I can express this as a Sequential(F(w), F(-w/2)).

This is what I tried:

import torch
import torch.nn as nn
import torch.optim as optim

class F(nn.Module):
    def __init__(self, w):
        super().__init__()
        self.w = w

    def forward(self, x):
        return x + self.w

w = nn.Parameter(torch.tensor([5.0]), requires_grad=True)

model = nn.Sequential(F(w), F(-w/2))
optimizer = optim.Adam([w], lr=1e-1)

for epoch in range(10):
    loss = torch.pow(model(torch.tensor([2.0])) - 10.0, 2)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(epoch, loss.item())

and I’m getting:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

Is there a trick to make this kind of parameter sharing work? Of course this is a minimal example, my actual function is much more complex.

The issue is caused by performing a differentiable operation on w in: F(-w/2). The internal self.w attribute won’t be a trainable nn.Parameter anymore but is the output of the division.

You would work around this issue e.g. by passing a scale argument to the __init__:

class F(nn.Module):
    def __init__(self, w, scale=1.0):
        super().__init__()
        self.w = w
        self.scale = scale

    def forward(self, x):
        return x + self.w * self.scale

w = nn.Parameter(torch.tensor([5.0]), requires_grad=True)

model = nn.Sequential(F(w), F(w, scale=-0.5))
optimizer = optim.Adam([w], lr=1e-1)

for epoch in range(10):
    loss = torch.pow(model(torch.tensor([2.0])) - 10.0, 2)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    print(epoch, loss.item())
1 Like