Zero optimizer.consolidate_state_dict(to=0) hangs

Similar issue to ZERO optimizer.consolidate_state_dict() will hang

However when i call optimizer.consolidate_state_dict(t=0) the code appears to hang or freeze.

    def after_epoch(self):
        if isinstance(self.optimizer, torch.distributed.optim.ZeroRedundancyOptimizer):
            comm.synchronize()
            print(
                f"Consolidating Optimizer State Dict ... on {comm.get_rank()}/{comm.get_world_size()}"
            )
            self.optimizer.consolidate_state_dict(to=0)
            print(f"Optimizer Synced from {comm.get_rank()}/{comm.get_world_size()}")

log prints:

[2025-08-29 15:40:34,612 INFO misc.py line 184 174457] Train: [1/16][1/748] Data 1.046 (1.046) Batch 17.165 (17.165) Remain 57:03:38 loss: 2.4071 

[2025-08-29 15:40:35,832 INFO misc.py line 184 174457] Train: [1/16][2/748] Data 0.005 (0.005) Batch 1.220 (1.220) Remain 04:03:16 loss: 3.1259 

[2025-08-29 15:40:37,822 INFO misc.py line 184 174457] Train: [1/16][3/748] Data 0.003 (0.003) Batch 1.990 (1.990) Remain 06:36:55 loss: 3.3515 

[2025-08-29 15:40:53,812 INFO misc.py line 184 174457] Train: [1/16][4/748] Data 0.003 (0.003) Batch 15.990 (15.990) Remain 53:08:22 loss: 2.7074 

[2025-08-29 15:41:14,647 INFO misc.py line 184 174457] Train: [1/16][5/748] Data 0.002 (0.003) Batch 20.835 (18.413) Remain 61:11:09 loss: 2.8574 

[2025-08-29 15:41:19,463 INFO misc.py line 184 174457] Train: [1/16][6/748] Data 0.003 (0.003) Batch 4.815 (13.880) Remain 46:07:13 loss: 4.1689 

[2025-08-29 15:41:20,957 INFO misc.py line 184 174457] Train: [1/16][7/748] Data 0.004 (0.003) Batch 1.494 (10.784) Remain 35:49:42 loss: 3.1682 

[2025-08-29 15:41:25,497 INFO misc.py line 184 174457] Train: [1/16][8/748] Data 3.785 (0.759) Batch 4.540 (9.535) Remain 31:40:37 loss: 2.1350 

[2025-08-29 15:41:30,370 INFO misc.py line 184 174457] Train: [1/16][9/748] Data 2.112 (0.985) Batch 4.873 (8.758) Remain 29:05:36 loss: 2.2668 

[2025-08-29 15:41:49,157 INFO misc.py line 184 174457] Train: [1/16][10/748] Data 8.303 (2.030) Batch 18.787 (10.191) Remain 33:50:58 loss: 1.7625 

Consolidating Optimizer State Dict ... on 3/4
Consolidating Optimizer State Dict ... on 0/4
Consolidating Optimizer State Dict ... on 2/4
Consolidating Optimizer State Dict ... on 1/4


I have already consulted similar issues, and intro documentations like Shard Optimizer States with ZeroRedundancyOptimizer — PyTorch Tutorials 2.8.0+cu128 documentation

is there a recommended way to fix this? or work around it?

Without a repro I’d be guessing at best, I just tried a minimal repro with consolidate_state_dict and it seems to be working fine. What’ is comm.synchronize? Do you notice anything odd if you print the local optimizers? Some more guesses might be a different number of steps per rank

Original:

And being used currently in a fork of

I dont think there is different number of steps per rank. But i can find a way to have every rank report its logs just to be sure. Ill also inspect the optimizer.

Need some time to spin up a spare node, is there anything else youd recommend looking at as well?

Hi Mark, i ran a 2GPU-1Node test to grab the logs below. It seems the sync does happen but quite slow. Is this to be expected? For our larger trainings the process was hung much longer.

As far as i understand, i should trigger consolidate_state_dict(to=0) on all ranks. Is there anything i can do to speed it up?

