# Computing Gradients of Weighted State Averages in PyTorch Models

Hi,

I have two trained models with the same architecture and different performances. I want to construct a new model by taking the weighted average of their states (s1, s2, s1+s2=1) and calculate the gradients of the new model’s loss with respect to s1 and s2. Here’s my current approach:

1. Create a simple model:

``````import torch
import torch.nn as nn

class MultiplyByWeight(nn.Module):
def __init__(self, input_size, output_size):
super(MultiplyByWeight, self).__init__()
self.linear = nn.Linear(input_size, output_size)

def forward(self, x):
return self.linear(x)
``````
2. Define the states:

``````state_1 = {
'linear.weight': 2.4,
'linear.bias': 0.1
}
state_2 = {
'linear.weight': 0,
'linear.bias': -0.5,
}
``````
3. Set the weights of the states:

``````state_weights = torch.tensor([0.5, 0.5], requires_grad=True)
``````
4. Set the weighted average of the states as the new model’s state. I encountered a problem here: when I update the state_dict or modify the named_parameters, the tensors are treated as leaf nodes and detached, losing the computation graph needed to calculate gradients with respect to state_weights.

``````d = model.state_dict()
for k in d:
d[k] = d[k] + state_1[k] * state_weights[0] + state_2[k] * state_weights[1]
``````
• With named_parameters:

``````for k, param in model.named_parameters():
``````

``````def loss_pt(y_hat, y):
return (y_hat - y) ** 2

r = model(torch.tensor([1.]))
l = loss_pt(r, 2)
l.backward()
You might be able to use the `functional_call` approach as explained here.