Hello there,
My case is a bit complex. I have 8 networks, one of which is the global server for all of them.
My aim is to train my network for one epoch in those 7 networks, then try to take the weighted average of their gradients, and replace the aggregated gradients for the 8th network and make one step with aggregated gradients, also called Federated Learning. However, it does not work the way it is supposed to. My main questions are:
1- Can I take store each worker’s gradients using:
grads = []
for p in net.parameters():
grads.append(p.grad)
2- Also, can I aggregate gradients by using sbs_grad = copy.deepcopy(temp_grad)
to the first one, then adding on top of that using
for j in range(len(sbs_grad)):
sbs_grad[j] += temp_grad[j]
then dividing it by the number of workers?
3- Finally can I transfer gradients using:
loss.backward(retain_graph=True)
index = 0
for p in net.parameters():
p.grad = cluster_grads[index]
index = index + 1
optimizer.step()
Any help would be much appreciated.
Here is my trial for those who are interested:
On workers:
def HFL_worker_train(worker, device, optimizer, epoch, criterion, i, h, c):
print('\n FL Epoch: %d, Cluster %d, Local Iteration %d, User %d, ' % (epoch+1, c+1, h+1, i+1))
net = worker.worker_net
trainloader = worker.train_DataLoader
grads = []
net.train()
train_loss = 0
correct = 0
total = 0
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.backward()
loss.retain_grad()
for p in net.parameters():
grads.append(p.grad)
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
tr_acc = 100.*correct/total
Aggregating gradients:
def hfl_localtraining(H, hfl_network, device, epoch, criterion, learning_rate, hfl_worker_optimizer_list, hfl_sbs_optimizer_list, no_of_clusters, no_workers_in_cluster, c):
for i in range(no_workers_in_cluster):
for h in range(H):
temp_grad = HFL_worker_train(hfl_network.Clusters[c].Workers[i], device, hfl_worker_optimizer_list[c][i], epoch, criterion, i, h, c)
if(H - h == 1):
if(i == 0):
sbs_grad = copy.deepcopy(temp_grad)
else:
for j in range(len(sbs_grad)):
sbs_grad[j] += temp_grad[j]/no_workers_in_cluster
return sbs_grad
After training all workers:
def HFL_SBS_train(worker, device, optimizer, epoch, criterion, c, cluster_grads):
print('\n FL Epoch: %d, Cluster %d SBS' % (epoch+1, c+1))
net = worker.worker_net
net.train()
train_loss = 0
correct = 0
total = 0
trainloader = worker.train_DataLoader
for batch_idx, (inputs, targets) in enumerate(trainloader):
inputs, targets = inputs.to(device), targets.to(device)
optimizer.zero_grad()
outputs = net(inputs)
loss = criterion(outputs, targets)
loss.retain_grad()
loss.backward(retain_graph=True)
index = 0
for p in net.parameters():
p.grad = cluster_grads[index]
index = index + 1
optimizer.step()
train_loss += loss.item()
_, predicted = outputs.max(1)
total += targets.size(0)
correct += predicted.eq(targets).sum().item()
progress_bar(batch_idx, len(trainloader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'
% (train_loss/(batch_idx+1), 100.*correct/total, correct, total))
tr_acc = 100.*correct/total
worker.train_loss.append(train_loss/trainloader.batch_size)
worker.train_acc.append(100.*correct/total)
return worker.worker_net.state_dict()