PyTorch backward is too slow

Hi, I’m doing a profile of the pytorch training process (with 1*V100, training ResNet18 on ImageNet), and I find that backward is really slow (as shown in the figure, 12ms with forward, but 300ms with backward, other batchs are the same)

this is my code

    nvtx.range_push("Epoch: " + str(epoch))
    nvtx.range_push("Batch 0")
    nvtx.range_push("Load Data")
    for i, (input_data, target) in enumerate(train_loader):
        if(profile and i==300):
            exit(0)
        # measure data loading time
        data_time.update(time.time() - end)

        target = target.cuda(non_blocking=True)
        input_data = input_data.cuda(non_blocking=True)
        nvtx.range_pop()

        # compute output
        nvtx.range_push("Forward")
        output = model(input_data)
        nvtx.range_pop()

        # compute gradient and do SGD step
        nvtx.range_push("Backward")
        optimizer.zero_grad()
        loss = criterion(output, target)
        # measure accuracy and record loss
        prec1, prec5 = accuracy(output, target, topk=(1, 5))
        losses.update(loss.item(), input_data.size(0))
        top1.update(prec1[0], input_data.size(0))
        top5.update(prec1[0], input_data.size(0))
        loss.backward()
        nvtx.range_pop()

        nvtx.range_push("SGD")
        optimizer.step()
        nvtx.range_pop()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        nvtx.range_pop()
        if i % print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Data {data_time.val:.3f} ({data_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                  'Prec@5 {top5.val:.3f} ({top5.avg:.3f})\t'.format(
                epoch, i, len(train_loader), batch_time=batch_time,
                data_time=data_time, loss=losses, top1=top1, top5=top5))
        nvtx.range_push("Batch " + str(i+1))
        nvtx.range_push("Load Data")
    nvtx.range_pop()
    nvtx.range_pop()
    nvtx.range_pop()

Other batchs have the same slow backward

You could try to use torch.backends.cudnn.benchmark = True in case you are not already using it.
Also, to avoid synchronizations you could remove all synchronizing item() calls.