Low GPU Usage during Training

Unfortunately I think I can’t help much. As the file in your github running successfully for me for different batch sizes and differ workers also. My gpu utilization is around 22%. But I am using linux machine and I do have 32 cpu’s in it.

But I did found one interesting article about num_wprkers>0 on windows mightnot work and how to fix it
solution is
( Errors when using num_workers>0 in DataLoader - PyTorch Forums)

def train():
    # Here was inserted the whole code that train the network ...
if __name__ == '__main__':
    train()

So I tried modifying your files but as I mentioned I don’t have windows system so try my below code in MNIST.py file if doesn’t work please follow the above link

import numpy as np
import torch
import torch.nn as nn
import torchvision

from torchvision.transforms import ToTensor
from torch.utils.data import Dataset
from torchvision.utils import make_grid
from torch.utils.data import random_split

import matplotlib.pyplot as plt

from model import MnistMLP,MnistCNN
from trainer import TrainerConfig,Trainer
from visualize import Plot

train_set = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_set = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=ToTensor()
)

if __name__ == '__main__': 
    cnn_train_configs = TrainerConfig(ckpt_path="./CNNModel.pt",max_epochs=40,learning_rate=4.67e-4,weight_decay=6.423e-4)
    CNN_model = MnistCNN()
    trainer = Trainer(model=CNN_model,train_dataset=train_set,test_dataset=test_set,config=cnn_train_configs)
    model_metrics = trainer.train()

    plotter = Plot(model_metrics=model_metrics)
    plotter.plot()

It works!

My GPU is now being used around 20% during training.

Thank you so much for helping me out.

1 Like

I had a similar problem for my training. It ended being GPUs were waiting for IO and CPU to finish the work. I could see in my GPU metrics that GPU was working 100% sometimes but most of the time they were waiting. I had to change my approach how I loaded data.

1 Like

I think I met the same problem, may I know how did you solve it?
Thank you!