I want to create a learnable parameter when aggregrate several models together

What I want is to update the alpha (size of clients size) after each epoch. However, the alpha is not properly updated when multiplying each param with the weight (calculated from the softmax of alpha) and creating a new model using load_state_dict(new_state_dict). If there is any way I can update this alpha.

import torch
import torch.nn.functional as F
import torch.optim as optim

# Define the models and the trainable parameter alpha
model1 = torch.nn.Linear(10, 1)
model2 = torch.nn.Linear(10, 1)
model3 = torch.nn.Linear(10, 1)
model4 = torch.nn.Linear(10, 1)
alpha = torch.ones(4, requires_grad=True)

# Define the optimizer and pass in alpha as a parameter to optimize
optimizer = optim.SGD([{'params': [alpha]}], lr=1)

# Define a function to calculate the weighted average of the models
def weighted_average(alpha, models):
    weights = F.softmax(alpha, dim=0)
    new_state_dict = {}
    for i, model in enumerate(models):
        for name, param in model.state_dict().items():
            if i == 0:
                new_state_dict[name] = weights[i] * param.clone()
            else:
                new_state_dict[name] += weights[i] * param.clone()
    avg_model = model1.__class__(10, 1)
    avg_model.load_state_dict(new_state_dict)
    return avg_model

# Train the model and update alpha during training
for i in range(1000):
    # Calculate the weighted average of the models
    models = [model1, model2, model3, model4]
    y  = weighted_average(alpha, models)(torch.randn(1, 10))

    # Calculate the loss and perform backpropagation

    loss = torch.nn.functional.mse_loss(y, torch.randn(1, 1))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    # Print the loss and alpha
    if i % 100 == 0:
        print("Iteration {}: Loss = {}, Alpha = {}".format(i, loss.item(), alpha.detach().numpy()))

The end goal is to implement the L2c algorithm from this paper: Learning To Collaborate in Decentralized Learning of Personalized Models. Here is the algorithm and I am struggling with updating alpha(L2x Parameter) using Loss of validation dataset: Line 18

In your current approach the usage of weight is not captured in a differentiable operation since you are manipulating the state_dict directly and load it afterwards.
You could check this post to see if a parametrization would work.

Thank you for your comment! I tried and it works for setting a tensor to a key. I am trying to do it to all parameters in a model (i.e. both “weight” and “bias” in nn.Linear). How can I do that?

class Average(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_clients = 2
        self.alpha = torch.ones(self.num_clients)
        self.alpha.requires_grad = True
        
    def set_reprs(self,reprs):
        self.reprs = reprs
        
    def forward(self, X):
        self.weight = F.softmax(self.alpha, dim=0)
        
        for i in range(self.num_clients):
            if i == 0:
                result = self.weight[0]*self.reprs[0]
            else:
                result += self.weight[i]*self.reprs[i]
        
        return result


layer1 = nn.Linear(10, 10)
layer2 = nn.Linear(10, 10)
weight = Average()
models = [layer1,layer2]
layer3 = nn.Linear(10, 10)
for i in range(10):
    weight.set_reprs([layer1.weight,layer2.weight])
    parametrize.register_parametrization(layer3,"weight", weight)
    #parametrize.register_parametrization(layer3,"bias", weight)
    optimizer = torch.optim.Adam([weight.alpha], lr=1.)

    x = torch.randn(1, 10)
    out = layer3(x)
    out.mean().backward()

    optimizer.step()

    print(i)
    print(weight.alpha)
    print(weight.weight)

Here is what I want to do but didn’t work

class Average(nn.Module):
    def __init__(self):
        super().__init__()
        self.num_clients = 2
        self.alpha = torch.ones(self.num_clients)
        self.alpha.requires_grad = True
        
    def set_reprs(self,reprs):
        self.reprs = reprs
        
    def forward(self, X):
        self.weight = F.softmax(self.alpha, dim=0)
        
        for i in range(self.num_clients):
            if i == 0:
                result = self.weight[0]*self.reprs[0]
            else:
                result += self.weight[i]*self.reprs[i]
        
        return result


layer1 = nn.Linear(10, 10)
layer2 = nn.Linear(10, 10)
weight = Average()
models = [layer1,layer2]
layer3 = nn.Linear(10, 10)
for i in range(10):

    weight.set_reprs([layer1.weight,layer2.weight])
    parametrize.register_parametrization(layer3,"weight", weight)
    weight.set_reprs([layer1.bias,layer2.bias])
    parametrize.register_parametrization(layer3,"bias", weight)
    #parametrize.register_parametrization(layer3,"bias", weight)
    optimizer = torch.optim.Adam([weight.alpha], lr=1.)

    x = torch.randn(1, 10)
    out = layer3(x)
    out.mean().backward()

    optimizer.step()

    print(i)
    print(weight.alpha)
    print(weight.weight)

RuntimeError: mat2 must be a matrix, got 1-D tensor

It seems you are overriding self.reprs with the bias so you might need to pass all trainable parameters (weight and bias) together.

Thank you so much for your reply. How should I do that? I searched through the internet but couldn’t find any similar use case. I don’t think I can both “weight” and “bias” in the tensor_name (str) variable (of register_parametrization) at the same time, especially in the case that there are many trainable parameters such as in ResNet.