Faster RCNN extremely slow training

Starting from this tutorial, I am trying to train a Faster R-CNN ResNet50 network on a custom dataset.
The train partition contains 26188 images that are 512x512 but, when loaded, they get resized at 240x240. To train on all the train data set for just one epoch it took 14 hours.
I’m trying to debug where the bottleneck(s) are.

I’m pretty sure everything is running on the gpu because when i do nvidia-smi Volatile GPU-Util is always around 99%.

Benchmarking using time.time() with batch size 10, 8 workers and a dataset of 80 elements, for the first epoch I obtain:

LOAD TIME 0.6659998893737793
TRAIN TIME 7.46896505355835
LOSS TIME 5.605570554733276
LOAD TIME 4.123662710189819
TRAIN TIME 7.413051128387451
LOSS TIME 0.14319729804992676
LOAD TIME 9.588207483291626
TRAIN TIME 7.420511722564697
LOSS TIME 0.14394235610961914
LOAD TIME 9.581853866577148
TRAIN TIME 7.408280849456787
LOSS TIME 0.14383983612060547

So it takes 7 seconds to train on 10 images. Correct me if i’m wrong but I think it’s a bit too much time.

The relevant bit of my code is here:

def main(folder_path, csv_file, attempt_fd, resume=False, fromEpoch=0, num_epochs=2, re_evaluate=False, evaluate=False):
     device = torch.device('cuda')
     num_classes = 9
     dataset = dld.DLDataset(csv_file, folder_path, transforms=get_transform(train=True))
     dataset_test = dld.DLDataset(csv_file, folder_path, transforms=get_transform(train=False))

     # split the dataset in train and test set
     indices = torch.randperm(len(dataset)).tolist()
     train_idx = np.load('train.npy')
     test_idx = np.load('test.npy')
     dataset = torch.utils.data.Subset(dataset, indices[0:80]) #np.asarray(train_idx)
     dataset_test = torch.utils.data.Subset(dataset_test, indices[0:5]) #np.asarray(test_idx)

     # define training and validation data loaders
     data_loader = torch.utils.data.DataLoader(
         dataset, batch_size=10, shuffle=True, num_workers=8,
         collate_fn=utils.collate_fn)

     data_loader_test = torch.utils.data.DataLoader(
         dataset_test, batch_size=1, shuffle=False, num_workers=8,
         collate_fn=utils.collate_fn)

     # get the model using our helper function
     model = get_model_instance_segmentation(num_classes)
     model.double()
     model.to(device)
     #model.cuda(device)

     # construct an optimizer
     params = [p for p in model.parameters() if p.requires_grad]
     optimizer = torch.optim.Adam(params, lr=0.001, weight_decay=0)

     lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=2, gamma=0.1)

     mAP = []
     IoU = []

     torch.cuda.empty_cache()

     #num_epochs = 2
     if resume:
        fromEpoch, model, optimizer, lr_scheduler = loadState(attempt_fd, fromEpoch, model, optimizer, lr_scheduler)
         model.to(device)
         if fromEpoch>0:
            mAP = np.load(attempt_fd + 'mAP_' + str(fromEpoch) + '.npy')
             IoU = np.load(attempt_fd + 'IoU_'+ str(fromEpoch) + '.npy')
         fromEpoch += 1

     if re_evaluate:
         t_mAP, t_iou = custom_evaluate(model, data_loader_test, device)
         print("EPOCH", fromEpoch-1, "mAP", t_mAP, "iou", t_iou)
         mAP = np.append(mAP, t_mAP)
         IoU = np.append(IoU, t_iou)
         np.save(attempt_fd + 'mAP_' + str(fromEpoch-1) + '.npy', mAP)
         np.save(attempt_fd + 'IoU_' + str(fromEpoch-1) + '.npy', IoU)

     for epoch in range(fromEpoch, num_epochs):
         # train for one epoch, printing every 10 iterations
         loss = train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=1)
         print("LOSS AT EPOCH", epoch, "IS", loss)
         # update the learning rate
         lr_scheduler.step()
         #save progress
         saveState(attempt_fd, epoch, model, optimizer, lr_scheduler, loss)
         #empty cache
         torch.cuda.empty_cache()
         if evaluate:
             # evaluate MYSELF
             t_mAP, t_iou = custom_evaluate(model, data_loader_test, device)
             print("EPOCH", epoch,"mAP", t_mAP, "iou", t_iou)
             mAP = np.append(mAP, t_mAP)
             IoU = np.append(IoU, t_iou)
             np.save(attempt_fd + 'mAP_' + str(epoch) + '.npy', mAP)
             np.save(attempt_fd + 'IoU_'+ str(epoch) + '.npy', IoU)
             #evaluate(model, data_loader_test, device=device)
             torch.cuda.empty_cache()

    print("That's it!")

