After finishing a single process version of FedAVG(a federated learning), I tried to apply torch.multiprocessing.Pool() and use apply_async to share and update the model. However, I got a way lower accuracy (from 0.8 to 0.1)in each client’s local model.
my original code is like this:
def update_server(self, T):
‘’’
FedAVG
‘’’
client_acc = []
# 2: for t=0, …, T-1 do
for round in range(T):
print(f"Round {round+1} started…")
client_models = []
client_losses = []
client_accs = []
delta_is_t = []
# --multi-pro pool
pool = mp.Pool(self._num_clients)
results = []
# --
# 3: TODO 2 : select a subset of client
# 4: x{i, 0}^t = x_t
x_t = self._global_model.state_dict()
# 5: for each client i \in S in parallel do
for i in range(self._num_clients):
#TODO 1 : do this in parallel
# x_{i, 0}^t = x_t : x_t is model from last epoch
x_t_temp = copy.deepcopy(x_t)
# cur_user._model.load_state_dict(x_t)
# 6-8: x{i, K}^t = CLIENTOPT
# --
# result = pool.apply_async(self._clients[i].client_update, args=(round+1,i, x_t_temp))
# x_i_K_t, client_loss, client_acc = \
# pool.apply_async(self._clients[i].client_update, args=(round+1,i, x_t_temp)).get()
# --
x_i_K_t, client_loss, client_acc = \
self._clients[i].client_update(epoch=round+1, id=i, global_model=x_t_temp)
# 9: \delta_{client i, epoch t} = x{i, K}^t - x_t
# delta_i_t = x_i_K_t - x_t
# delta_i_t = {}
# for key in x_i_K_t.keys():
# delta_i_t[key] = x_i_K_t[key] - x_t[key]
# --
# results.append(result)
# --
client_models.append(x_i_K_t)
client_losses.append(client_loss)
client_accs.append(client_acc)
# delta_is_t.append(delta_i_t)
# -- print(f"Client {i+1} loss: {client_loss:.4f}, accuracy {client_acc: .4f} ")
# --
# for result in results:
# client_state_dict, client_loss, client_acc = result.get()
# client_models.append(client_state_dict)
# client_losses.append(client_loss)
# client_accs.append(client_acc)
# --
# x_{t+1} = \frac{1}{S} \sum_{i \in S} \delta_i^t
global_state_dict = self.server_update(client_models)
self._global_model.load_state_dict(global_state_dict)
print(f"Round {round+1} finished, global loss: \
{sum(client_losses)/len(client_losses):.4f}, \
global accuracy: {sum(client_accs)/len(client_accs): .4f}")
#--
pool.close()
pool.join()
#--
return sum(client_accs)/len(client_accs)
my updated version only changed this:
ool.apply_async(self._clients[i].client_update, args=(round+1,i, x_t_temp)).get()
(I used .get() method for simplicity)
My client update functions stays same between two version:
class ClientBase():
‘’‘base class for FL learning’‘’
def init(self, dataloader, model, optimizer, device, E, B):
self._dataloader = dataloader
self._model = model
self._optimizer = optimizer
self._device = device
self._E = E
self._B = B
def client_update(self, epoch, id, global_model):
'''ClientUpdate in FedAVG;'''
# print(f'client {id+1} is started to run.')
self._model.train()
self._model.load_state_dict(global_model)
self._model.to(device=self._device)
criterion = nn.CrossEntropyLoss()
running_loss = 0
num_euqal = 0
acc = None
for _ in range(self._E):
for inputs, labels in self._dataloader:
inputs, labels = inputs.to(self._device), labels.to(self._device)
self._optimizer.zero_grad()
outputs = self._model(inputs)
loss = criterion(outputs, labels)
loss.backward()
self._optimizer.step()
running_loss += loss.item() * inputs.size(0)
self._model.eval()
acc_num = 0
total_num = 0
for inputs, labels in self._dataloader:
inputs, labels = inputs.to(self._device), labels.to(self._device)
test_output = self._model(inputs)
pred_y = torch.max(test_output, 1)[1].data.squeeze()
num_equal = (pred_y == labels).sum().item()
acc_num += num_equal
total_num += labels.size()[0]
print(f"Client {id+1} Ended-loss: \
{running_loss / len(self._dataloader.dataset)*self._E:.4f}, \
accuracy {acc_num/total_num: .4f} ")
return copy.deepcopy(self._model.cpu()).state_dict(), \
running_loss / len(self._dataloader.dataset)*self._E, acc_num/total_num
My whole project can be found here:
I’m new to pytorch and this seems to be a silly question, but trying to solve it and got stuck for a month make me write this post. I’ll be grateful if anyone can help me solve it. Thanks for reading.