Why is pytorch's GPU utilization so low in production ( NOT training )?

I just got the utilization at 70%, using Tesla v100 16g, ubantu, pytorch 1.0.1 post2.
Dose utilization related to the size of dataset?

Only very loosely. I hate how Nvidia decided to name this metric “utilization” because it confuses the heck out of me and others.

There are basically 2 definitions of utilization:

  1. utilization( what it is most commonly used as) - measure of how much GPU resource a process ( kernel ) is using. If GPU has 100GB of DRAM, and we only use 10GB, then DRAM utilization would be 10%. If there are 1000 CUDA cores and we use 300 of them on average through out the process, then CUDA core utilization would be 30%.

  2. utilization ( the one shown when you run nvidia-smi ) - for a given period of time – s – the percent of time that at least one kernel is present on GPU ( regardless of how much resource the process is using ). So, if we run the above process, where DRAM utilization is 10% and CUDA core utilization is 30%. As long as it is doing some computation on the GPU, utilization would be 100%. (Please refer to the link above for more detail)

The utilization I am referring to is the second one. It doesn’t matter how much of GPU’s resources a process is using. I’m not concerned with that ( although that would be a good future goal to tackle ). I’m concerned that, even if a pytorch DNN is running for minutes, nvidia-smi's reported utilization is around 30%. THIS MENAS THAT, IN EVERY SECOND, THE GPU IS NOT DOING ANYTHING FOR 0.7 SECOND!! That’s simply unacceptable.

6 Likes

If I add an artificial bottleneck in the data loading part using this Dataset:

class MyDataset(Dataset):
    def __init__(self, size):
        self.size = size
        
    def __getitem__(self, index):
        x = torch.randn(3, 224, 224)
        y = torch.randint(0, 1000, (1,)).squeeze()
        
        # Simulate slow data loading
        a = torch.randn(100)
        for _ in range(1000):
            for _ in range(100):
                a = a * a
        return x, y
    
    def __len__(self):
        return self.size

I get an utilization of approx. 13%.
It might be a good idea to time your data loading using the ImageNet example.

Thanks for your reply.

I have test time cost on data loading.
In the training phase, some iterations as follow:

Epoch: [1][100/296],
Learning_Rate: 0.000500,
Time: 0.4145,       Data:     0.0837,
... ...
Epoch: [2][140/296],
Learning_Rate: 0.000492,
Time: 0.4112,       Data:     0.0885,
... ...

batch_size=64 and input_size=(256,256)
each iteration cost 0.4s and data loading cost 0.08s
But in validation phase, data loading cost too much time.

Epoch 1 validation done !
Time: 0.4080,       Data:     0.3324,

Epoch 2 validation done !
Time: 0.4062,       Data:     0.3306,
... ...

Does it seem that there is something wrong in the custom dataset?

rgb_mean = (0.4353, 0.4452, 0.4131)
rgb_std = (0.2044, 0.1924, 0.2013)

class MyDataset(Dataset):
    def __init__(self,
                 config,
                 subset):
        super(MyDataset, self).__init__()
        assert subset == 'train' or subset == 'val' or subset == 'test'
        self.config = config
        self.root = self.config.root_dir
        self.subset = subset
        self.data = self.config.data_folder_name
        self.target = self.config.target_folder_name
        self.mapping = {
            0: 0,
            255: 1,
        }
        self.data_list = glob.glob(os.path.join(
            self.root,
            subset,
            self.data,
            '*.tif'
        ))
        self.target_list = glob.glob(os.path.join(
            self.root,
            subset,
            self.target,
            '*.tif'
        ))

    def mask_to_class(self, mask):
        for k in self.mapping:
            mask[mask == k] = self.mapping[k]
        return mask

    def train_transforms(self, image, mask):

        resize = transforms.Resize(size=(self.config.input_size, self.config.input_size))
        image = resize(image)
        mask = resize(mask)

        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        image = TF.to_tensor(image) # scale 0-1
        image = TF.normalize(image, mean=rgb_mean, std=rgb_std) # normalize
        mask = torch.from_numpy(np.array(mask, dtype=np.uint8))
        mask = self.mask_to_class(mask)
        mask = mask.long()
        return image, mask

    def untrain_transforms(self, image, mask):

        resize = transforms.Resize(size=(self.config.input_size, self.config.input_size))
        image = resize(image)
        mask = resize(mask)
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=rgb_mean, std=rgb_std)
        mask = torch.from_numpy(np.array(mask, dtype=np.uint8))
        mask = self.mask_to_class(mask)
        mask = mask.long()
        return image, mask

    def __getitem__(self, index):

        datas = Image.open(self.data_list[index])
        targets = Image.open(self.target_list[index])
        if self.subset == 'train':
            t_datas, t_targets = self.train_transforms(datas, targets)
            return t_datas, t_targets
        elif self.subset == 'val':
            t_datas, t_targets = self.untrain_transforms(datas, targets)
            return t_datas, t_targets
        elif self.subset == 'test':
            t_datas, t_targets = self.untrain_transforms(datas, targets)
            return t_datas, t_targets, self.data_list[index]

    def __len__(self):

        return len(self.data_list)