I also modified the train_one_epoch function:

 def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
     model.train()
     #metric_logger = utils.MetricLogger(delimiter="  ")
     #metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
     header = 'Epoch: [{}]'.format(epoch)

     lr_scheduler = None
     if epoch == 0:
         warmup_factor = 1. / 1000
         warmup_iters = min(1000, len(data_loader) - 1) 
         lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

     one = time.time()
     for images, targets in data_loader: #metric_logger.log_every(data_loader, print_freq, header):
         images = list(image.to(device) for image in images)
         targets = [{k: v.unsqueeze(0).to(device) for k, v in t.items()} for t in targets]
         loaded = time.time()
        print("LOAD TIME", loaded-one)

         loss_dict = model(images, targets)
         train_time = time.time()
         print("TRAIN TIME", train_time - loaded)


         losses = sum(loss for loss in loss_dict.values())


         # reduce losses over all GPUs for logging purposes
         loss_dict_reduced = utils.reduce_dict(loss_dict)
         losses_reduced = sum(loss for loss in loss_dict_reduced.values())
         loss_value = losses_reduced.item()
         #losses_reduced = losses_reduced.detach()

         if not math.isfinite(loss_value):
             print("Loss is {}, stopping training".format(loss_value))
             print(loss_dict_reduced)
             sys.exit(1)

         optimizer.zero_grad()
         losses.backward()
         optimizer.step()

         #if lr_scheduler is not None:
             #lr_scheduler.step()

        #metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
        #metric_logger.update(lr=optimizer.param_groups[0]["lr"])

         loss_time = time.time()
         print("LOSS TIME", loss_time-train_time)
         one = loss_time



     return loss_value

If anyone can give me tips, I would be very grateful, thank you!

I’m making progress towards solving this issue. I’ve pinpointed the cause of the problem to the inputs I give the network. If I train with random tensors, generated with torch.rand(), it takes 2 seconds to train on a batch of 32 elements.
I have to check what I’m doing wrong in the Dataset. It is still very strange since I followed the tutorial linked above

Problem is the dataloader, I have no idea why. Using a custom generator works fine.

I don’t know, how utils.collate_fn is defined, but you might try to profile it or reduce the number of workers.

That being said, the double call will slow down your training additionally, as FP64 is slower on the GPU than FP32. While you will get more floating point precision, this is usually not necessary in the Deep Learning domain (but you might have your reasons to do so).
Also torch.cuda.empty_cache() will not save any memory, but only slow down your code, as PyTorch would need to reallocate the device memory with synchronizing calls.

Modifying the number of workers didn’t change much, I tried 2,4 and 8 workers.
The collate_fn function is this one:

def collate_fn(batch):
    return tuple(zip(*batch))

Since my last comment, I’m using the float32 dtype and I’ve removed the torch.cuda.empty_cache() call. I’ve managed my meory a bit better with the del variable instructions given that I was also running out of cuda memory.

But only replacing the dataloader with a simple custom generator brought my performance from 3 seconds per image to 2,5 second per batch of 32 images.
I suspect it was doing something on the images themselves but I don’t know what.
Maybe there was something wrong with my dataset implementation that was slowing down the dataloader?
I’ve no clue but I’m also new to the whole PyTorch thing.

The DataLoader will call into the Dataset.__getitem__ method with the current sample index.
Depending how you’ve implemented the loading and processing logic in __getitem__, it might take some time to load and process the data.
How did you implement your generator, which achieves the speedup?
Are you loading the data lazily or are you processing it at all?

My generator is this:

def get_batch(dataset, device, batch_size=32):
iterations = math.ceil(len(dataset)/batch_size)
for i in range(iterations):
    images = []
    targets = []
    for j in range(i*batch_size, min(len(dataset), (i+1)*batch_size)):
        sample = dataset[j]
        s = sample[0].type(torch.FloatTensor).to(device)
        images.append(s)
        t = {}
        t['boxes'] = sample[1]['boxes'].type(torch.FloatTensor).unsqueeze(0).to(device)
        t['labels'] = sample[1]['labels'].unsqueeze(0).to(device)
        targets.append(t)
    images = list(image for image in images)
    targets = list(target for target in targets)
    yield images, targets

