Distributed training error when starting second epoch

Hi, guys. I am trying to train a resnet model on single node and 8 GPUs with DistributedDataParallel. Every thing is ok during the first epoch. However, the script shut down without any error report when the second epoch starts. I have tried to track the code and find that the code stops at:

for batch_idx, (data, label) in enumerate(train_loader, 0): 

I find a similar topic from here which imply that this problem may caused by DistributedSampler.
Meanwhile, I create a small version dataset with 34 classes, and the error is gone.
Here is the code:

torch.cuda.set_device(opt.local_rank)

dist.init_process_group(backend='nccl', init_method='env://', world_size=8)

train_dir = os.path.join(opt.data, 'train')
train_dataset = datasets.ImageFolder(
    train_dir,
    transforms.Compose([transforms.ToTensor()])
)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=opt.batch_s,
        num_workers=opt.workers,
        pin_memory=True,
        shuffle=False,
        sampler=train_sampler
    )

input_size = (opt.batch_s, 3, 128, 128)
num_classes = 9092
model = models.se_resnet34_v3(input_size, opt.grid_s, num_classes)

model.cuda()
model = torch.nn.parallel.DistributedDataParallel(model,\
    device_ids=[opt.local_rank], output_device=opt.local_rank)

optimizer = optim.SGD([
        {'params': get_parameters(model, bias=False)},
        {'params': get_parameters(model, bias=True), 'lr':opt.lr * 2, 'weight_decay': 0},
        {'params': get_parameters(model, bn=True), 'lr':opt.lr * 1.00001001358, 'weight_decay':0}
    ], lr=opt.lr, momentum=opt.momentum, weight_decay=opt.weight_decay)
if opt.resume:
    if os.path.isfile(opt.resume):
        checkpoint = torch.load(opt.resume)
        optimizer.load_state_dict(checkpoint['optimizer'])

scheduler = optim.lr_scheduler.MultiStepLR(optimizer, \
    milestones=[8,15,24], gamma=0.5)

def train(epoch):
    for batch_idx, (data, label) in enumerate(train_loader, 0):
        optimizer.zero_grad()
        data, label = data.cuda(), label.cuda()

        output = model(data)
        nll_loss = F.nll_loss(output, label)
        de_loss = deformation_constraint_loss(grid, opt.grid_s)
        loss =  nll_loss + de_loss

        loss.backward()
        optimizer.step()


for epoch in range(opt.start_epoch, opt.epoch + 1):
    train_sampler.set_epoch(epoch)
    scheduler.step() 
    model.train()
    train(epoch)

Command Line:

python -m torch.distributed.launch --nproc_per_node=8 train.py

Any help will be appreciate.