Thank you :smiley:

I have tried to display the data loading cost of each iteration in the training phase and validation phase, the observation as follow:

  • the first iteration of the first epoch cost so much time: 8s around for training dataset and 6s around for the validation dataset ( training dataset has 18944 images in 256x256 and validation dataset has 4144 images in 256x256) it seems normal due to the different size.
  • time cost of other iteration ( except the first iter of the first epoch ) all at the same level.
  • I think the display in training phase and validation phase is ok because batch_size is the same.

so, what could cause the utilization of validation phase close to zero?
Looking forward to your help, thank you!

Thanks for the detailed analysis.
I assume that the time given in your output corresponds to a single iteration time.

For me it looks like your data loading time is being hidden in the training script, since your training takes some time and the workers can preload the next batch. Thus it’s quite low at 0.08s. During validation the workload is smaller, since you are just computing the forward pass, thus the data loading time is now present. This might also be the reason for the low GPU utilization, since it now seems to create a data loading bottleneck due to the low workload during validation. Your Dataset implementation looks alright.

The first iteration might take a bit more time, as all workers are loading a batch and need some “warm up time”.

5 Likes

Oh, thank you for your detailed explanation. :wink:

I am sorry to bother you again.

If the GPU utilization ( checked by nvidia-smi ) is defined as @0xFFFFFFFF mentioned above, it means that the time data loading cost is much more than the forward pass cost, so that there exist periods of time that no kernel is present on GPU and only loading data to GPU, right?:thinking:

If so, there really is a ‘gap’ between data loading and processing. Is there any way to avoid it?

Thanks a million.

Yes, that’s usually the case if your actual workload on the GPU is small and thus your CPU code execution cannot be hidden. You could try to play around with the number of workers to possibly speed up the data loading. Also make sure the data is stored on an SSD. If you are using some image preprocessing, you might want to install PIL-SIMD, which is a drop-in replacement for PIL using SIMD instructions.

3 Likes

Ok, thank you very much, i will have a try. Thanks.

My understanding of data loading makes me believe it wouldn’t be relevant on inference ( not training but production ). All the weight, model and input start from GPU RAM ( because they are only a couple GBs combined and can be pre-loaded onto the device before inference). If Dataloader is supposed to aide asynchronously copying memory from CPU to GPU while GPU is doing some work, then it doesn’t help.

If you could copy the input data somehow onto the GPU beforehand (and have the memory to do so), then the DataLoader won’t help anything, that’s true.

However, if we are talking about production systems, I assume you’ll get the data from some kind of streaming service (in which case the usage of a DataLoader wouldn’t make sense). In that case you would still have to push the data onto the GPU or am I misunderstanding your use case?

The input to tacotron2, for example, is a string of texts. Therefore, it merely takes a couple hundred bytes. So, yes we do get input data from a streaming service, but all our engines have input that are less than a megabyte which makes the cost of loading them upfront before running the model trivial. So data loading is not the cause of low utilization.

Yeah right, you’ve mentioned tacotron2. Have you had a chance to profile it?
I’ll try to get it working on my machine and have a look at it.

@ptrblck I liked the ‘dummy code’ above, so I thought I’d play with it a little bit, since I’ve also been trying to understand some low utilization in Pytorch. Maybe this is getting a little off topic…but maybe not.

I made simple tweaks to support training and testing modes, different # workers, different batch sizes, and Windows. Still resnet50, (3,224,224). (Also:Titan XP, pytorch-nightly from Feb 28).

The timing numbers ignore the first call to the model, since that’s much slower. Utilization (for training mode) is from nvidia-smi, captured by hand.

High ‘utilization’ can be reached in training mode, at high batch sizes (as one would expect).

TRAIN Rate Imgs/s Imgs/s Imgs/s Util. Util. Util.
Batch Size #Wrk=1 #Wrk=2 #Wrk=4 #Wrk=1 #Wrk=2 #Wrk=4
1 10.51 10.82 10.79 23% 27% 25%
2 20.81 17.62 20.85 27% 30% 30%
4 40.86 43.28 42.51 28% 37% 37%
8 79.31 85.84 82.68 32% 35% 50%
16 142.99 165.7 161.5 46% 50% 78%
32 155.08 183.86 179.83 70% 75% 89%

For test mode, I get:

TEST Rate Imgs/s Imgs/s Imgs/s Imgs/s Imgs/s Imgs/s Imgs/s Imgs/s
Batch Size #Wrk=1 #Wrk=2 #Wrk=4 #Wrk=8 #Wrk=1 #Wrk=2 #Wrk=4 #Wrk=8
1 33.16 33.37 33.48 31.83 20% 19% 19% 19%
2 65.21 65.08 55.71 62.35 29% 22% 24% 22%
4 125.79 123.84 118.78 116.11 28% 31% 24% 28%
8 147.84 227.16 226.23 220.73 30% 35% 39% 41%
16 162.01 271.54 393.73 396.41 30% 39% 58% 60%
32 170.08 293.18 475.14 529.37 32% 57% 78% 77%
~Max possible 640