[2025-08-31 08:49:12,732 rank: 0/2 WARNING train.py line 142] Trainer will only use first 10 samples from the loader!
[2025-08-31 08:49:12,733 rank: 0/2 INFO train.py line 145] => Loading config ...
[2025-08-31 08:49:12,733 rank: 0/2 INFO train.py line 147] Save path: runs/${datapath1}/debug/bs16_xc_${datapoint}_${model_key}_mirror_zero_1Mx6t_long
[2025-08-31 08:49:12,733 rank: 1/2 WARNING train.py line 142] Trainer will only use first 10 samples from the loader!
[2025-08-31 08:49:12,733 rank: 1/2 INFO train.py line 145] => Loading config ...
[2025-08-31 08:49:12,733 rank: 1/2 INFO train.py line 147] Save path: runs/${datapath1}/debug/bs16_xc_${datapoint}_${model_key}_mirror_zero_1Mx6t_long
[2025-08-31 08:49:13,288 rank: 0/2 INFO train.py line 148] Config:
weight = None
resume = False
evaluate = True
test_only = False
seed = 52
save_path = 'runs/${datapath1}/debug/bs16_xc_${datapoint}_${model_key}_mirror_zero_1Mx6t_long'
num_worker = 2
batch_size = 2
batch_size_val = None
batch_size_test = None
epoch = 43
eval_epoch = 43
clip_grad = None
sync_bn = False
enable_amp = True
amp_dtype = 'float16'
empty_cache = True
empty_cache_per_epoch = False
find_unused_parameters = False
enable_wandb = False
wandb_project = 'pointcept'
wandb_key = None
mix_prob = 0
param_dicts = [
    dict(keyword='encoder', lr=0),
    dict(keyword='projector', lr=0),
    dict(keyword='decoder', lr=0)
]
optimizer = dict(
    type='ZeroROTorch',
    lr=0.0,
    weight_decay=0.05,
    optimizer_class='AdamW',
    betas=(0.9, 0.99))
scheduler = dict(
    type='OneCycleLR',
    max_lr=[0.0001, 0.0001, 0.0001, 0.0001],
    pct_start=0.03,
    anneal_strategy='cos',
    div_factor=100.0,
    final_div_factor=1000.0)
*********
deleted
*********
max_sample = 10
num_worker_per_gpu = 1
batch_size_per_gpu = 1
batch_size_val_per_gpu = 1
batch_size_test_per_gpu = 1

[2025-08-31 08:49:13,289 rank: 0/2 INFO train.py line 149] => Building model ...
[2025-08-31 08:49:13,299 rank: 1/2 INFO train.py line 148] Config:
weight = None
resume = False
evaluate = True
test_only = False
seed = 52
save_path = 'runs/${datapath1}/debug/bs16_xc_${datapoint}_${model_key}_mirror_zero_1Mx6t_long'
num_worker = 2
batch_size = 2
batch_size_val = None
batch_size_test = None
epoch = 43
eval_epoch = 43
clip_grad = None
sync_bn = False
enable_amp = True
amp_dtype = 'float16'
empty_cache = True
empty_cache_per_epoch = False
find_unused_parameters = False
enable_wandb = False
wandb_project = 'pointcept'
wandb_key = None
mix_prob = 0
param_dicts = [
    dict(keyword='encoder', lr=0),
    dict(keyword='projector', lr=0),
    dict(keyword='decoder', lr=0)
]
optimizer = dict(
    type='ZeroROTorch',
    lr=0.0,
    weight_decay=0.05,
    optimizer_class='AdamW',
    betas=(0.9, 0.99))
scheduler = dict(
    type='OneCycleLR',
    max_lr=[0.0001, 0.0001, 0.0001, 0.0001],
    pct_start=0.03,
    anneal_strategy='cos',
    div_factor=100.0,
    final_div_factor=1000.0)
*********
deleted
*********
max_sample = 10
num_worker_per_gpu = 1
batch_size_per_gpu = 1
batch_size_val_per_gpu = 1
batch_size_test_per_gpu = 1

