Hi, guys. I am trying to train a resnet model on single node and 8 GPUs with DistributedDataParallel. Every thing is ok during the first epoch. However, the script shut down without any error report when the second epoch starts. I have tried to track the code and find that the code stops at:
for batch_idx, (data, label) in enumerate(train_loader, 0):
I find a similar topic from here which imply that this problem may caused by DistributedSampler.
Meanwhile, I create a small version dataset with 34 classes, and the error is gone.
Here is the code:
torch.cuda.set_device(opt.local_rank)
dist.init_process_group(backend='nccl', init_method='env://', world_size=8)
train_dir = os.path.join(opt.data, 'train')
train_dataset = datasets.ImageFolder(
train_dir,
transforms.Compose([transforms.ToTensor()])
)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=opt.batch_s,
num_workers=opt.workers,
pin_memory=True,
shuffle=False,
sampler=train_sampler
)
input_size = (opt.batch_s, 3, 128, 128)
num_classes = 9092
model = models.se_resnet34_v3(input_size, opt.grid_s, num_classes)
model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model,\
device_ids=[opt.local_rank], output_device=opt.local_rank)
optimizer = optim.SGD([
{'params': get_parameters(model, bias=False)},
{'params': get_parameters(model, bias=True), 'lr':opt.lr * 2, 'weight_decay': 0},
{'params': get_parameters(model, bn=True), 'lr':opt.lr * 1.00001001358, 'weight_decay':0}
], lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
if opt.resume:
if os.path.isfile(opt.resume):
checkpoint = torch.load(opt.resume)
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, \
milestones=[8,15,24], gamma=0.5)
def train(epoch):
for batch_idx, (data, label) in enumerate(train_loader, 0):
optimizer.zero_grad()
data, label = data.cuda(), label.cuda()
output = model(data)
nll_loss = F.nll_loss(output, label)
de_loss = deformation_constraint_loss(grid, opt.grid_s)
loss = nll_loss + de_loss
loss.backward()
optimizer.step()
for epoch in range(opt.start_epoch, opt.epoch + 1):
train_sampler.set_epoch(epoch)
scheduler.step()
model.train()
train(epoch)
Command Line:
python -m torch.distributed.launch --nproc_per_node=8 train.py
Any help will be appreciate.