Max possible is estimated by dividing BS=32 by the best-case GPU time in nvprof (~50ms). I see a pattern in nvprof for 4 workers, where it’s ~(50,50,50,50)ms, and then there is an extra delay. I don’t quite see the same pattern for #workers=8, but I do see interspersed gaps.

[BTW: similar-in-spirit benchmark code in recent MXNet gives 732 images/s, comparable to the 640 number, I believe. If there is more interest in flushing this out (e.g. with nvprof timings), I can start a new thread here, or a new issue in Pytorch,.]

Here is my modified code:

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import time

def main():
    mode = 'test'
    model = models.resnet50()
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
    N = 1280
    dataset = datasets.FakeData(size=N, transform=transforms.ToTensor())
    if mode=='test': # switch to evaluate mode
        model.eval()
    model.to('cuda')
    for num_workers in [1, 2, 4, 8]: # 4 < 2 for test
        for batch_size in [1, 2, 4, 8, 16, 32]:
            loader = DataLoader(dataset, num_workers=num_workers, batch_size=batch_size, pin_memory=True)
            if mode=='test':
                for i, (data, target) in enumerate(loader):
                    if i==1:
                        tm = time.time()
                    data = data.to('cuda', non_blocking=True)
                    output = model(data)
            else: # mode=='train':
                for i, (data, target) in enumerate(loader):
                    if i==1:
                        tm = time.time()
                    data = data.to('cuda', non_blocking=True)
                    target = target.to('cuda', non_blocking=True).long()
                    optimizer.zero_grad()
                    output = model(data)
                    loss = criterion(output, target)
                    loss.backward()
                    optimizer.step()
            tm = time.time() - tm
            print('Mode=%s: NumWorkers=%2d  BatchSize=%2d  Time=%6.3fs  Imgs/s=%6.2f' % (mode, num_workers, batch_size, tm, N/tm))
            torch.cuda.empty_cache() # doesn't seem to be working...

if __name__ == '__main__':
    main()

Maybe we should create a separate thread regarding increasing utilization in training. The topic of this thread is mostly for inference ( production ) and not training. Therefore, dataloader, as discussed above, is not relevant since it is not used at all. I changed the title to make it more clear.

For completeness, I edited the post above, adding the (rough) utilization numbers. They follow a similar trend as training, but are generally lower, as expected; the GPU is less busy, without back-prop.

Certainly, dataloader wouldn’t be used for production. The numbers from nvprof are more telling.

Resnet50 is not Tacotron, so that would have to be benchmarked (and examined in nvprof). But in general, you’ll either need to go data-parallel or model-parallel (if you have the memory) to get the highest utilization.

If your interested in speeding up inference, I’d suggest looking at this: https://developer.nvidia.com/tensorrt. Even if you don’t think this looks applicable to your situation, some of the TensorRT documentation has good discussions about inference performance, see here: https://docs.nvidia.com/deeplearning/sdk/pdf/TensorRT-Best-Practices.pdf

Hello, I just test in inference process, if I have 25 images need be handled, it seems 1st time cost some time, then next 24 images cost very little time. but the interesting thing is if I have 26 images then the 26 images will cost same time with 1st image. could you help to explain that?

1st image Time: 0.4107s

2nd -25nd images Time: 0.0010s

26 images Time: 0.4206s

Here is some profiling data I have collected using Nsight Systems during inference of Tacotron2. When we look at the result, it becomes clear why the utilization is so low when we perform inference with Pytorch.

The problems are:

  1. extremely small kernels ( that take around 5 micro seconds ) are called one at a time, resulting in the cost of kernal launch ( on the CPU )being generally more expensive than the cost of the kernel itself ( on the GPU ). This makes the time of invoking a kernel more expensive than actually doing the computation ( for example, the time taken on the CPU to launch “gemv2T_kernel_val” is about 15 micro seconds, where as the time taken on the GPU to actually complete the computation is about 5 micro seconds ).

  2. there seems to be gaps between each CUDA API calls because Pytorch adds additional wrappers around the tensors.

This is a portion of the Nsight Compute profiler. The skyblue bar indicates some sort of work happening on the GPU. the absence of the sky blue bar or the red bar indicates that no computation is happening on the GPU. The “CUDA API” row indicates CPU’s preperation of launching cuda kernels – it is CPU side work in order to launch kernel on the GPU.

As we zoom in furhter, we can see that there are HUGE gaps on the work that is happening on the device.

and even more if we zoom in even further

1 Like

In this issue, a dev from nvidia explains why this problem is occuring. Essentially, the asnwer is: pytorch is not optimized well plus the nature of Tacotron2’s network architecture produced this low nvidia-smi utilization. It is not a bug.

2 Likes