Pytorch 2.0.1+cu117, Win10, powerful workstation with nvidia gpu and 16 cpus, full code below
I have a dataset of 240,000 28x28 images, but once I figure out this issue I plan on moving to a much larger dataset. I was simply loading these images and running them through the model, but wanted to augment my data by slightly changing the images at random.
Making this augmentation change (in __getitem__
) about quadrupled my training time due to the perturbImage
call. It was at this point I looked into using the DataLoader
class to load and alter the data in the background and avoid the drastic slowdown I just introduced. However, passing in any value for num_workers
besides 0 results in extreme slowdowns and if I set num_workers around 12 and let it run for a few minutes, it just freezes my computer until I manage to open task manager and kill the Python tasks. I’ve tried playing around with some settings like pin_memory and persistent_workers with no noticeable positive changes. What is going on here?
Time to load data and run 1 epoch:
num_workers 0, no perturbImage: 20 seconds
num_workers 0, with perturbImage: 91 seconds
num_workers 1, with perturbImage: 4 minutes
num_workers 3, with perturbImage: 14 minutes
num_workers 12, with perturbImage: greater than 30 minutes, killed process
Code, with no perturbImage call and 0 workers:
from multiprocessing import freeze_support
import random
import torch.nn as nn
import torch.nn.functional as F
import torch
import torchvision
from torch.utils.data import Dataset
import torch.optim as optim
import time
import os
def perturbImage(img, jit, blurrer, pt, transpose = False):
numOptions = 4
pickedOption = random.randint(1, numOptions)
if pickedOption == 1:
img = jit(img)
elif pickedOption == 2:
img = blurrer(img)
elif pickedOption == 3:
img = pt(img)
elif pickedOption == 4:
rots = [90, 180, 270]
pickedRot = rots[random.randint(0, 2)]
img = torchvision.transforms.functional.rotate(img, pickedRot)
if transpose:
img = torch.transpose(img, 0, 2)
img = torch.transpose(img, 0, 1)
return img
class CustomMnistRotateDataset(Dataset):
def __init__(self, data_dir):
self.data_dir = data_dir
self.jit = torchvision.transforms.ColorJitter(brightness=.5, hue=0.1)
self.blurrer = torchvision.transforms.GaussianBlur(kernel_size=(5, 9), sigma=(0.1, .8))
self.pt = torchvision.transforms.RandomPerspective(distortion_scale=0.4, p=1.0)
print(f'loading CustomMnistRotateDataset from {data_dir}')
self.data = torch.load(data_dir)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img = self.data[idx][0]
label = self.data[idx][1]
#img = perturbImage(img, self.jit, self.blurrer, self.pt)
return img, label
def loadData(type = None, trainPath = None, testPath = None):
if type == 'custfashmnist':
raw_training_data = CustomMnistRotateDataset(trainPath)
raw_test_data = CustomMnistRotateDataset(testPath)
return (raw_training_data, raw_test_data)
class BaseNet(nn.Module):
def __init__(self, imgSize, outputSize):
super().__init__()
self.linear_relu_stack = nn.Sequential(
nn.Linear(imgSize, 5120),
nn.ReLU(),
nn.Linear(5120, 5120),
nn.ReLU(),
nn.Linear(5120, 5120),
nn.ReLU(),
nn.Linear(5120, outputSize),
)
def forward(self, x):
x = self.linear_relu_stack(x)
x = F.log_softmax(x, dim=1)
return x
def train(net, epochs, trainset, losses, start_time, lossRecordEvery = 10, imgSize = (28*28), lr = 0.001, device = "cuda"):
optimizer = optim.Adam(net.parameters(), lr = lr)
i = 0
for epoch in range(epochs):
for batchData in trainset:
batchImg, batchLabel = batchData
net.zero_grad()
forwardOutput = net(batchImg.view(-1, imgSize).to(device))
loss = F.nll_loss(forwardOutput, batchLabel.to(device))
if (i == lossRecordEvery):
losses.append(loss.item())
i = 0
i += 1
loss.backward()
optimizer.step()
print(loss)
print(f"--- {(time.time() - start_time)} epoch {epoch} time to run ---" )
def go():
start_time = time.time()
device = (
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
print(f"Using {device} device")
dataset = 'custfashmnist'
if dataset == 'custfashmnist':
imgSize = 28*28
imgDimsenion = 1
outputSize = 10
raw_training_data, raw_test_data = loadData('custfashmnist', 'data/expanded-train', 'data/expanded-test')
net = BaseNet(imgDimsenion* imgSize, outputSize).to(device)
epochs = 1
losses = []
# num_workers=1 takes 4 minutes for 1 epoch(!!!), num_workers=3 takes 14 minutes for 1 epoch(!!!) num_workers=0 with perturbImage takes 91 seconds for 1 epoch, without perturbImage num_workers=0 takes 20 seconds
trainset = torch.utils.data.DataLoader(raw_training_data, batch_size=500, shuffle=True, num_workers=0, persistent_workers=False, pin_memory=False)
print('start train')
train(net, epochs, trainset, losses, start_time, imgSize=(imgSize*imgDimsenion), lossRecordEvery=100, lr = 0.0001)
if __name__ == '__main__':
#freeze_support()
print(torch.__version__) # 2.0.1+cu117
go()