Why is pytorch's GPU utilization so low in production ( NOT training )?

utilization ( which you can check using nvidia-smi) – defined in this link is not how well a process is using the GPU resources. Please read the definition if you aren’t sure.

Why is GPU utilization so low for codes written in Pytorch ( averages around 30% ) ? Does pytorch create unnecessary work for CPU?

4 Likes

The GPU utilization depends on your setting of DataLoader.

1 Like

Besides the possible data loading bottleneck, small architectures also might result in a low GPU utilization.
Do you see some peak utilization and then a small pause or is the utilization approx. the whole time at 30%?
Could you explain your model architecture a bit?

1 Like

Background:

  • utilization (from nvidia-smi) is: utilization = time that at least one process is running on GPU / total sampled time * 100 . It has nothing to do with how much the process running on GPU uses GPU’s resources.
  • inference only ( with no data loading, since we start with all the weights on the GPU and input is merely a string of words )
  • tacotron2 link
  • does not peak. stays at 30% the entire time.
  • even if the inference takes longer than 1.5 seconds, which is longer than total sampled time, utilization is at 30%.
  • CentOS7 with Pytorch 0.4.1

I think the NVIDIA Visual Profiler might give you a clear information on what’s going on exactly under the hood.

1 Like

Okay. So this is not an inherent issue of every pytorch network? We are currently testing with more than 10 different models, and all except for 1 display this behaviour, which made us believe that they is something inherently inefficient about pytorch. Is this something unusual?

I will get back to you guys if I find anything useful.

It shouldn’t be, as I usually try to get the utilization at >90% on my system while training.
Often I have some bottlenecks or unwanted synchronization points which slow down my code.
Could you test this dummy code and check the utilization?

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision.models as models
import torchvision.datasets as datasets
import torchvision.transforms as transforms


model = models.resnet50()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

dataset = datasets.FakeData(
    size=1000,
    transform=transforms.ToTensor())
loader = DataLoader(
    dataset,
    num_workers=1,
    pin_memory=True
)

model.to('cuda')

for data, target in loader:
    data = data.to('cuda', non_blocking=True)
    target = target.to('cuda', non_blocking=True)
    optimizer.zero_grad()
    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()
    
print('Done')

I’m just curious to see if you can also reach a utilization of approx. 94% or of something else might be the issue.

5 Likes

Hi ptrblck,

I have encountered a similar problem a few days ago.
The utilization of GPU sometimes high to 80% and sometimes low to 20% in the traing phase, and the utilization is close to zero in validation phase.
After I try different num_workers value in dataloader and compute metircs on GPU instead on CPU in numpy, the utilization can be 80-90% stably, but in validation phase at the end of each epoch, the utilization is still close to 0, what might make this strange phenomenon?

Thanks in advance:thinking:

Are you transferring your output or loss to the CPU in each iteration in your validation case?
Could you post the validation code so that we can have a look?

1 Like

Sorry for the late reply.
I upload my code to github and here is the link for BaseTrainer which include training phase and validation phase, and test phase in BaseTester.In the begining, it computed metric(like MIoU and Accuracy) in numpy on cpu, after improved(change to gpu), the utilization is higher and stable in training phase, but I am still not sure for my metrics are correct.
I think there are still many problems in my code.
Thank you for your help.:smile:

I get this error when I run the code:

Traceback (most recent call last):
  File "pytorch_example.py", line 29, in <module>
    loss = criterion(output, target)
  File "/usr/lib64/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/usr/lib64/python3.6/site-packages/torch/nn/modules/loss.py", line 904, in forward
    ignore_index=self.ignore_index, reduction=self.reduction)
  File "/usr/lib64/python3.6/site-packages/torch/nn/functional.py", line 1970, in cross_entropy
    return nll_loss(log_softmax(input, 1), target, weight, None, ignore_index, None, reduction)
  File "/usr/lib64/python3.6/site-packages/torch/nn/functional.py", line 1790, in nll_loss
    ret = torch._C._nn.nll_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index)
RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

I use pytorch version: 1.0.1 and 0.4.1 to run the code and get errors on both.

I fixed the dummy code as the following:

target = target.to('cuda', non_blocking=True).long()

I got the utilization at 98%.

