Distributed training shut down on the second epoch

Hi, guys! I‘m trying to use distributed data parallel to train a resnet model wtih vggface2 dataset on 8 GPUs on single node. Every thing is ok during the first epoch. However, the script shut down without any error report when the second epoch starts. I have tracked the code, and find that the code stops at for batch_idx, (data, label) in enumerate(train_loader): on the second epoch.
I have also tried a smaller version of vggface2 consisting of 34 classes, and found that this error is gone.
Here are the scripts and the output:

  1. training code
import argparse,os,time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.distributed as dist
import torch.utils.data
import torch.utils.data.distributed
import torchvision
from torchvision import datasets, transforms
import numpy as np
import models
from util import *

parser = argparse.ArgumentParser()
parser.add_argument('--start_epoch', type=int, default=1, help='start epoch number')
parser.add_argument('--epoch', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate, default=0.1')
parser.add_argument('--momentum', type=float, default=0.9, help='momentum, default=0.9')
parser.add_argument('--weight_decay', type=float, default=0.0002, help='weight_decay, default=0.0002')
parser.add_argument('--batch_s', type=int, default=64, help='input batch size')
parser.add_argument('--grid_s', type=int, default=8, help='grid size')
parser.add_argument('--data', type=str, default='../vgg2data', help='data directory')
parser.add_argument('--workers', type=int, default=4, help='number of data loading workers')
parser.add_argument('--output_dir', type=str, default='./output/', help='model_saving directory')
parser.add_argument('--resume', type=str, default='', help='resume')
parser.add_argument("--display_interval", type=int, default=50)
parser.add_argument("--local_rank", type=int)
opt = parser.parse_args()

dist.init_process_group(backend='nccl', init_method='env://')

train_dir = os.path.join(opt.data, 'train')
train_dataset = datasets.ImageFolder(

train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(

input_size = (opt.batch_s, 3, 128, 128)
num_classes = 9092
#num_classes = 34
model = models.se_resnet34_v3(input_size, opt.grid_s, num_classes)
if opt.resume:
    if os.path.isfile(opt.resume):
        print("=> loading checkpoint '{}'".format(opt.resume))
        checkpoint = torch.load(opt.resume)
        print("=> loaded checkpoint '{}' (epoch {})"
            .format(opt.resume, checkpoint['epoch']))
model = torch.nn.parallel.DistributedDataParallel(model,\
    device_ids=[opt.local_rank], output_device=opt.local_rank)

optimizer = optim.SGD([
        {'params': get_parameters(model, bias=False)},
        {'params': get_parameters(model, bias=True), 'lr':opt.lr * 2, 'weight_decay': 0},
        {'params': get_parameters(model, bn=True), 'lr':opt.lr * 1.00001001358, 'weight_decay':0}
    ], lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
if opt.resume:
    if os.path.isfile(opt.resume):
        checkpoint = torch.load(opt.resume)

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, \
    milestones=[8,10,12,14,15,16,17,18,19,20,21,22,23,24], gamma=0.5)

def train(epoch):
    batch_time = AverageMeter()
    data_time = AverageMeter()
    losses = AverageMeter()
    nll_losses = AverageMeter()
    de_losses = AverageMeter()
    end = time.time()
    for batch_idx, (data, label) in enumerate(train_loader):
        data_time.update(time.time() - end)

        data, label = data.cuda(), label.cuda()
        output, grid = model(data)
        nll_loss = F.nll_loss(output, label)
        de_loss = deformation_constraint_loss(grid, opt.grid_s)
        loss =  nll_loss + de_loss

        losses.update(loss.item(), data.size(0))
        nll_losses.update(nll_loss.item(), data.size(0))
        de_losses.update(de_loss.item(), data.size(0))


        batch_time.update(time.time() - end)
        end = time.time()

        if opt.local_rank == 0 and batch_idx % opt.display_interval == 0:
            total_time = int(((opt.epoch - epoch + 1) * len(train_loader) - batch_idx) * batch_time.avg)
            print('Epoch: [{0}][{1}/{2}]\t'
                    'Time {batch_time.val:.3f}\t'
                    'Data {data_time.val:.3f}\t'
                    'RemainTime [{3}:{4}]\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'NLL_Loss {nll_loss.val:.3f} ({nll_loss.avg:.3f})\t'
                    'DE_Loss {de_loss.val:.3f} ({de_loss.avg:.3f})'.format(
                    epoch, batch_idx, len(train_loader), total_time / 3600, total_time % 3600 / 60, batch_time=batch_time,
                    data_time=data_time, loss=losses, nll_loss=nll_losses, de_loss=de_losses))

for epoch in range(opt.start_epoch, opt.epoch + 1):

    if opt.local_rank == 0:
        print("=> saving checkpoint '{}'".format('checkpoint'+'_'+str(epoch+1)+'.pth.tar'))
            'epoch': epoch + 1,
            'arch': 'resnet34',
        }, 'checkpoint'+'_'+str(epoch+1)+'.pth.tar')
        torch.save(model.module.state_dict(), opt.output_dir + 'grid_face.pt')
        print 'Done.'
  1. command line
python -m torch.distributed.launch --nproc_per_node=8 train.py
  1. log
|Epoch: [1][0/6215]|Time 12.832|Data 9.403|RemainTime [553:49]|Loss 9.2308 (9.2308)|NLL_Loss 9.157 (9.157)|DE_Loss 0.074 (0.074)|
|Epoch: [1][50/6215]|Time 0.271|Data 0.000|RemainTime [22:45]|Loss 13.9143 (19.6498)|NLL_Loss 9.073 (9.098)|DE_Loss 4.841 (10.552)|
|Epoch: [1][100/6215]|Time 0.265|Data 0.000|RemainTime [17:22]|Loss 13.1845 (16.5875)|NLL_Loss 8.963 (9.031)|DE_Loss 4.221 (7.556)|
|Epoch: [1][150/6215]|Time 0.256|Data 0.000|RemainTime [15:31]|Loss 12.5163 (15.3281)|NLL_Loss 8.679 (8.945)|DE_Loss 3.837 (6.383)|
Epoch: [1][6150/6215]	Time 0.283	Data 0.000	RemainTime [11:36]	Loss 5.2405 (7.9289)	NLL_Loss 3.552 (5.624)	DE_Loss 1.688 (2.305)
Epoch: [1][6200/6215]	Time 0.285	Data 0.000	RemainTime [11:36]	Loss 5.6334 (7.9077)	NLL_Loss 3.953 (5.607)	DE_Loss 1.680 (2.300)
=> saving checkpoint 'checkpoint_2.pth.tar'

Does anyone have idea about this?