Generating one model's parameters with another model

I’m trying to generate one model’s parameters (ActualModel) with another model (ParameterModel), but running into problems with autograd when I backpropagate multiple times.

Here’s an example ActualModel, but this is supposed to be generic:

class ActualModel(torch.nn.Module):
    def __init__(self):
        self.conv = torch.nn.Conv2d(1, 1, 3)
    def forward(self, x):
        return self.conv(x)

The ParameterModel wraps the ActualModel, freezes its parameters and in the forward() method, first computes the parameters, sets the values inside ActualModel and then computes ActualModel.forward():

class ParameterModel(torch.nn.Module):
    def __init__(self, actual_model):
        # Save actual_model in a list so its parameters aren't auto-registered
        self.actual_model = [actual_model]
        # Compute total number of parameters of the actual_model
        num_parameters = np.sum([p.numel() for p in actual_model.parameters()])
        # Freeze parameters of actual_model
        for p in actual_model.parameters():
        # Simple model to generate num_parameters parameters
        self.linear = torch.nn.Linear(1, num_parameters)
    def forward(self, param_input, actual_input):
        # Compute the parameters for the actual model
        param_output = self.linear(param_input)
        # Set the parameters in the actual model
        idx = 0
        for p in self.actual_model[0].parameters():
            # p.set_(x) derivative is not implemented, use p *= 0, p += x instead
            idx += p.numel()
        # Compute the output of the actual model
        return self.actual_model[0](actual_input)

However, when I use this in a simplified training loop, on the second epoch I get an error about backwarding through the graph a second time, even though it should be a new graph in the second epoch.

Training loop:

actual_model = ActualModel()
wrapped_model = ParameterModel(actual_model)
optim = torch.optim.Adam(wrapped_model.parameters(), lr=0.07)

for i in range(2):
    print(f'Epoch {i}')
    output = wrapped_model(torch.ones(1), torch.ones(1, 1, 3, 3))
    loss = output - torch.ones(1, 1, 3, 3)

Output with error message:

Epoch 0
Epoch 1
RuntimeError: Trying to backward through the graph a second time, 
but the saved intermediate results have already been freed. Specify 
retain_graph=True when calling backward the first time.

Any help on what I’m doing wrong here would be much appreciated.


This is a toy code (which is exactly similar to your example, but much simpler for the sake of sanity!).

class TmpModel(torch.nn.Module):
    def __init__(self):
        self.x = [nn.Parameter(torch.Tensor(1))]
        self.x[0].requires_grad = False
        self.y = nn.Parameter(torch.Tensor(1))
        self.z = nn.Parameter(torch.Tensor(1))
    def forward(self, i1):

        o1 = self.y * i1
        return self.x[0] * self.z
for i in range(2):
    i1 = torch.ones(1)
    out = model_new(i1)

If you traverse the autograd graph (by repeatedly using out.grad_fn.next_functions). You will find the below autograd graph. First image is the computation graph for iteration 0 & second image is the computation graph for iteration 1.

Computation graph 1

Computation graph 2

The dashed “red” box has the old nodes of i1 and self.y. Why? This is due to how in-place operation combines with autograd (see this post : link ). Since the computation graph 2 has old values as one of the nodes, you get the RunTime error.

Just for reference, the things I used to debug this was using .grad_fn.next_functions & running the code within a with torch.autograd.set_detect_anomaly(True) block. You can ping me if you need any more details regarding debugging. :slight_smile:

@Rahul_Chand Thank you very much for the detailed analysis and for explaining how you were able to debug the issue! I understand now where the problem comes from, the zero_ inplace operation does not make autograd “forget” the previous history even though it will no longer be relevant for computing the gradient value.

In your simplified example, I could fix this problem by replacing the in-place operation with the following: self.x[0] = o1

However, I have not been able to to transfer this to my version where I have a wrapped model with multiple Parameters inside where I don’t know the exact structure (wanting to keep this generic), so can only access it through either actual_model.parameters() or actual_model.state_dict().

Do you have any idea how I could achieve this?