[2025-08-31 08:49:13,299 rank: 1/2 INFO train.py line 149] => Building model ...
[2025-08-31 08:49:38,124 rank: 0/2 INFO trainers.py line 190] Num params: 1275537664
[2025-08-31 08:49:38,279 rank: 1/2 INFO trainers.py line 190] Num params: 1275537664
[2025-08-31 08:49:40,550 rank: 0/2 INFO train.py line 151] => Building writer ...
[2025-08-31 08:49:40,550 rank: 1/2 INFO train.py line 151] => Building writer ...
[2025-08-31 08:49:40,550 rank: 1/2 INFO train.py line 280] Tensorboard writer logging dir: runs/${datapath1}/debug/bs16_xc_${datapoint}_${model_key}_mirror_zero_1Mx6t_long
[2025-08-31 08:49:40,551 rank: 1/2 INFO train.py line 153] => Building train dataset & dataloader ...
[2025-08-31 08:49:40,552 rank: 0/2 INFO train.py line 280] Tensorboard writer logging dir: runs/${datapath1}/debug/bs16_xc_${datapoint}_${model_key}_mirror_zero_1Mx6t_long
[2025-08-31 08:49:40,552 rank: 0/2 INFO train.py line 153] => Building train dataset & dataloader ...
[2025-08-31 08:49:40,557 rank: 1/2 INFO defaults.py line 70] Totally 2995 x 1 samples in ${datapath2} train set.
[2025-08-31 08:49:40,559 rank: 0/2 INFO defaults.py line 70] Totally 2995 x 1 samples in ${datapath2} train set.
[2025-08-31 08:49:43,264 rank: 1/2 INFO train.py line 155] => Building val dataset & dataloader ...
[2025-08-31 08:49:43,265 rank: 1/2 INFO defaults.py line 70] Totally 250 x 1 samples in ${datapath2} val set.
[2025-08-31 08:49:43,318 rank: 0/2 INFO train.py line 155] => Building val dataset & dataloader ...
[2025-08-31 08:49:43,319 rank: 0/2 INFO defaults.py line 70] Totally 250 x 1 samples in ${datapath2} val set.
[2025-08-31 08:49:45,862 rank: 1/2 INFO train.py line 157] => Building optimize, scheduler, scaler(amp) ...
[2025-08-31 08:49:45,865 rank: 1/2 INFO optimizer.py line 68] Params Group 1 - lr: 0.0; Params: [].
[2025-08-31 08:49:45,865 rank: 1/2 INFO optimizer.py line 68] Params Group 2 - lr: 0; Params: ['module.encoder'] ****deleted****
[2025-08-31 08:49:45,865 rank: 1/2 INFO optimizer.py line 68] Params Group 3 - lr: 0; Params: ['module.projector'] ****deleted****
[2025-08-31 08:49:45,865 rank: 1/2 INFO optimizer.py line 68] Params Group 4 - lr: 0; Params: ['module.decoder'] ****deleted****
[2025-08-31 08:49:45,868 rank: 1/2 INFO train.py line 418] Total steps: 9124 * 43 = 392332
[2025-08-31 08:49:45,869 rank: 1/2 INFO train.py line 161] => Building hooks ...
[2025-08-31 08:49:45,869 rank: 1/2 INFO misc.py line 390] => Loading checkpoint & weight ...
[2025-08-31 08:49:45,869 rank: 1/2 INFO misc.py line 435] No weight found at: None
[2025-08-31 08:49:45,869 rank: 1/2 INFO train.py line 168] >>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>
[2025-08-31 08:49:45,947 rank: 0/2 INFO train.py line 157] => Building optimize, scheduler, scaler(amp) ...
[2025-08-31 08:49:45,949 rank: 0/2 INFO optimizer.py line 68] Params Group 1 - lr: 0.0; Params: [].
[2025-08-31 08:49:45,949 rank: 0/2 INFO optimizer.py line 68] Params Group 2 - lr: 0; Params: ['module.encoder'] ****deleted****
[2025-08-31 08:49:45,950 rank: 0/2 INFO optimizer.py line 68] Params Group 3 - lr: 0; Params: ['module.projector'] ****deleted****
[2025-08-31 08:49:45,950 rank: 0/2 INFO optimizer.py line 68] Params Group 4 - lr: 0; Params: ['module.decoder'] ****deleted****
[2025-08-31 08:49:45,952 rank: 0/2 INFO train.py line 418] Total steps: 9124 * 43 = 392332
[2025-08-31 08:49:45,953 rank: 0/2 INFO train.py line 161] => Building hooks ...
[2025-08-31 08:49:45,953 rank: 0/2 INFO misc.py line 390] => Loading checkpoint & weight ...
[2025-08-31 08:49:45,953 rank: 0/2 INFO misc.py line 435] No weight found at: None
[2025-08-31 08:49:45,953 rank: 0/2 INFO train.py line 168] >>>>>>>>>>>>>>>> Start Training >>>>>>>>>>>>>>>>
[2025-08-31 08:50:08,431 rank: 0/2 INFO misc.py line 184] Train: [1/43][1/9124] Data 1.220 (1.220) Batch 4.239 (4.239) Remain 461:58:14 loss: 2.1847 
	
[2025-08-31 08:50:08,575 rank: 1/2 INFO misc.py line 184] Train: [1/43][1/9124] Data 1.860 (1.860) Batch 4.727 (4.727) Remain 515:07:28 loss: 3.8498 
	
