A confusing problem about DDP

I’m running the code of Mean Teacher where the teacher is updated via EMA.
The Teacher’s state_dicts of 10 continuous EMA updates are saved (marked as Tn-1, Tn-2, …, Tn-10), and I want to check the distance between the output of Tn and Tn-10:

When I use traditional DataParallel, the value is usually 0.3~0.7.
But when I use AMP+DDP on single device but multi GPUs, the value is 0.01 ~ 0.1.
And also, I tried AMP+DataParallel and AMP+DDP on single device and single GPU, the value back to 0.3~0.7.
So it looks like something goes different from usual when DDP is apply on multi GPUs.

My code looks like:

# Model Initialize
scaler = amp.GradScaler()
student_model = Model().cuda()
teacher_model = Model().cuda()
teacher_dict = student_model.state_dict().copy()
student_model = nn.SyncBatchNorm.convert_sync_batchnorm(student_model)
# student_model = nn.DataParallel(student_model, device_ids = gpu_ids)
student_model = DDP(student_model, device_ids = [local_rank])
teacher_model = nn.SyncBatchNorm.convert_sync_batchnorm(teacher_model)
# teacher_model = nn.DataParallel(teacher_model, device_ids = gpu_ids)
teacher_model = DDP(teacher_model, device_ids = [local_rank])

loss_fn = nn.SmoothL1Loss().cuda()

opt = torch.optim.Adam([{"params":[p for p in list(student_model.parameters()) if p.requires_grad]}], lr=lr, betas=(0.5, 0.999))

# Training
teacher_dict_list = []
teacher_dict_prev = None  # Tn-10, at the beginning is None
for epoch in range(num_epochs):
	for data in trainloader:
		images, labels = data
		images = images.cuda()
		labels = labels.cuda()

		with amp.autocast():
			preds = student_model(images)
			loss = loss_fn(preds, labels)

		scaler.scale(loss).backward()
		scaler.step(opt)
		opt.zero_grad()

		t = teacher_dict.copy()
		for k, v in t.items():
			t[k] = t[k].cpu()	# store on cpu memory, or the cuda memory would overflow
		teacher_dict_list.append(t)	# store Tn-1
		# EMA update teacher_dict
		for k, v in student_model.module.state_dict().items():
			if k.find('num_batches_tracked') == -1:
				teacher_dict[k] = 0.999 * teacher_dict[k] + 0.001 * v
		# we have stored 10 teacher's state_dicts
		if len(teacher_dict_list) == 10:
			teacher_dict_prev = teacher_dict_list[0]	# pick Tn-10
			del(teacher_dict_list[0])

		if teacher_dict_prev is None:
			continue
		else:
			teacher_model.module.load_state_dict(teacher_dict)	# load the newest state_dict
			with torch.no_grad():
				with amp.autocast():
					output_cur = teacher_model(images)

			teacher_model.module.load_state_dict(teacher_dict_prev)	# load the old state_dict
			with torch.no_grad():
				with amp.autocast():
					output_prev = teacher_model(images)
			print(torch.abs(output_cur - output_prev).max())
		scaler.update()

The code above is a simplified version, the structure is the same. To some reason, I can’t post all my code for you to reproduce the problem.
So, the printed value is 0.3 ~ 0.7 (which i think should be correct) at the setting of:

  • DataParallel w/o AMP on single device and multi GPUs
  • DataParallel with AMP on single device and multi GPUs
  • DDP with AMP on single device and single GPU

but it becomes much smaller (0.01 ~ 0.1) on DDP with AMP on single device and multi GPUs.

Hey did u solve your problem ? If yes, could you maybe explain the solution?