Wrapping with DDP changes the weights in Half Precision

Wrapping a model with DDP when the model’s weights are half-precision seems to change the weights. I noticed this issue when I was loading checkpoints and I’ve boiled it down to this minimal example.

In the code below, I instantiate a simple model with a single nn.Linear layer and then wrap it with the DDP constructor. This shouldn’t change the weights, right? I check the norm before and after. When the model is full-precision, the norm of the weights changes a little, but not significantly. When the weights are half-precision, the norm of the weights change massively.

Any idea what’s going on?
P.S: I just checked in PyTorch 2.0 and the issue still exists.

import os
import sys
import tempfile
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
import torch.multiprocessing as mp
import numpy as np

from torch.nn.parallel import DistributedDataParallel as DDP

import random
import  torch.nn as nn
import torch.nn.functional as F



def set_seed(s, reproducible=False):
    "Set random seed for `random`, `torch`, and `numpy` (where available)"
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    np.random.seed(s%(2**32-1))
    random.seed(s)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def setup(rank, world_size):
    os.environ['MASTER_ADDR'] = 'localhost'
    os.environ['MASTER_PORT'] = '12355'

    # initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def cleanup():
    dist.destroy_process_group()

def get_model(half_prec=True):
    set_seed(42)
    dtype = torch.float16 if half_prec else torch.float32
    simple_lin_model = nn.Linear(512, 512, bias=False, dtype=dtype)

    return simple_lin_model


def demo_basic(rank, world_size, half_prec=False):
    setup(rank, world_size)

    simple_lin_model = get_model(half_prec)

    print("\nBefore wrapping with DDP, here are the weights:")
    print("Rank: {} | simple_lin_model.weight norm: {}".format(rank, torch.norm(simple_lin_model.weight.data)))

    print("\nLet's try compiling the model, here are the weights:")
    simple_lin_model = torch.compile(simple_lin_model)
    print("Rank: {} | simple_lin_model.weight norm: {}".format(rank, torch.norm(simple_lin_model.weight.data)))

    simple_lin_model = DDP(simple_lin_model.to(rank), device_ids=[rank])

    print("\nAfter wrapping with DDP, here are the weights:")
    print("Rank: {} | simple_lin_model.weight norm: {}".format(rank, torch.norm(simple_lin_model.module.weight.data)))

    cleanup()


def run_demo(demo_fn, world_size, half_prec=False):
    mp.spawn(demo_fn,
             args=(world_size, half_prec),
             nprocs=world_size,
             join=True)


if __name__ == "__main__":
    print("PyTorch version: {}".format(torch.__version__))
    print("Experiment with 1 GPU Full-prec")
    run_demo(demo_basic, 1, half_prec=False)
    print("--------------------------------------------")

    print("\n\nExperiment with 2 GPUs Full-prec")
    run_demo(demo_basic, 2, half_prec=False)
    print("--------------------------------------------")

    print("\n\nExperiment with 1 GPU half-prec")
    run_demo(demo_basic, 1, half_prec=True)
    print("--------------------------------------------")

    print("\n\nExperiment with 2 GPUs half-prec")
    run_demo(demo_basic, 2, half_prec=True)
    print("--------------------------------------------")

    print("why pytorch. why.")

The output of running the above script when using half-prec looks like this

Experiment with 1 GPU half-prec

Before wrapping with DDP, here are the weights:
Rank: 0 | simple_lin_model.weight norm: 5.65625

Let's try compiling the model, here are the weights:
Rank: 0 | simple_lin_model.weight norm: 5.65625

After wrapping with DDP, here are the weights:
Rank: 0 | simple_lin_model.weight norm: 13.0625

The issue is not caused by DDP, but by using float16 on the CPU, which might not use a wider accumulation dtype for norm. Here is a small example:

lin = nn.Linear(512, 512, bias=False, dtype=torch.float16)
lin.weight.norm()
# tensor(5.6562, dtype=torch.float16, grad_fn=<LinalgVectorNormBackward0>)

w = lin.weight.clone()
w_cuda = w.cuda()
w_cuda.norm()
# tensor(13.0547, device='cuda:0', dtype=torch.float16,
#        grad_fn=<LinalgVectorNormBackward0>)
w.float().norm()
# tensor(13.0520, grad_fn=<LinalgVectorNormBackward0>)
1 Like