Right ways to serialize and load DDP model checkpoints

I have trained a model using DistributedDataParallel. After training, I serialized the model like so where the model is wrapped using DistributedDataParallel:

torch.save(model.state_dict(), 'model.pt')

Note that this serialization was performed in the launcher function which is typically passed to spawn() of torch.multiprocessing. My training setup consists of 4 GPUs.

Now when I am trying to load the checkpoint in my local inference setup (single GPU) the keys are not matching. The model, in this case, is not wrapped using DistributedDataParallel. Any pointers would be useful.

1 Like

Save your DDP model after unwrapping DataParallel, such as,

torch.save(model.module.state_dict(), 'model.pt')

Here, model.module is where your original model (before DDP wrapping) is placed.

1 Like

Thanks @sio277.

Also, a slightly unrelated question on how to best log/print the loss and other metrics in these settings. Currently, I am only logging loss and other metrics from the master i.e. when rank == 0.

To get mean metrics all across the ranks, I use all-reduce function something like this:

import torch.distributed as dist

def global_meters_all_avg(args, *meters):
    """meters: scalar values of loss/accuracy calculated in each rank"""
    tensors = [torch.tensor(meter, device=args.gpu, dtype=torch.float32) for meter in meters]
    for tensor in tensors:
        # each item of `tensors` is all-reduced starting from index 0 (in-place)
        dist.all_reduce(tensor)

    return [(tensor / args.world_size).item() for tensor in tensors]

See Distributed communication package - torch.distributed — PyTorch 1.8.1 documentation for more details about the collective communications.

2 Likes

Thank you very much. I suppose meter is the metric quantity?

Yes, such as loss and accuracy.

I see. Could you also share a minimal example as to where global_meters_all_avg() should be placed inside the training loop? Let’s say the following is my loop (part of the launcher train() function called by mp.spawn():

for batch in pbar:
    # load image and mask into device memory
    image = batch['image'].cuda(rank, non_blocking=True)
    mask = batch['mask'].cuda(rank, non_blocking=True)

    # pass images into model
    pred = model(image)

    # get loss
    loss = criteria(pred, mask)

    # update the model
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

I usually use that function after each epoch (not during the batch iteration), to avoid worsening the training speed. During the batch iterations, I accumulate the loss values, and after 1 epoch, let global_meters_all_avg be called with an input of the accumulated loss value.

Note that for accumulating the loss in each rank, I use this class:

class AvgMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0
        self.avg = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In your code, you can use this class something like this:

losses = AvgMeter()
for batch in pbar:
    # load image and mask into device memory
    image = batch['image'].cuda(rank, non_blocking=True)
    mask = batch['mask'].cuda(rank, non_blocking=True)

    # pass images into model
    pred = model(image)

    # get loss
    loss = criteria(pred, mask)
    losses.update(loss.item(), image.size(0))  # accumulate the loss

    # update the model
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# after each epoch
loss = losses.avg
global_loss = global_meters_all_avg(args, loss) 

The global_loss is the one all-reduced (averaged) across the ranks.

1 Like

Thanks so much. Really appreciate your help here.

1 Like

One more doubt is during logging or printing global_loss I think it needs to be print only from one rank to prevent duplicate entries. Something like if rank == 0: print(global_loss).

It depends on your preference :slightly_smiling_face:.