Issues when trying to implement FedAVG using torch.multiprocessing.Pool

After finishing a single process version of FedAVG(a federated learning), I tried to apply torch.multiprocessing.Pool() and use apply_async to share and update the model. However, I got a way lower accuracy (from 0.8 to 0.1)in each client’s local model.

my original code is like this:
def update_server(self, T):
‘’’
FedAVG
‘’’
client_acc = []
# 2: for t=0, …, T-1 do
for round in range(T):
print(f"Round {round+1} started…")
client_models = []
client_losses = []
client_accs = []
delta_is_t = []

        # --multi-pro pool
        pool = mp.Pool(self._num_clients)
        results = [] 
        # --
        
        # 3: TODO 2 : select a subset of client
        # 4: x{i, 0}^t = x_t
        x_t = self._global_model.state_dict()
        # 5: for each client i \in S in parallel do
        for i in range(self._num_clients):
            #TODO 1 : do this in parallel
            
            #  x_{i, 0}^t = x_t  : x_t is model from last epoch 
            x_t_temp = copy.deepcopy(x_t)
            
            # cur_user._model.load_state_dict(x_t)
            # 6-8: x{i, K}^t = CLIENTOPT
            # --
            # result = pool.apply_async(self._clients[i].client_update, args=(round+1,i, x_t_temp))
            # x_i_K_t, client_loss, client_acc = \
            # pool.apply_async(self._clients[i].client_update, args=(round+1,i, x_t_temp)).get()
            # --
            x_i_K_t, client_loss, client_acc =  \
               self._clients[i].client_update(epoch=round+1, id=i, global_model=x_t_temp)
            # 9: \delta_{client i, epoch t} = x{i, K}^t - x_t
            # delta_i_t =  x_i_K_t - x_t
            # delta_i_t = {}
            # for key in x_i_K_t.keys():
            #     delta_i_t[key] = x_i_K_t[key] - x_t[key]
            # --
            # results.append(result)
            # --
            client_models.append(x_i_K_t)
            client_losses.append(client_loss)
            client_accs.append(client_acc)
            # delta_is_t.append(delta_i_t)
            # -- print(f"Client {i+1} loss: {client_loss:.4f}, accuracy {client_acc: .4f} ")
        # --
        
        # for result in results:
        #     client_state_dict, client_loss, client_acc = result.get()
        #     client_models.append(client_state_dict)
        #     client_losses.append(client_loss)
        #     client_accs.append(client_acc)
        # --

        # x_{t+1} = \frac{1}{S} \sum_{i \in S} \delta_i^t
        global_state_dict = self.server_update(client_models)
        self._global_model.load_state_dict(global_state_dict)
        print(f"Round {round+1} finished, global loss:  \
            {sum(client_losses)/len(client_losses):.4f},  \
                global accuracy: {sum(client_accs)/len(client_accs): .4f}")
        #--
        pool.close()
        pool.join()
        #--
        
    return sum(client_accs)/len(client_accs) 

my updated version only changed this:
ool.apply_async(self._clients[i].client_update, args=(round+1,i, x_t_temp)).get()
(I used .get() method for simplicity)

My client update functions stays same between two version:
class ClientBase():
‘’‘base class for FL learning’‘’
def init(self, dataloader, model, optimizer, device, E, B):
self._dataloader = dataloader
self._model = model
self._optimizer = optimizer
self._device = device
self._E = E
self._B = B

def client_update(self, epoch, id, global_model):
    '''ClientUpdate in FedAVG;'''
    # print(f'client {id+1} is started to run.')
    self._model.train()
    self._model.load_state_dict(global_model)
    
    self._model.to(device=self._device)
    criterion = nn.CrossEntropyLoss()
    running_loss = 0
    num_euqal = 0
    acc = None
    
    for _ in range(self._E):
        for inputs, labels in self._dataloader:
            inputs, labels = inputs.to(self._device), labels.to(self._device)
            self._optimizer.zero_grad()
            outputs = self._model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            self._optimizer.step()
            running_loss += loss.item() * inputs.size(0)
    self._model.eval()
    acc_num = 0
    total_num = 0
    for inputs, labels in self._dataloader:
        inputs, labels = inputs.to(self._device), labels.to(self._device)
        test_output = self._model(inputs)
        pred_y = torch.max(test_output, 1)[1].data.squeeze()
        num_equal = (pred_y == labels).sum().item()
        acc_num += num_equal
        total_num += labels.size()[0] 
    print(f"Client {id+1} Ended-loss:  \
        {running_loss / len(self._dataloader.dataset)*self._E:.4f}, \
            accuracy {acc_num/total_num: .4f} ")
    return copy.deepcopy(self._model.cpu()).state_dict(), \
        running_loss / len(self._dataloader.dataset)*self._E, acc_num/total_num

My whole project can be found here:

I’m new to pytorch and this seems to be a silly question, but trying to solve it and got stuck for a month make me write this post. I’ll be grateful if anyone can help me solve it. Thanks for reading.