What I want is to update the alpha (size of clients size) after each epoch. However, the alpha is not properly updated when multiplying each param with the weight (calculated from the softmax of alpha) and creating a new model using load_state_dict(new_state_dict). If there is any way I can update this alpha.
import torch
import torch.nn.functional as F
import torch.optim as optim
# Define the models and the trainable parameter alpha
model1 = torch.nn.Linear(10, 1)
model2 = torch.nn.Linear(10, 1)
model3 = torch.nn.Linear(10, 1)
model4 = torch.nn.Linear(10, 1)
alpha = torch.ones(4, requires_grad=True)
# Define the optimizer and pass in alpha as a parameter to optimize
optimizer = optim.SGD([{'params': [alpha]}], lr=1)
# Define a function to calculate the weighted average of the models
def weighted_average(alpha, models):
weights = F.softmax(alpha, dim=0)
new_state_dict = {}
for i, model in enumerate(models):
for name, param in model.state_dict().items():
if i == 0:
new_state_dict[name] = weights[i] * param.clone()
else:
new_state_dict[name] += weights[i] * param.clone()
avg_model = model1.__class__(10, 1)
avg_model.load_state_dict(new_state_dict)
return avg_model
# Train the model and update alpha during training
for i in range(1000):
# Calculate the weighted average of the models
models = [model1, model2, model3, model4]
y = weighted_average(alpha, models)(torch.randn(1, 10))
# Calculate the loss and perform backpropagation
loss = torch.nn.functional.mse_loss(y, torch.randn(1, 1))
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Print the loss and alpha
if i % 100 == 0:
print("Iteration {}: Loss = {}, Alpha = {}".format(i, loss.item(), alpha.detach().numpy()))
The end goal is to implement the L2c algorithm from this paper: Learning To Collaborate in Decentralized Learning of Personalized Models. Here is the algorithm and I am struggling with updating alpha(L2x Parameter) using Loss of validation dataset: Line 18