Distributed model saves weights different from current state dict?

I’m using torch ‘1.13.1+cu117’ and if I run multi-gpu torch code

NUM_GPU=2
CUDA_VISIBLE_DEVICES=0,1 python -W ignore -m torch.distributed.launch --nproc_per_node=$NUM_GPU .... 

and if I do

import torch.distributed as dist
local_rank = dist.get_rank()
if local_rank ==0 :     
     torch.save(model.state_dict(), "temp.pth")
     for name, param in model.state_dict().items():
            print(name, param.mean().item())

I get values like

mlp.fc1_g.weight 2.249192903036601e-07
mlp.fc1_g.bias -0.00013029106776230037
mlp.fc1_x.weight -1.5291587942556362e-06
mlp.fc1_x.bias 0.00014961455599404871
mlp.norm.weight 1.0
mlp.norm.bias 0.0
mlp.fc2.weight -8.485401536972859e-08
mlp.fc2.bias 2.6364039513282478e-05

Now when I load the same weights on 1 GPU

checkpoint = torch.load("temp.pth", map_location='cpu')
model.load_state_dict( checkpoint , strict = True )
for name, param in model.state_dict().items():
    print(name, param.mean().item())

I get

mlp.fc1_g.weight 2.249192903036601e-07
mlp.fc1_g.bias -0.00013029106776230037
mlp.fc1_x.weight -1.5291587942556362e-06
mlp.fc1_x.bias 0.00014961455599404871
mlp.norm.weight 1.0
mlp.norm.bias 0.0
mlp.fc2.weight -8.485401536972859e-08
mlp.fc2.bias 2.6364039513282478e-05

There is a difference in decimal places of weights like mlp.fc2.weight , mlp.fc1_x.weight, mlp.fc1_g.weight.

As a result, validation accuracy during training and later testing the same model is different (up to 2-3%).

What am I doing wrong? Note: I have just posted a few weights of the entire transformers, but the weights differ across all layers.
After training the model for 1 epoch, I have compared values across all weights here

This usually happens due to model wrapping (e.g., DistributedDataParallel), so use model.module.state dict() when saving or loading weights.

Turns out I didnt wrap the model to begin with :sweat_smile:

doing

model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank],find_unused_parameters=True)

Fixed the problem!

Apparently torch.cuda.synchronize() doesnt work unless DistributedDataParallel and GPUs had different weights thus which ever rank saved the model (during validation) was evaluating on testing

Yes, in distributed training, the saved weights may differ in format (e.g., wrapped in module.) due to DataParallel or DistributedDataParallel, requiring state dict key adjustment when loading.