How to backprop through a network that synthesizes another network dynamically?

I’m trying to build a synthesizer network that outputs the weights and biases of another network, given some input-output examples. The examples are generated by a reference network and the idea is to train a synthesizer that can infer the reference distribution.

Here’s a diagram that shows how the component networks are connected:

The challenge lies in how to perform backpropagation from the synthesized network through the synthesizer. I can’t find a way to fill in that gap between the synthesizer and synthesized networks during backpropagation. My current attempt that fails with the message:

---> ref_net.fc1.weight.set_(nn.Parameter(pred_net[:, 0:2].view(2, 1)))
     ref_net.fc1.bias.set_(nn.Parameter(pred_net[:, 2:4]))
     ref_net.fc2.weight.set_(nn.Parameter(pred_net[:, 4:6].view(1, 2)))

RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.

I tried other variants besides set_ like direct assignments of the weights and biases. However, they all cut off the backpropagation process such that training cannot proceed and the loss remains constant. Here’s the entire code snippet:

NUM_EXAMPLES = 10
RANGE = 10

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class ReferenceNet(nn.Module):
    def __init__(self):
        super(ReferenceNet, self).__init__()
        self.fc1 = nn.Linear(1, 2, bias=True)  # First layer: 1 input, 2 hidden nodes
        self.fc2 = nn.Linear(2, 1, bias=True)  # Second layer: 2 hidden nodes, 1 output

    def forward(self, x):
        x = F.relu(self.fc1(x))  # ReLU activation on the first layer
        x = self.fc2(x)  # No activation on the output layer
        return x

# Instantiate the test network
ref_net = ReferenceNet()

# Initialize weights randomly (this is done by default in PyTorch but shown here explicitly for clarity)
for param in ref_net.parameters():
    nn.init.normal_(param, mean=0.0, std=0.1)

class SynthesizerNet(nn.Module):
    def __init__(self):
        super(SynthesizerNet, self).__init__()
        self.fc1 = nn.Linear(NUM_EXAMPLES * 2, NUM_EXAMPLES * 1)
        self.fc2 = nn.Linear(NUM_EXAMPLES * 1, 7)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Instantiate the reason network
syn_net = SynthesizerNet()

# Training loop
optimizer = optim.Adam(syn_net.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Create random input-output pairs from ReferenceNet
inputs = torch.rand(NUM_EXAMPLES, 1) * RANGE * 2 - RANGE
print(f"inputs.shape: {inputs.shape}")
with torch.no_grad():
    outputs = ref_net(inputs)
print(f"outputs.shape: {outputs.shape}")

# Prepare data by concatenating inputs and outputs into a flat vector
train_data = torch.cat((inputs, outputs), dim=0)
print(f"train_data.shape: {train_data.shape}")

# Re-initialize reference network for training
ref_net = ReferenceNet()

print(f"outputs: {outputs}")

for epoch in range(500):
    optimizer.zero_grad()
    # synthesize a network based on input output pairs
    pred_net = syn_net(train_data.view(1, NUM_EXAMPLES * 2))

    print(f"pred_net: {pred_net}")

    # set the weights of the reference network to the synthesized network
    ref_net.fc1.weight.set_(nn.Parameter(pred_net[:, 0:2].view(2, 1)))
    ref_net.fc1.bias.set_(nn.Parameter(pred_net[:, 2:4]))
    ref_net.fc2.weight.set_(nn.Parameter(pred_net[:, 4:6].view(1, 2)))
    ref_net.fc2.bias.set_(nn.Parameter(pred_net[:, 6:7]))
    
    inputs = train_data[:NUM_EXAMPLES].view(NUM_EXAMPLES, 1)
    outputs_pred = ref_net(inputs)

    print("outputs_pred: ", outputs_pred)

    loss = criterion(outputs, outputs_pred)
    loss.backward()
    optimizer.step()
    print(f'Epoch [{epoch+1}/500], Loss: {loss.item():.4f}')

I solved this problem in a not-so-elegant way by repeating the ReferenceNet layout using F.linear and F.relu:

class FusedNet(nn.Module):
    def __init__(self, syn_net):
        super(FusedNet, self).__init__()
        self.syn_net = syn_net

    def forward(self, train_data):
        inputs = train_data[:NUM_EXAMPLES].view(NUM_EXAMPLES, 1)
        
        # Use syn_net to obtain the weights and biases
        pred_net = self.syn_net(train_data.view(1, NUM_EXAMPLES * 2))
        
        # Manually apply these weights to the layers in ref_net
        x = F.linear(inputs, weight=pred_net[:, 0:2].view(2, 1), bias=pred_net[:, 2:4].view(2))
        x = F.relu(x)
        x = F.linear(x, weight=pred_net[:, 4:6].view(1, 2), bias=pred_net[:, 6:7].view(1))

        return x