Computing Gradients of Weighted State Averages in PyTorch Models

Hi,

I have two trained models with the same architecture and different performances. I want to construct a new model by taking the weighted average of their states (s1, s2, s1+s2=1) and calculate the gradients of the new model’s loss with respect to s1 and s2. Here’s my current approach:

  1. Create a simple model:

    import torch
    import torch.nn as nn
    
    class MultiplyByWeight(nn.Module):
        def __init__(self, input_size, output_size):
            super(MultiplyByWeight, self).__init__()
            self.linear = nn.Linear(input_size, output_size)
            
        def forward(self, x):
            return self.linear(x)
    
  2. Define the states:

    state_1 = {
        'linear.weight': 2.4,
        'linear.bias': 0.1
    }
    state_2 = {
        'linear.weight': 0,
        'linear.bias': -0.5,
    }
    
  3. Set the weights of the states:

    state_weights = torch.tensor([0.5, 0.5], requires_grad=True)
    
  4. Set the weighted average of the states as the new model’s state. I encountered a problem here: when I update the state_dict or modify the named_parameters, the tensors are treated as leaf nodes and detached, losing the computation graph needed to calculate gradients with respect to state_weights.

    • With load_state_dict:

      d = model.state_dict()
      for k in d:
          d[k] = d[k] + state_1[k] * state_weights[0] + state_2[k] * state_weights[1]
      model.load_state_dict(d)
      
    • With named_parameters:

      for k, param in model.named_parameters():
          param.data.add_(state_1[k] * state_weights[0])
          param.data.add_(state_2[k] * state_weights[1])
      
  5. Calculate loss and gradients:

    def loss_pt(y_hat, y):
        return (y_hat - y) ** 2
    
    r = model(torch.tensor([1.]))
    l = loss_pt(r, 2)
    l.backward()
    state_weights.is_leaf, state_weights.grad
    # (True, None)
    

I would greatly appreciate any help or suggestions on how to tackle this issue. I tried searching for a quick workaround in the PyTorch source code, but it doesn’t seem to be a straightforward approach.

You might be able to use the functional_call approach as explained here.

1 Like