Low GPU Usage during Training

Unfortunately I think I can’t help much. As the file in your github running successfully for me for different batch sizes and differ workers also. My gpu utilization is around 22%. But I am using linux machine and I do have 32 cpu’s in it.

But I did found one interesting article about num_wprkers>0 on windows mightnot work and how to fix it
solution is
( Errors when using num_workers>0 in DataLoader - PyTorch Forums)

def train():
    # Here was inserted the whole code that train the network ...
if __name__ == '__main__':
    train()

So I tried modifying your files but as I mentioned I don’t have windows system so try my below code in MNIST.py file if doesn’t work please follow the above link

import numpy as np
import torch
import torch.nn as nn
import torchvision

from torchvision.transforms import ToTensor
from torch.utils.data import Dataset
from torchvision.utils import make_grid
from torch.utils.data import random_split

import matplotlib.pyplot as plt

from model import MnistMLP,MnistCNN
from trainer import TrainerConfig,Trainer
from visualize import Plot

train_set = torchvision.datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_set = torchvision.datasets.MNIST(
    root="./data",
    train=False,
    download=True,
    transform=ToTensor()
)

if __name__ == '__main__': 
    cnn_train_configs = TrainerConfig(ckpt_path="./CNNModel.pt",max_epochs=40,learning_rate=4.67e-4,weight_decay=6.423e-4)
    CNN_model = MnistCNN()
    trainer = Trainer(model=CNN_model,train_dataset=train_set,test_dataset=test_set,config=cnn_train_configs)
    model_metrics = trainer.train()

    plotter = Plot(model_metrics=model_metrics)
    plotter.plot()