[2025-08-31 08:50:09,162 rank: 1/2 INFO misc.py line 184] Train: [1/43][2/9124] Data 0.001 (0.001) Batch 0.588 (0.588) Remain 64:01:46 loss: 2.9965 
	
[2025-08-31 08:50:09,164 rank: 0/2 INFO misc.py line 184] Train: [1/43][2/9124] Data 0.004 (0.004) Batch 0.733 (0.733) Remain 79:49:45 loss: 3.1200 
	
[2025-08-31 08:50:09,686 rank: 0/2 INFO misc.py line 184] Train: [1/43][3/9124] Data 0.003 (0.003) Batch 0.523 (0.523) Remain 56:58:29 loss: 2.1029 
	
[2025-08-31 08:50:09,701 rank: 1/2 INFO misc.py line 184] Train: [1/43][3/9124] Data 0.001 (0.001) Batch 0.539 (0.539) Remain 58:41:48 loss: 4.2658 
	
[2025-08-31 08:50:10,300 rank: 1/2 INFO misc.py line 184] Train: [1/43][4/9124] Data 0.001 (0.001) Batch 0.599 (0.599) Remain 65:18:42 loss: 2.6473 
	
[2025-08-31 08:50:10,302 rank: 0/2 INFO misc.py line 184] Train: [1/43][4/9124] Data 0.002 (0.002) Batch 0.615 (0.615) Remain 67:00:42 loss: 3.2209 
	
[2025-08-31 08:50:10,798 rank: 0/2 INFO misc.py line 184] Train: [1/43][5/9124] Data 0.003 (0.003) Batch 0.497 (0.556) Remain 60:34:46 loss: 2.8445 
	
[2025-08-31 08:50:10,804 rank: 1/2 INFO misc.py line 184] Train: [1/43][5/9124] Data 0.001 (0.001) Batch 0.503 (0.551) Remain 60:04:47 loss: 2.8488 
	
[2025-08-31 08:50:12,433 rank: 0/2 INFO misc.py line 184] Train: [1/43][6/9124] Data 0.880 (0.295) Batch 1.634 (0.915) Remain 99:44:40 loss: 2.3001 
	
[2025-08-31 08:50:12,436 rank: 1/2 INFO misc.py line 184] Train: [1/43][6/9124] Data 0.244 (0.082) Batch 1.632 (0.911) Remain 99:19:41 loss: 3.2092 
	
[2025-08-31 08:50:12,938 rank: 1/2 INFO misc.py line 184] Train: [1/43][7/9124] Data 0.002 (0.062) Batch 0.503 (0.809) Remain 88:12:05 loss: 2.4592 
	
[2025-08-31 08:50:12,941 rank: 0/2 INFO misc.py line 184] Train: [1/43][7/9124] Data 0.072 (0.239) Batch 0.505 (0.813) Remain 88:33:54 loss: 2.2741 
	
[2025-08-31 08:50:13,992 rank: 1/2 INFO misc.py line 184] Train: [1/43][8/9124] Data 0.346 (0.119) Batch 1.054 (0.858) Remain 93:31:34 loss: 2.3312 
	
[2025-08-31 08:50:13,994 rank: 0/2 INFO misc.py line 184] Train: [1/43][8/9124] Data 0.433 (0.278) Batch 1.056 (0.861) Remain 93:52:38 loss: 1.6145 
	
[2025-08-31 08:50:17,035 rank: 1/2 INFO misc.py line 184] Train: [1/43][9/9124] Data 0.001 (0.099) Batch 3.043 (1.222) Remain 133:12:53 loss: 2.0611 
	
[2025-08-31 08:50:17,058 rank: 0/2 INFO misc.py line 184] Train: [1/43][9/9124] Data 2.285 (0.613) Batch 3.065 (1.229) Remain 133:53:43 loss: 2.2524 
	
[2025-08-31 08:50:19,111 rank: 1/2 INFO misc.py line 184] Train: [1/43][10/9124] Data 0.001 (0.085) Batch 2.075 (1.344) Remain 146:29:18 loss: 2.2984 
	
[2025-08-31 08:50:19,111 rank: 1/2 INFO train.py line 182] Epoch: 0 max samples:10 reached.
[2025-08-31 08:50:19,130 rank: 0/2 INFO misc.py line 184] Train: [1/43][10/9124] Data 1.347 (0.718) Batch 2.072 (1.349) Remain 147:01:15 loss: 2.2064 
	
