Multi-gpu training does not converge

Hello,

I am training a ResNext model on cifar100. The loss has two parts: supervised, consistency regularization. I use DataParallel to wrap the model. I find training on a single GPU works well. However, the training doesn’t converge if using multiple GPUs. Here is the code to compute loss:

  logits_clean = model(input, aug=False)
  loss = F.cross_entropy(logits_clean, target)

  logits_aug1 = model(input, aug=True)
  logits_aug2 = model(input, aug=True)
            
    p_clean, p_aug1, p_aug2 = F.softmax(
    logits_clean, dim=1), F.softmax(
    logits_aug1, dim=1), F.softmax(
    logits_aug2, dim=1)

    p_mixture = torch.clamp((p_clean + p_aug1 + p_aug2) / 3., 1e-7, 1).log()
    consist_loss = (F.kl_div(p_mixture, p_clean, reduction='batchmean') +
                            F.kl_div(p_mixture, p_aug1, reduction='batchmean') +
                            F.kl_div(p_mixture, p_aug2, reduction='batchmean')) / 3.

   loss += config.consist_wt * consist_loss

In the above code, if aug=True, there will be some feature augmentation during the forward process. The augmentation will exchange some information between samples. For one GPU training, the log for the first few iterations is:

Supervised Loss 5.8061 (5.8061) Consistency Loss 0.0218 (0.0218)        Loss 6.0676 (6.0676)
Supervised Loss 15.2160 (10.6238)       Consistency Loss 0.0039 (0.0234)        Loss 15.2623 (10.9045)
Supervised Loss 7.7359 (9.3208) Consistency Loss 0.0279 (0.0194)        Loss 8.0703 (9.5532)
Supervised Loss 5.1153 (8.2851) Consistency Loss 0.0031 (0.0159)        Loss 5.1523 (8.4765)
Supervised Loss 4.7728 (7.5176) Consistency Loss 0.0093 (0.0137)        Loss 4.8840 (7.6816)
Supervised Loss 7.2397 (7.0988) Consistency Loss 0.0024 (0.0119)        Loss 7.2690 (7.2417)
Supervised Loss 4.5411 (6.7195) Consistency Loss 0.0022 (0.0106)        Loss 4.5676 (6.8472)
Supervised Loss 4.5578 (6.4211) Consistency Loss 0.0034 (0.0096)        Loss 4.5989 (6.5362)
Supervised Loss 4.5433 (6.1923) Consistency Loss 0.0021 (0.0087)        Loss 4.5685 (6.2971)
Supervised Loss 4.5238 (6.0134) Consistency Loss 0.0022 (0.0080)        Loss 4.5497 (6.1093)
Supervised Loss 4.4789 (5.8675) Consistency Loss 0.0055 (0.0074)        Loss 4.5453 (5.9567)
Supervised Loss 4.5648 (5.7460) Consistency Loss 0.0015 (0.0070)        Loss 4.5828 (5.8299)
Supervised Loss 4.4963 (5.6442) Consistency Loss 0.0025 (0.0066)        Loss 4.5262 (5.7232)
Supervised Loss 4.4717 (5.5555) Consistency Loss 0.0022 (0.0062)        Loss 4.4985 (5.6303)
Supervised Loss 4.4500 (5.4778) Consistency Loss 0.0023 (0.0060)        Loss 4.4775 (5.5498)
Supervised Loss 4.4184 (5.4073) Consistency Loss 0.0039 (0.0058)        Loss 4.4650 (5.4770)
Supervised Loss 4.3696 (5.3446) Consistency Loss 0.0041 (0.0057)        Loss 4.4183 (5.4130)
Supervised Loss 4.4169 (5.2873) Consistency Loss 0.0061 (0.0056)        Loss 4.4900 (5.3547)
Supervised Loss 4.3870 (5.2368) Consistency Loss 0.0041 (0.0055)        Loss 4.4365 (5.3029)
Supervised Loss 4.4100 (5.1898) Consistency Loss 0.0042 (0.0055)        Loss 4.4606 (5.2556)
Supervised Loss 4.3863 (5.1467) Consistency Loss 0.0032 (0.0054)        Loss 4.4247 (5.2121)
Supervised Loss 4.2698 (5.1084) Consistency Loss 0.0044 (0.0054)        Loss 4.3226 (5.1731)
Supervised Loss 4.3833 (5.0720) Consistency Loss 0.0072 (0.0054)        Loss 4.4695 (5.1364)
Supervised Loss 4.3285 (5.0383) Consistency Loss 0.0049 (0.0054)        Loss 4.3868 (5.1031)
Supervised Loss 4.4569 (5.0073) Consistency Loss 0.0063 (0.0054)        Loss 4.5322 (5.0723)
Supervised Loss 4.2775 (4.9790) Consistency Loss 0.0072 (0.0054)        Loss 4.3634 (5.0443)
Supervised Loss 4.2218 (4.9525) Consistency Loss 0.0067 (0.0054)        Loss 4.3022 (5.0179)
Supervised Loss 4.2382 (4.9265) Consistency Loss 0.0084 (0.0055)        Loss 4.3385 (4.9923)
Supervised Loss 4.1960 (4.9017) Consistency Loss 0.0056 (0.0056)        Loss 4.2627 (4.9684)
Supervised Loss 4.2202 (4.8789) Consistency Loss 0.0064 (0.0056)        Loss 4.2970 (4.9459)
Supervised Loss 4.2757 (4.8566) Consistency Loss 0.0090 (0.0057)        Loss 4.3840 (4.9244)
Supervised Loss 4.2316 (4.8360) Consistency Loss 0.0102 (0.0057)        Loss 4.3545 (4.9042)
Supervised Loss 4.1364 (4.8154) Consistency Loss 0.0083 (0.0057)        Loss 4.2358 (4.8842)
Supervised Loss 4.1381 (4.7951) Consistency Loss 0.0082 (0.0058)        Loss 4.2369 (4.8650)
Supervised Loss 4.1155 (4.7771) Consistency Loss 0.0066 (0.0059)        Loss 4.1951 (4.8473)

