Inference Code Optimizations+ DataLoader

I am trying to increase the inference rate for a pre-trained network. The code for the inference is as follows:

import argparse
import torch
import skimage.transform
from skimage.io import imsave
import torchvision
from PIL import Image
import imageio
import torch.optim
import RedNet_model
from utils import utils
from utils.utils import load_ckpt
from torch import nn
import numpy as np
import os
import glob
from torch.utils.data import DataLoader
import RedNet_data
import time
import torch.backends.cudnn as cudnn

parser = argparse.ArgumentParser(description='Semantic Segmentation')
parser.add_argument('--data-dir', default=None, metavar='DIR',
                    help='path to Data Directory')
parser.add_argument('-o', '--output', default='', metavar='DIR',
                    help='path to output')
parser.add_argument('--cuda', action='store_true', default=False,
                    help='enables CUDA training')
parser.add_argument('--last-ckpt', default='', type=str, metavar='PATH',
                    help='path to latest checkpoint (default: none)')
parser.add_argument('-b', '--batch-size', default=10, type=int,
                    metavar='N', help='mini-batch size (default: 10)')

args = parser.parse_args()
device = torch.device("cuda:0" if args.cuda and torch.cuda.is_available() else "cpu")
image_w = 640
image_h = 480

def inference():
    
    test_data = RedNet_data.InferenceData(phase_train=False, data_dir=args.data_dir)
    
    test_loader = DataLoader(test_data, batch_size = args.batch_size, shuffle=False, num_workers=1, pin_memory=True)
    
    num_test = len(test_data)

    model = RedNet_model.RedNet(pretrained=False)
    model = nn.DataParallel(model).cuda()#Need to add because model trained on multiple gpu's
    load_ckpt(model, None, args.last_ckpt, device)
    model.eval()
    model.to(device)
    cudnn.benchmark = True
    
    start = time.time()
    
    torch.no_grad()
    
    for batch_idx, (sample, idx, img_paths, depth_paths) in enumerate(test_loader):
        image = sample['image'].numpy()
        depth = sample['depth'].numpy()  
        
        if batch_idx % 1 == 0:
            print('No. of Batches Done: [{0}/{1}]\t'.format(batch_idx,len(test_loader)))

        i = 0
        
        for im, d in zip(image, depth):
            im = skimage.transform.resize(im, (image_h, image_w), order=1,
                                         mode='reflect', preserve_range=True)
            # Nearest-neighbor
            d = skimage.transform.resize(d, (image_h, image_w), order=0,
                                         mode='reflect', preserve_range=True)

            fileName1 = os.path.basename(img_paths[i])
            
            im = im / 255
            im = torch.from_numpy(im).float()
            d = torch.from_numpy(d).float()
            im = im.permute(2, 0, 1)
            d.unsqueeze_(0)
            im = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(im)
            d = torchvision.transforms.Normalize(mean=[19050], std=[9650])(d)
            im = im.to(device).unsqueeze_(0)
            d = d.to(device).unsqueeze_(0)
            pred = model(im, d)
            output = utils.color_label(torch.max(pred, 1)[1]+1)[0]
            
            imageio.imsave(args.output+fileName1, output.numpy().transpose((1, 2, 0)))
            i += 1
    
    end = time.time()
    elapsed = end-start
    print("Time elapsed in seconds:", elapsed)
    print("Inference Rate(Images per second): ", num_test/elapsed)
    
if __name__ == '__main__':
    inference()

Now, I noticed something weird when running this inference on the cloud(4* Tesla VX-100’s GPU’s) and on my local machine (1* Nvidia GTX 1080Ti ). It actually runs faster on my local machine! (4 images/sec on the cloud versus 6.6 images/sec on my local machine) I am still new to Deep Learning and I cant seem to understand how this could be possible. Another possible follow-up question I have is: If I load a specific batch of images via the DataLoader class shouldnt it be able to process all those images simultaneously? And, Lastly, How can I optimize this further. I need to be able to improve the inference rate by almost 200% so that I can perform the inference in realtime. Thanks.

1 Like

Currently you are performing all the preprocessing of your data in the loop over your DataLoader. It should be faster, if you move it to your Dataset's __getitem__ method and use multiple workers to load your data batches.
Could you try that and see if it’s faster?
Let me know, if that works for you.

Do you mean something like this?