I get the data in the lazy way, I guess, because I only load the actual images when I’m getting the batch. I was expecting low loading times with this approach but it actually takes around 0,5 seconds on the machine to load a batch and 2,5 second to train, so for my task it’s more than fine.

I understood that the data loader load all the data at once in the memory, multiple times for each worker, so I would expect a slow loading time the first time but not in the next iterations. Still, I don’t understand the super slow loading times.

If you want I can also post the __getitem__ method but that’s nothing special, just loading data and images and convert them to the appropriate types.

Could you nevertheless post the __getitem__ method, so that I could try to reproduce it locally?
Since the slowdown is large, I would like to debug it and see, if I can reproduce it.

Here is my whole dataset class.
In __getitem__ i have to convert the images to jpg for compatibility issue, but this only happens when it sees a picture for the first time.
The dataset I’m using is the DeepLesion one.

import os
import numpy as np
import torch
from PIL import Image
import pandas as pd
import cv2

class DLDataset(object):
    def __init__(self, csv_file, root_dir, transforms=None):
        self.root_dir = root_dir
        self.csv_file = csv_file
        self.transforms = transforms
        self.bbox_frame = pd.read_csv(root_dir + csv_file, header=0,
                                  usecols=['File_name', 'Bounding_boxes', 'Image_size', 'Coarse_lesion_type'])
        self.npData = self.bbox_frame.to_numpy()

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        img_folder, img_name, img_ext = self.formatInputPath(self.npData[idx][0])
        if (not os.path.exists(self.root_dir + img_folder + img_name + ".jpg")):
            image = cv2.imread(self.root_dir + img_folder + img_name + "." + img_ext)
            cv2.imwrite(self.root_dir + img_folder + img_name + ".jpg", image,    [cv2.IMWRITE_JPEG_QUALITY, 100])
        img_name = self.root_dir + img_folder + img_name + '.jpg'

        image = Image.open(img_name).convert("RGB")

        #bboxes
        coords = self.getBBoxes(self.npData[idx][1])
        boxes = [float(i) for i in coords]
        list = []
        list.append([boxes[0], boxes[1], boxes[2], boxes[3]])
        #boxes = np.asarray(boxes)


        # label
        label = self.getLabel(self.npData[idx][2])
        if label==-1:
            label=9

        #convert all to tensor

        boxes = torch.as_tensor(boxes, dtype=torch.float64)
        label = torch.tensor(label)

        target = {}
        target['boxes']=boxes
        target['labels']=label

        if self.transforms is not None:
            image, target = self.transforms( (image, target) )

        return image, target

    def __len__(self):
        return len(self.bbox_frame)

    def formatInputPath(self, path):
        index = path.rfind("_")
        point = path.rfind(".")
        return path[:index] + "/", path[index + 1:point], path[point+1:]

    def getBBoxes(self, bbox):
        bbox = bbox.replace(',', '')
        coord = bbox.split()
        # corners = [(coord[0],coord[1]),(coord[2],coord[3])]
        return coord

    def getLabel(self, label):
        return int(label)

    def getOnlyLabels(self):
        all_labels = []
        for i in range(len(self.bbox_frame)):
            all_labels.append(self.npData[i][2])
        all_labels = np.array(all_labels)
        all_labels[all_labels < 0] = 9
        return all_labels

This is my custom ToTensor class:

class ToTensor(object):


def __call__(self, sample):
    image, bboxes, label = sample[0], sample[1]['boxes'], sample[1]['labels']

    # swap color axis because
    # numpy image: H x W x C
    # torch image: C X H X W
    #image = np.array(image)

    #image = image.transpose((2, 0, 1))
    #image = torch.tensor(image)
    tr = transforms.ToTensor()
    image = tr(image)
    image = image.double()
    target = sample[1]


    return image, target

When you init the dataset, pass a torchvision.transform.compose with my custom ToTensor class in it.

Hey, i was wondering if you could maybe share the code of your custom data loader. i have exactly the same issue and can’t fix it.

Even I am facing a similar issue while training the fasterrcnn_resnet50_fpn model, though my dataset is different my workflow is similar to that of @Wertiz. Any solution to the problem is much appreciated.