We can see the supervised loss continues to decrease and the consistency part maintains a small number, which is normal for the first epoch. However, For multi-gpu training, the log for the first few iterations is:

Supervised Loss 4.7842 (4.7842) Consistency Loss 0.0353 (0.0353)        Loss 5.1374 (5.1374)
Supervised Loss 6.0271 (5.3342) Consistency Loss 0.0584 (0.0661)        Loss 6.1633 (5.8588)
Supervised Loss 7.4594 (5.7474) Consistency Loss 0.0405 (0.0593)        Loss 5.9195 (6.3001)
Supervised Loss 5.5253 (5.8246) Consistency Loss 0.0409 (0.0562)        Loss 5.9347 (6.3425)
Supervised Loss 5.3308 (5.7481) Consistency Loss 0.0391 (0.0508)        Loss 5.7216 (6.1887)
Supervised Loss 5.0278 (5.6837) Consistency Loss 0.0311 (0.0488)        Loss 5.3383 (6.1297)
Supervised Loss 5.0018 (5.5859) Consistency Loss 0.0373 (0.0467)        Loss 5.3744 (6.0418)
Supervised Loss 4.8198 (5.4986) Consistency Loss 0.0163 (0.0432)        Loss 4.9823 (5.9141)
Supervised Loss 4.7773 (5.4078) Consistency Loss 0.0132 (0.0412)        Loss 4.8723 (5.8413)
Supervised Loss 4.6783 (5.3305) Consistency Loss 0.0119 (0.0383)        Loss 4.7971 (5.7232)
Supervised Loss 4.6450 (5.2683) Consistency Loss 0.0095 (0.0355)        Loss 4.7403 (5.6284)
Supervised Loss 4.6154 (5.2142) Consistency Loss 0.0099 (0.0338)        Loss 4.7147 (5.5745)
Supervised Loss 4.5999 (5.1677) Consistency Loss 0.0056 (0.0305)        Loss 4.6554 (5.4698)
Supervised Loss 4.6428 (5.1269) Consistency Loss 0.0052 (0.0289)        Loss 4.6832 (5.4191)
Supervised Loss 4.6359 (5.0916) Consistency Loss 0.0036 (0.0267)        Loss 4.6950 (5.3522)
Supervised Loss 4.6198 (5.0608) Consistency Loss 0.0032 (0.0248)        Loss 4.6431 (5.2956)
Supervised Loss 4.6359 (5.0345) Consistency Loss 0.0023 (0.0234)        Loss 4.6584 (5.2554)
Supervised Loss 4.6362 (5.0108) Consistency Loss 0.0020 (0.0216)        Loss 4.6440 (5.2054)
Supervised Loss 4.6214 (4.9892) Consistency Loss 0.0016 (0.0201)        Loss 4.6397 (5.1623)
Supervised Loss 4.6154 (4.9701) Consistency Loss 0.0013 (0.0194)        Loss 4.6426 (5.1409)
Supervised Loss 4.6383 (4.9525) Consistency Loss 0.0019 (0.0185)        Loss 4.6342 (5.1157)
Supervised Loss 4.6218 (4.9369) Consistency Loss 0.0016 (0.0177)        Loss 4.6368 (5.0931)
Supervised Loss 4.6255 (4.9224) Consistency Loss 0.0012 (0.0168)        Loss 4.6376 (5.0686)
Supervised Loss 4.6349 (4.9095) Consistency Loss 0.0015 (0.0160)        Loss 4.6496 (5.0472)
Supervised Loss 4.6260 (4.8975) Consistency Loss 0.0012 (0.0154)        Loss 4.6375 (5.0305)
Supervised Loss 4.6169 (4.8863) Consistency Loss 0.0007 (0.0151)        Loss 4.6483 (5.0211)
Supervised Loss 4.6296 (4.8761) Consistency Loss 0.0011 (0.0145)        Loss 4.6244 (5.0033)
Supervised Loss 4.6121 (4.8666) Consistency Loss 0.0006 (0.0141)        Loss 4.6354 (4.9925)

We can see the consistency loss goes to zero. However, the supervised part does not decrease much and will remain the same if training continues. I feel the problem may be because of the information exchanges between GPUs during the model forward. Although I can train with one gpu by using a small batch size, it would be much better if the multi-gpu training works. It would be great if anybody can give some hints how to fix the issue. Thanks in advance!

I am having the same problem with my code. Did you figure out a solution?