Hi, I’ve seen several posts about num_workers and there are answers to suggest the ideal num_workers is to be 4* num_GPUs but I just can’t get the same speed boost with more num_workers. Being new to deep learning, I plan to open this post with a reproducible code example using Mnist, to understand fully on how to improve the training speed.
I’m using Ubuntu 20.04 LTS and have a RTX 3080, when I don’t use the batch training and just train the whole 60,000 like below, it takes about 6-7 seconds to finish the training and GPU usage at 99-100%.
import torch
import torch.nn as nn
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader, TensorDataset
import numpy as np
import random
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
import timeit
# Set Device function (to GPU)
def set_device():
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
print("GPU is not enabled")
else:
print("GPU is enabled")
return device
DEVICE = set_device()
# set seed function
def set_seed(seed=None, seed_torch=True):
if seed is None:
seed = np.random.choice(2 ** 32)
random.seed(seed)
np.random.seed(seed)
if seed_torch:
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
print(f'Random seed {seed} has been set.')
SEED = 2021
# for DataLoader
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
# Download Mnist datasets
train_data = datasets.MNIST(
root='data',
train=True,
transform=ToTensor(),
download=True,
)
test_data = datasets.MNIST(
root='data',
train=False,
transform=ToTensor()
)
X = train_data.data.reshape(60000, -1).float()
y = train_data.train_labels
X_test = test_data.data.reshape(10000, -1).float()
y_test = test_data.train_labels
# Simple Neural Net
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
# define layers
self.layers = nn.Sequential(
nn.Linear(784, 600),
nn.ReLU(),
nn.Linear(600, 300),
nn.ReLU(),
nn.Linear(300, 100),
nn.ReLU(),
nn.Linear(100, 10)
)
def forward(self, x):
return self.layers(x)
def predict(self, x):
return torch.argmax(self.forward(x), 1)
# simple train
X = X.to(DEVICE)
y = y.to(DEVICE)
X_test = X_test.to(DEVICE)
y_test = y_test.to(DEVICE)
SEED = 2021
set_seed(SEED)
model = Net().to(DEVICE)
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.005, momentum=0.9)
loss_list = []
logits = model.forward(X)
loss = loss_function(logits, y)
start1 = timeit.default_timer()
for epoch in range(500):
logits = model.forward(X)
loss = loss_function(logits, y)
loss_list.append(loss.item())
loss.backward()
optimizer.step()
optimizer.zero_grad()
if epoch % 20 == 0:
print(f"epoch {epoch + 1}: loss: {loss:.5f},"
f"train_accuracy: {torch.sum(model.predict(X) == y) / 60000:.3f},"
f"test_accuracy:{torch.sum(model.predict(X_test) == y_test) / 10000:.3f}")
end1 = timeit.default_timer()
print(f"Time: {end1 - start1:.2f} seconds")
But when I use batch training like below, the speed drops significantly, and when num_workers=0, it takes 176 seconds to finish the training, and when num_workers=4, it takes 216 seconds to finish the training. And in both scenarios, the GPU usage hover around 20-30% and sometimes even lower. So my question is: is it normal to expect this time increase when using batch training, and if so, why should we use batch training? Is it to improve the test accuracy?
Secondly, why does increasing the num_workers take longer to train? Is there anything fundamentally wrong in the code? And is it normal to have GPU usage low when doing the batch training?
X = train_data.data.reshape(60000, -1).float()
y = train_data.train_labels
X_test = test_data.data.reshape(10000, -1).float()
y_test = test_data.train_labels
# Dataloader
g_seed = torch.Generator()
g_seed.manual_seed(SEED)
batch_size = 300
test_data = TensorDataset(X_test, y_test)
test_loader = DataLoader(test_data, batch_size=batch_size,
shuffle=False, num_workers=8,
worker_init_fn=seed_worker,
generator=g_seed)
train_data = TensorDataset(X, y)
train_loader = DataLoader(train_data, batch_size=batch_size, drop_last=True,
shuffle=True, num_workers=8,
worker_init_fn=seed_worker,
generator=g_seed)
def train_test_classification(net, criterion, optimizer, train_loader,
test_loader, num_epochs=1, verbose=True,
training_plot=True, device='cuda'):
net.train()
training_losses = []
for epoch in tqdm(range(num_epochs)): # loop over the dataset multiple times
running_loss = 0.0
for (i, data) in enumerate(train_loader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
inputs = inputs.to(device).float()
labels = labels.to(device).long()
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
if verbose:
training_losses += [loss.item()]
net.eval()
def test(data_loader):
correct = 0
total = 0
for data in data_loader:
inputs, labels = data
inputs = inputs.to(device).float()
labels = labels.to(device).long()
outputs = net(inputs)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
acc = 100 * correct / total
return total, acc
train_total, train_acc = test(train_loader)
test_total, test_acc = test(test_loader)
if verbose:
print(f"Accuracy on the {train_total} training samples: {train_acc:0.2f}")
print(f"Accuracy on the {test_total} testing samples: {test_acc:0.2f}")
if training_plot:
plt.plot(training_losses)
plt.xlabel('Batch')
plt.ylabel('Training loss')
plt.show()
return train_acc, test_acc
set_seed(SEED)
net = Net().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.005, momentum=0.9)
num_epochs = 500
start = timeit.default_timer()
_, _ = train_test_classification(net, criterion, optimizer, train_loader,
test_loader, num_epochs=num_epochs,
training_plot=True, device=DEVICE)
end = timeit.default_timer()
print(f"Time: {end-start:.2f}")