Train 10 models using 4 gpus

Hi,

I want to train 10 models simultaneously using multiple gpus. First, I split the training data into 10 parts, so each models randomly collect a batch of data from one part at each iteration, and they update their parameters independently with the data they receive. Then, after the local update, I need to compute the average of all models and then give this average parameter to all models. So compared with traditional training, the only difference now is we need to compute the average of the parameters of all models after each local update. It is not very difficult to write the code for single gpu, but as we train 10 models here, it takes a lot of time. Anyone knows how to use multiple gpus to accelerate the training?

Here I attach my code which works well in single GPU:

K = 10#number of models, here it should be 10
models = [Net().to(device) for _ in range(K)]
opts = [opt.SGD(models[i].parameters(), lr) for i in range(K)]

for _ in range(runs):#epoch
        for data in train_loader:
            for i in range(K): 
                #gradient descent over all agents
                x,y = data[i][0].to(device), data[i][1].to(device)# data for i_th model
                models[i].train()
                yp = models[i](x)
                loss = nn.CrossEntropyLoss()(yp,y)
                opts[i].zero_grad()#sgd
                loss.backward()
                opts[i].step()
                models[i].eval()
            
            with torch.no_grad():

                ##combine the parameters of all agents
                net_comb_state_dict_temp = [net.state_dict() for net in models]
                net_comb_state_dict = copy.deepcopy(net_comb_state_dict_temp)# record the w_{i-1}
                weight_keys = list(net_comb_state_dict[0].keys())
                for j in range(K):
                    updated_state_dict = collections.OrderedDict()#construct a dictionary which tracks the order of the added items
                    for key in weight_keys:
                        key_sum = 0
                        for n in range(len(net_comb_state_dict)):
                            key_sum = key_sum + 1/K*net_comb_state_dict[n][key]
                        updated_state_dict[key] = key_sum
                    models[j].load_state_dict(updated_state_dict)

We have existing distributed solutions for this.

If you are simply doing data parallel, you may want to look into DistributedDataParallel. This does training in a SPMD style on multiple GPUs. Getting Started with Distributed Data Parallel — PyTorch Tutorials 2.0.0+cu117 documentation