How to save a single final model from distributed multi GPU training?

Hello,

There is something I seem to struggle to understand regarding how to use the DistributedDataParallel correctly. Below is a reproducible example of my code (I tried to make it as short and general as possible, and removed the evaluation step from the training).

I’m running the code on a machine with two GPUs, and my problem is that the code will save two separate torch models, one for each GPU process I’ve spawned. My assumption is that distributed multiprocessing should eventually reconvene everything under the same model, and so a single model should be saved by the code after the training. Could someone please tell me what I’m doing wrong in the code below?

Thanks in advance

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data.distributed import DistributedSampler
import torch.multiprocessing as mp
from argparse import ArgumentParser
import os

class MyModel(nn.Module):
    def __init__(self, input_dim, inner_layer_1, inner_layer_2, output_dim):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, inner_layer_1)
        self.fc2 = nn.Linear(inner_layer_1, inner_layer_2)
        self.fc3 = nn.Linear(inner_layer_2, output_dim)

    def forward(self, x):
        x = self.fc1(x) 
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.softmax(x, dim=1)
        return x

def train(gpu_number, n_epochs, model, train_data, optimizer, loss_fn, log_interval=2):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '29500'
    torch.distributed.init_process_group(
        backend='nccl',
        init_method='env://',
        world_size=2,  # total number of gpus
        rank=gpu_number
    )

    sampler = DistributedSampler(train_data, num_replicas=2, rank=gpu_number)
    trainloader = DataLoader(train_data, batch_size=8, sampler=sampler)

    #torch.cuda.set_device(gpu_number)
    model = model.cuda(gpu_number)
    model = DDP(model, device_ids=[gpu_number], output_device=gpu_number)

    for epoch in range(n_epochs):
        for i, batch in enumerate(trainloader):
            inputs, labels = batch[:,:8].cuda(gpu_number), batch[:,-2:].cuda(gpu_number)
            optimizer.zero_grad()
            outputs = model(inputs)

            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
    torch.save(model.state_dict(), f"model_{gpu_number}.pt")

if __name__ == "__main__":
    train_data = torch.rand(100, 10)
    n_epochs = 3
    learning_rate = 0.001
    model = MyModel(8, 800, 300, 2)
    loss_fn = nn.MSELoss() # use nn.CrossEntropyLoss() for binary classification
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    mp.spawn(train, nprocs=2, args=(n_epochs, model, train_data, optimizer, loss_fn, 2))

Hi @AndreaSottana

When using the distributed training mode, one of the processes should be treated as the main process, and you can save the model only for the main process.

Check one of the torchvision’s examples, which will give you a good idea for your problem.

Say you’re using two GPUs in the distributed training mode, and then there will be two processes of your script.
Both the processes will call save_on_master function in the above example at some time, but only one of the processes will save the model (i.e., when is_main_process() returns True)

1 Like

Hi @yoshitomo-matsubara
Many thanks for your prompt reply. This is the first time I hear about save_on_master and couldn’t find it in any torch distributed documentation I’ve seen so far, it would be good to make it more visible I think.
Anyway, I have 3 follow up questions.

  1. If I understand correctly, the functions you linked would simply save the model only if the rank is zero. Is this arbitrary, i.e. could I have just chosen rank one instead for example?
  2. The code you linked does not contain any explicit action of bringing together the two spawned processes, which I’m therefore assuming has to happen in the background. Does this mean that at any point after calling optimiser.step(), the weights on both GPUs will synchronise automatically? If true, this would imply that the two models that my original code would have saved were completely identical and took into account the training from both GPUs, as opposed to being two different models trained on a different GPU each with a different subset of the data. If my understanding is correct, this would also imply that the save_on_master function you linked has the only advantage of not saving the model multiple times, but even if I left my code as original, it would just save multiple copies of the same model with exactly the same weights, and I could use any of these saved copies, so it wouldn’t make any real difference. Please correct me if I’m wrong.
  3. Given that the two (or more) GPUs are independent, one might run the code slightly faster than the other. If you save the model when the rank is zero and the GPU from rank 1 hasn’t yet finished running, is there a chance that you might save the model before the whole training has actually completed?

Thank you very much for your kind help!

1 Like

Hi @AndreaSottana

  1. Yes, technically you can choose the rank you like as the main process. Since the rank number starts from zero, I’d suggest using rank 0 as the main process
  2. Using DDP, the two processes synchronize gradient across processes. You can confirm the detail here. If you don’t control the timing of saving your model when allowing both the processes to save a model, it may cause file handler error like one of the processes attempts to overwrite a file whose save operation is not complete.
  3. It can be controlled by torch.distributed.barrier(), which waits for other processes. For instance, I’d put torch.distributed.barrier() right before torch.save in your code above to wait for all the iterations in the epoch done before saving the model. The same example in torchvision uses it in evaluation to compute global accuracy across the processes. The synchronize_between_processes() uses torch.distributed.barrier() as dist.barrier()

Hi @yoshitomo-matsubara
Many thanks for the very detailed reply.
I will try to study the code you linked and implement your changes, and will let you know if I have further queries down the line.
Thanks again

1 Like

Hi @yoshitomo-matsubara
I’d have a small follow up from this.

I’ve noticed that you don’t call dist.barrier() before save_on_master (for example here) but suggested I do that in my code. Is there a reason why you don’t need to do this in your code?

I’ve also noticed that when you save the model, you save the model_without_ddp.state_dict() instead of simply model.state_dict(). I’ve not seen this done before. Is there any advantage in doing this over simply saving the DDP model’s state dict?

Thanks again

Hi @AndreaSottana

I don’t know which one you referred to as “your code”, but I suggested calling dist.barrrier() before torch.save in the code you showed above as follows:

For instance, I’d put torch.distributed.barrier() right before torch.save in your code above to wait for all the iterations in the epoch done before saving the model.

The example code actually calls dist.barrier() before save_on_master, and that happens in evaluate function, which is called before save_on_master

I’ve also noticed that when you save the model, you save the model_without_ddp.state_dict() instead of simply model.state_dict() . I’ve not seen this done before. Is there any advantage in doing this over simply saving the DDP model’s state dict?

Since DDP and DP will wrap your model, the state_dict saved by model.stage_dict() cannot be directly loaded a model without DDP/DP.
e.g.,
Try this minimal example

from torchvision import models

model = models.resnet18()
print(model.state_dict().keys())

model = torch.nn.parallel.DistributedDataParallel(model)
# or use DP instead of DDP
# model = torch.nn.DataParallel(model)
print(model.state_dict().keys())

From the first print statement, you’ll see pure module paths used in ResNet-18. But from the second print statement, you can confirm “module.” as a prefix for all the module paths shown by the first print statement.

It can be confirmed that your model is referred to as module in DP and DDP implementations

1 Like

Thanks again!
By “your code” I meant the code here but I didn’t realise that the code actually called the evaluate function just before save_on_master, which was calling dist.barrier, makes sense now. I should have said the torch vision code that you linked.
I will convert my torch.save to the save_on_master function and call dict.barrier() before it in my code as well then.

Thanks for the small example with resnet as well, all clear now. I’m still new to multi GPU training but this helped me a lot.

1 Like

Hi @AndreaSottana @yoshitomo-matsubara, I also had a similar doubt and this discussion really helped. Can you please share your final code which allowed you to save the model on master.

Thanks in advance!