Using GTX 1080, CentOS7, Pytorch 1.0.1, utilization hovered around 77 ~ 90% ( very rarely did it peak to 97%.

I’m still puzzled why many of the models we have tested barely reach 40%.

Thanks, now I was able to run it.

I just got the utilization at 70%, using Tesla v100 16g, ubantu, pytorch 1.0.1 post2.
Dose utilization related to the size of dataset?

Only very loosely. I hate how Nvidia decided to name this metric “utilization” because it confuses the heck out of me and others.

There are basically 2 definitions of utilization:

  1. utilization( what it is most commonly used as) - measure of how much GPU resource a process ( kernel ) is using. If GPU has 100GB of DRAM, and we only use 10GB, then DRAM utilization would be 10%. If there are 1000 CUDA cores and we use 300 of them on average through out the process, then CUDA core utilization would be 30%.

  2. utilization ( the one shown when you run nvidia-smi ) - for a given period of time – s – the percent of time that at least one kernel is present on GPU ( regardless of how much resource the process is using ). So, if we run the above process, where DRAM utilization is 10% and CUDA core utilization is 30%. As long as it is doing some computation on the GPU, utilization would be 100%. (Please refer to the link above for more detail)

The utilization I am referring to is the second one. It doesn’t matter how much of GPU’s resources a process is using. I’m not concerned with that ( although that would be a good future goal to tackle ). I’m concerned that, even if a pytorch DNN is running for minutes, nvidia-smi's reported utilization is around 30%. THIS MENAS THAT, IN EVERY SECOND, THE GPU IS NOT DOING ANYTHING FOR 0.7 SECOND!! That’s simply unacceptable.

6 Likes

If I add an artificial bottleneck in the data loading part using this Dataset:

class MyDataset(Dataset):
    def __init__(self, size):
        self.size = size
        
    def __getitem__(self, index):
        x = torch.randn(3, 224, 224)
        y = torch.randint(0, 1000, (1,)).squeeze()
        
        # Simulate slow data loading
        a = torch.randn(100)
        for _ in range(1000):
            for _ in range(100):
                a = a * a
        return x, y
    
    def __len__(self):
        return self.size

I get an utilization of approx. 13%.
It might be a good idea to time your data loading using the ImageNet example.

Thanks for your reply.

I have test time cost on data loading.
In the training phase, some iterations as follow:

Epoch: [1][100/296],
Learning_Rate: 0.000500,
Time: 0.4145,       Data:     0.0837,
... ...
Epoch: [2][140/296],
Learning_Rate: 0.000492,
Time: 0.4112,       Data:     0.0885,
... ...

batch_size=64 and input_size=(256,256)
each iteration cost 0.4s and data loading cost 0.08s
But in validation phase, data loading cost too much time.

Epoch 1 validation done !
Time: 0.4080,       Data:     0.3324,

Epoch 2 validation done !
Time: 0.4062,       Data:     0.3306,
... ...

Does it seem that there is something wrong in the custom dataset?

rgb_mean = (0.4353, 0.4452, 0.4131)
rgb_std = (0.2044, 0.1924, 0.2013)

class MyDataset(Dataset):
    def __init__(self,
                 config,
                 subset):
        super(MyDataset, self).__init__()
        assert subset == 'train' or subset == 'val' or subset == 'test'
        self.config = config
        self.root = self.config.root_dir
        self.subset = subset
        self.data = self.config.data_folder_name
        self.target = self.config.target_folder_name
        self.mapping = {
            0: 0,
            255: 1,
        }
        self.data_list = glob.glob(os.path.join(
            self.root,
            subset,
            self.data,
            '*.tif'
        ))
        self.target_list = glob.glob(os.path.join(
            self.root,
            subset,
            self.target,
            '*.tif'
        ))

    def mask_to_class(self, mask):
        for k in self.mapping:
            mask[mask == k] = self.mapping[k]
        return mask

    def train_transforms(self, image, mask):

        resize = transforms.Resize(size=(self.config.input_size, self.config.input_size))
        image = resize(image)
        mask = resize(mask)

        if random.random() > 0.5:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
        if random.random() > 0.5:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
        image = TF.to_tensor(image) # scale 0-1
        image = TF.normalize(image, mean=rgb_mean, std=rgb_std) # normalize
        mask = torch.from_numpy(np.array(mask, dtype=np.uint8))
        mask = self.mask_to_class(mask)
        mask = mask.long()
        return image, mask

    def untrain_transforms(self, image, mask):

        resize = transforms.Resize(size=(self.config.input_size, self.config.input_size))
        image = resize(image)
        mask = resize(mask)
        image = TF.to_tensor(image)
        image = TF.normalize(image, mean=rgb_mean, std=rgb_std)
        mask = torch.from_numpy(np.array(mask, dtype=np.uint8))
        mask = self.mask_to_class(mask)
        mask = mask.long()
        return image, mask

    def __getitem__(self, index):

        datas = Image.open(self.data_list[index])
        targets = Image.open(self.target_list[index])
        if self.subset == 'train':
            t_datas, t_targets = self.train_transforms(datas, targets)
            return t_datas, t_targets
        elif self.subset == 'val':
            t_datas, t_targets = self.untrain_transforms(datas, targets)
            return t_datas, t_targets
        elif self.subset == 'test':
            t_datas, t_targets = self.untrain_transforms(datas, targets)
            return t_datas, t_targets, self.data_list[index]

    def __len__(self):

        return len(self.data_list)

Thank you :smiley:

I have tried to display the data loading cost of each iteration in the training phase and validation phase, the observation as follow:

  • the first iteration of the first epoch cost so much time: 8s around for training dataset and 6s around for the validation dataset ( training dataset has 18944 images in 256x256 and validation dataset has 4144 images in 256x256) it seems normal due to the different size.
  • time cost of other iteration ( except the first iter of the first epoch ) all at the same level.
  • I think the display in training phase and validation phase is ok because batch_size is the same.

so, what could cause the utilization of validation phase close to zero?
Looking forward to your help, thank you!

Thanks for the detailed analysis.
I assume that the time given in your output corresponds to a single iteration time.

For me it looks like your data loading time is being hidden in the training script, since your training takes some time and the workers can preload the next batch. Thus it’s quite low at 0.08s. During validation the workload is smaller, since you are just computing the forward pass, thus the data loading time is now present. This might also be the reason for the low GPU utilization, since it now seems to create a data loading bottleneck due to the low workload during validation. Your Dataset implementation looks alright.

The first iteration might take a bit more time, as all workers are loading a batch and need some “warm up time”.

5 Likes