[2025-08-31 08:50:19,133 rank: 0/2 INFO train.py line 182] Epoch: 0 max samples:10 reached.
[2025-08-31 08:50:19,134 rank: 0/2 INFO train.py line 250] ZeroRedundancyOptimizer (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499999127109967, 0.99)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1e-06
    lr: 1.0000864161131812e-06
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 9.999999999999999e-10
    weight_decay: 0.05

Parameter Group 1
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499999127109967, 0.99)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1e-06
    lr: 1.0000864161131812e-06
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 9.999999999999999e-10
    weight_decay: 0.05

Parameter Group 2
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499999127109967, 0.99)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1e-06
    lr: 1.0000864161131812e-06
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 9.999999999999999e-10
    weight_decay: 0.05

Parameter Group 3
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499999127109967, 0.99)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1e-06
    lr: 1.0000864161131812e-06
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 9.999999999999999e-10
    weight_decay: 0.05
)
[2025-08-31 08:50:19,134 rank: 1/2 INFO train.py line 250] ZeroRedundancyOptimizer (
Parameter Group 0
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499999127109967, 0.99)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1e-06
    lr: 1.0000864161131812e-06
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 9.999999999999999e-10
    weight_decay: 0.05

Parameter Group 1
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499999127109967, 0.99)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1e-06
    lr: 1.0000864161131812e-06
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 9.999999999999999e-10
    weight_decay: 0.05

Parameter Group 2
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499999127109967, 0.99)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1e-06
    lr: 1.0000864161131812e-06
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 9.999999999999999e-10
    weight_decay: 0.05

Parameter Group 3
    amsgrad: False
    base_momentum: 0.85
    betas: (0.9499999127109967, 0.99)
    capturable: False
    decoupled_weight_decay: True
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    initial_lr: 1e-06
    lr: 1.0000864161131812e-06
    max_lr: 0.0001
    max_momentum: 0.95
    maximize: False
    min_lr: 9.999999999999999e-10
    weight_decay: 0.05
)
[2025-08-31 08:50:19,134 rank: 0/2 INFO train.py line 251] Consolidating Optimizer State Dict ... on 0/2
[2025-08-31 08:50:19,134 rank: 1/2 INFO train.py line 251] Consolidating Optimizer State Dict ... on 1/2
[2025-08-31 08:58:51,023 rank: 1/2 INFO train.py line 255] Optimizer Synced from 1/2
[2025-08-31 08:58:51,023 rank: 1/2 INFO misc.py line 267] Train result: loss: 2.8967 
[2025-08-31 08:58:51,023 rank: 1/2 INFO ***_eval.py line 110] >>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>
[2025-08-31 08:59:04,805 rank: 0/2 INFO train.py line 255] Optimizer Synced from 0/2
[2025-08-31 08:59:04,807 rank: 0/2 INFO misc.py line 267] Train result: loss: 2.4121 
[2025-08-31 08:59:04,808 rank: 0/2 INFO ***_eval.py line 110] >>>>>>>>>>>>>>>> Start Evaluation >>>>>>>>>>>>>>>>
[2025-08-31 08:59:09,572 rank: 1/2 INFO ***_eval.py line 124] Test: [1/10] Loss 2.2650 
[2025-08-31 08:59:09,880 rank: 1/2 INFO ***_eval.py line 124] Test: [2/10] Loss 1.2580 
[2025-08-31 08:59:10,045 rank: 1/2 INFO ***_eval.py line 124] Test: [3/10] Loss 1.6889 
[2025-08-31 08:59:10,424 rank: 1/2 INFO ***_eval.py line 124] Test: [4/10] Loss 1.5754 
[2025-08-31 08:59:10,610 rank: 1/2 INFO ***_eval.py line 124] Test: [5/10] Loss 1.2158 
[2025-08-31 08:59:10,773 rank: 1/2 INFO ***_eval.py line 124] Test: [6/10] Loss 2.0492 `
[2025-08-31 08:59:10,932 rank: 1/2 INFO ***_eval.py line 124] Test: [7/10] Loss 2.2129 
[2025-08-31 08:59:11,551 rank: 1/2 INFO ***_eval.py line 124] Test: [8/10] Loss 1.2615 
[2025-08-31 08:59:11,756 rank: 1/2 INFO ***_eval.py line 124] Test: [9/10] Loss 2.5478 
[2025-08-31 08:59:11,911 rank: 1/2 INFO ***_eval.py line 124] Test: [10/10] Loss 1.8812 
[2025-08-31 08:59:13,484 rank: 1/2 INFO ***_eval.py line 67] Eval result: ****deleted****