Accumulate gradients leads to dramatic performance degrade

Hi, everyone, I’ve implemented the accumulate gradients using Pytorch, and trained on cifar100 dataset, here is my code snippet:

    net.train()
    loss = 0
    acc_batches = args.acc_b
    for batch_index, (images, labels) in enumerate(cifar100_training_loader):
        if epoch <= args.warm:
            warmup_scheduler.step()

        if args.gpu:
            labels = labels.cuda()
            images = images.cuda()

        outputs = net(images)
        scaled_loss = loss_function(outputs, labels) / args.acc_b
        scaled_loss.backward()
        loss += scaled_loss

        print(1)
        if acc_batches > 1:
            acc_batches -= 1
            continue

        optimizer.step()
        optimizer.zero_grad()

        n_iter = (epoch - 1) * len(cifar100_training_loader) + batch_index + 1

        last_layer = list(net.children())[-1]
        for name, para in last_layer.named_parameters():
            if 'weight' in name:
                writer.add_scalar('LastLayerGradients/grad_norm2_weights', para.grad.norm(), n_iter)
            if 'bias' in name:
                writer.add_scalar('LastLayerGradients/grad_norm2_bias', para.grad.norm(), n_iter)

        print('Training Epoch: {epoch} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.6f}'.format(
            loss.item(),
            optimizer.param_groups[0]['lr'],
            epoch=epoch,
            trained_samples=batch_index * args.b + len(images),
            total_samples=len(cifar100_training_loader.dataset)
        ))

        #update training loss for each iteration
        writer.add_scalar('Train/loss', loss.item(), n_iter)

        if acc_batches == 1:
            acc_batches = args.acc_b
            loss = 0

    for name, param in net.named_parameters():
        layer, attr = os.path.splitext(name)
        attr = attr[1:]
        writer.add_histogram("{}/{}".format(layer, attr), param, epoch)

then I run the following command to train my resnet50:

python train.py -net resnet50 -gpu -b 256
python train.py -net resnet50 -gpu -b 256 -acc_b 2
python train.py -net resnet50 -gpu -b 256 -acc_b 4
python train.py -net resnet50 -gpu -b 256 -acc_b 8

I’ve plotted the test acc using tensorboard, and here is the result:

orange line: args.acc_b = 1
dark blue line: args.acc_b = 2 (loss becomes nan at 37th epoch)

Training Epoch: 37 [3072/50000] Loss: 1.2027    LR: 0.100000
1
1
Training Epoch: 37 [3584/50000] Loss: 5.3613    LR: 0.100000
1
1
Training Epoch: 37 [4096/50000] Loss: 5.7180    LR: 0.100000
1
1
Training Epoch: 37 [4608/50000] Loss: nan       LR: 0.100000
1
1
Training Epoch: 37 [5120/50000] Loss: nan       LR: 0.100000

red line: args.acc_b = 4 (loss increases suddenly from 0.7734 to 6.5302 at 38 epoch)

Training Epoch: 38 [43008/50000]        Loss: 0.7734    LR: 0.100000
1
1
1
1
Training Epoch: 38 [44032/50000]        Loss: 6.5302    LR: 0.100000
1
1
1
1
Training Epoch: 38 [45056/50000]        Loss: 5.4722    LR: 0.100000
1
1
1
1

light blue line: args.acc_b = 8 (loss increases suddenly from 0.0202 to 6.8125 at 105 epoch):

Training Epoch: 105 [6144/50000]        Loss: 0.0202    LR: 0.020000
1
1
1
1
1
1
1
1
Training Epoch: 105 [8192/50000]        Loss: 6.8125    LR: 0.020000

The whole project is at https://github.com/weiaicunzai/pytorch-cifar100/tree/feat/accumulate_grad_batches

I just do not understand why my loss keeps exploding when I set arg.acc_b > 1.
Thanks in advance.