class InferenceData(Dataset):
    def __init__(self, transform=None, phase_train=False, data_dir=None):

        self.phase_train = phase_train
        self.transform = transform
        self.data_dir = data_dir
        
        self.img_dir_test = os.path.join(data_dir, 'data', 'inference', 'images')
        self.depth_dir_test = os.path.join(data_dir, 'data', 'inference', 'depths')

    def __len__(self):
        if self.phase_train:
            img_files_train = os.listdir(self.img_dir_train)
            return len(img_files_train)
        else:
            img_files_test = os.listdir(self.img_dir_test)
            return len(img_files_test)

    def __getitem__(self, idx):
        if self.phase_train:
            img_dir = self.img_dir_train
            depth_dir = self.depth_dir_train
        else:
            img_dir = self.img_dir_test
            depth_dir = self.depth_dir_test
            
        img_paths = glob.glob(os.path.join(img_dir, '*.png'))
        depth_paths = glob.glob(os.path.join(depth_dir, '*.png'))
        
        depth = np.asarray(Image.open(depth_paths[idx]))
        image = np.asarray(Image.open(img_paths[idx]))
        image = skimage.transform.resize(image, (image_h, image_w), order=1,
                                         mode='reflect', preserve_range=True)
        depth = skimage.transform.resize(depth, (image_h, image_w), order=0,
                                         mode='reflect', preserve_range=True)

        image = image / 255
        image = torch.from_numpy(image).float()
        depth = torch.from_numpy(depth).float()
        image = image.permute(2, 0, 1)
        depth.unsqueeze_(0)
        image = torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(image)
        depth = torchvision.transforms.Normalize(mean=[19050], std=[9650])(depth)
        
        sample = {'image': image, 'depth': depth}
        
        if self.transform:
            sample = self.transform(sample)

        return sample, idx, img_paths[idx], depth_paths[idx]

If yes, I am having problems referencing the individual transformed images.

My inference code looks something like this.

for batch_idx, (sample, idx, img_paths, depth_paths) in enumerate(test_loader):
    
        image = sample['image']
        depth = sample['depth']
        
        if batch_idx % 1 == 0:
            print('No. of Batches Done: [{0}/{1}]\t'.format(batch_idx,len(test_loader)))

        i = 0
        
        for im, d in zip(image, depth):
            print(im[i].shape)
            print(d[i].shape)
            im[i] = im[i].to(device).unsqueeze_(0)
            d[i] = d[i].to(device).unsqueeze_(0)
            pred = model(im, d)
            output = utils.color_label(torch.max(pred, 1)[1]+1)[0]
1 Like

Also, another follow-up question that I had was:
Looking at the way I am iterating over the inference code will I always get the associated images for both RGB and Depth? (Both folders have the same number of images)

Basically your approach is right.

Just some minor notes:

  • you could move the image path search (glob.glob) into __init__ as you are currently creating the paths in each __getitem__ call, which introduces an unnecessary overhead.
  • your current transformations (resize, normalization) can be composed into torchvision.transforms.Compose so that your code might be a bit cleaner.

What do you mean by “problems referencing the individual transformed images”?

As long as the image paths are sorted in the same way, i.e. the list elements in both lists correspond to each other, you’ll get the right images.

Shifted the image search in __init__. By problems referencing images, I mean that it takes in images before the transforms I applied in __getitem__ instead of images with the transforms. ( I am checking that by printing out the shape of the images before I predict) It should be 1*3*480*640 for the RGB and 1*1*640*480 for the depth when I run it through the model. But, I somehow get an input of size 480* 640* 3. which I think means that it doesnt undergo the required transformations.

I have it working now. But, this gives me the same inference rate as I was getting earlier. I want to know if there was a way for me to improve the rate at which it consumes the images faster.

You could try to increase num_workers of your dataloader for parallelization of data loading. This will decrease the inference time iff your data loading is the bottleneck.

Have you tried to run the inference on a single GPU (without DataParallel)?
Since DataParallel implies some tensor copies, it could also increase your speed for shallow networks.

Another minor issue: you should call torch.cuda.synchronize() every time before you call time.time(). This blocks your program until all cuda ops are finished. Otherwise you might not get a correct timing, since cuda calls are asynchronous and your script may finish on CPU after queueing some cuda calls before the CUDA operations have been executed.

Thanks. I tried increasing the num_workers and it seems my bottleneck isn’t the data loading.
As far as I understand this, I can’t infer without DataParallel if I have trained the model on multiple GPU’s? It was throwing errors in loading the model state_dictionary.
I will check out the torch.cuda.synchronize.
Thanks. Also, The model I am using is a modified ResNet50. Not the shallowest of networks.