Hi,sync batch normalization has been really bothering me for a lone time ,last time I posted a question here does not get much response .
The situation is when I train with sync batch normalization ,the training process stops…The author of yolov5 can reproduce this issue
And When I use pytorch version 1.6.0 ,same phenomenon occurs…
I am not sure how many people meet this issue and does pytorch dev team is aware of this issue ?
Below I post my code of how I do DDP training :
args = build_args()
local_rank = args.local_rank
args.world_size = int(os.environ["WORLD_SIZE"]) if "WORLD_SIZE" in os.environ else 1
#only work for ddp mode
if local_rank != -1:
args.total_batch_size = args.batch_size
torch.cuda.set_device(local_rank)
device = torch.device("cuda",local_rank )
dist.init_process_group(backend='nccl', init_method='env://')
assert args.batch_size % args.world_size==0 ,"batch size shuld be multiple of number of cuda device."
args.batch_size = args.total_batch_size // args.world_size
model = MODEL()
if local_rank != -1:
#sync bn
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).to(device)
model.to(device)
model = DDP(model , device_ids = [local_rank] , output_device = local_rank)
else :
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
#And forward looks like this :
with torch.set_grad_enabled(True):
with torch.cuda.amp.autocast():
outputs = model(inputs)[0]
_, preds, = torch.max(outputs, 1)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
Hope someone can help me out ,Thanks in advance !