DataLoader num_workers drastically increases run time

Pytorch 2.0.1+cu117, Win10, powerful workstation with nvidia gpu and 16 cpus, full code below

I have a dataset of 240,000 28x28 images, but once I figure out this issue I plan on moving to a much larger dataset. I was simply loading these images and running them through the model, but wanted to augment my data by slightly changing the images at random.

Making this augmentation change (in __getitem__) about quadrupled my training time due to the perturbImage call. It was at this point I looked into using the DataLoader class to load and alter the data in the background and avoid the drastic slowdown I just introduced. However, passing in any value for num_workers besides 0 results in extreme slowdowns and if I set num_workers around 12 and let it run for a few minutes, it just freezes my computer until I manage to open task manager and kill the Python tasks. I’ve tried playing around with some settings like pin_memory and persistent_workers with no noticeable positive changes. What is going on here?

Time to load data and run 1 epoch:
num_workers 0, no perturbImage: 20 seconds
num_workers 0, with perturbImage: 91 seconds
num_workers 1, with perturbImage: 4 minutes
num_workers 3, with perturbImage: 14 minutes
num_workers 12, with perturbImage: greater than 30 minutes, killed process

Code, with no perturbImage call and 0 workers:

from multiprocessing import freeze_support
import random
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torch.utils.data import Dataset
import torch.optim as optim
import time
import os

def perturbImage(img, jit, blurrer, pt, transpose = False):

    numOptions = 4
    pickedOption = random.randint(1, numOptions)

    if pickedOption == 1:
        img = jit(img)
    elif pickedOption == 2:
        img = blurrer(img)
    elif pickedOption == 3:
        img = pt(img)
    elif pickedOption == 4:
        rots = [90, 180, 270]
        pickedRot = rots[random.randint(0, 2)]
        img = torchvision.transforms.functional.rotate(img, pickedRot)

    if transpose:
        img = torch.transpose(img, 0, 2)
        img = torch.transpose(img, 0, 1)

    return img

class CustomMnistRotateDataset(Dataset):
    def __init__(self, data_dir):
        self.data_dir = data_dir
        self.jit = torchvision.transforms.ColorJitter(brightness=.5, hue=0.1)
        self.blurrer = torchvision.transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, .8))
        self.pt = torchvision.transforms.RandomPerspective(distortion_scale=0.4, p=1.0)
        
        print(f'loading CustomMnistRotateDataset from {data_dir}')
        self.data = torch.load(data_dir)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx][0]
        label = self.data[idx][1]
        #img = perturbImage(img, self.jit, self.blurrer, self.pt)

        return img, label

def loadData(type = None, trainPath = None, testPath = None):
    if type == 'custfashmnist':
        raw_training_data = CustomMnistRotateDataset(trainPath)
        raw_test_data = CustomMnistRotateDataset(testPath)
    return (raw_training_data, raw_test_data)

class BaseNet(nn.Module):
    def __init__(self, imgSize, outputSize):
        super().__init__()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(imgSize, 5120),
            nn.ReLU(),
            nn.Linear(5120, 5120),
            nn.ReLU(),
            nn.Linear(5120, 5120),
            nn.ReLU(),
            nn.Linear(5120, outputSize),
        )

    def forward(self, x):
        x = self.linear_relu_stack(x)
        x = F.log_softmax(x, dim=1)
        return x

def train(net, epochs, trainset, losses, start_time, lossRecordEvery = 10, imgSize = (28*28), lr = 0.001, device = "cuda"):
    optimizer = optim.Adam(net.parameters(), lr = lr) 
    i = 0
    for epoch in range(epochs):
        for batchData in trainset:
            batchImg, batchLabel = batchData
            net.zero_grad()
            forwardOutput = net(batchImg.view(-1, imgSize).to(device))
            loss = F.nll_loss(forwardOutput, batchLabel.to(device))
            if (i == lossRecordEvery):
                losses.append(loss.item())
                i = 0
            i += 1
            loss.backward()
            optimizer.step()
        print(loss)
        print(f"--- {(time.time() - start_time)} epoch {epoch} time to run ---" )

def go():
    start_time = time.time()

    device = (
        "cuda"
        if torch.cuda.is_available()
        else "mps"
        if torch.backends.mps.is_available()
        else "cpu"
    )
    print(f"Using {device} device")

    dataset = 'custfashmnist'

    if dataset == 'custfashmnist':
        imgSize = 28*28 
        imgDimsenion = 1 
        outputSize = 10 
        raw_training_data, raw_test_data = loadData('custfashmnist', 'data/expanded-train', 'data/expanded-test')

    net = BaseNet(imgDimsenion* imgSize, outputSize).to(device)

    epochs = 1
    losses = []

    # num_workers=1 takes 4 minutes for 1 epoch(!!!), num_workers=3 takes 14 minutes for 1 epoch(!!!) num_workers=0 with perturbImage takes 91 seconds for 1 epoch, without perturbImage num_workers=0 takes 20 seconds
    trainset = torch.utils.data.DataLoader(raw_training_data, batch_size=500, shuffle=True, num_workers=0, persistent_workers=False, pin_memory=False)
    print('start train')
    train(net, epochs, trainset, losses, start_time, imgSize=(imgSize*imgDimsenion), lossRecordEvery=100, lr = 0.0001)

if __name__ == '__main__':
    #freeze_support()
    print(torch.__version__) # 2.0.1+cu117
    go()