I have a PyTorch script that trains a ResNet-18 on MNIST (I will attach the script at the end of this post). I am running this script using SGX enclaves with the Gramine libOS. Unfortunately, I am encountering some performance problems: the training time increases epoch after epoch. For example, first epoch takes 50 seconds, while 1000th epoch takes 80 seconds…
Now, I know that my “PyTorch workload creates/destroys too many threads – each iteration creates a bunch of threads, and after the iteration these threads are destroyed, and the next iteration creates another bunch of threads, etc. The creation/destroying of enclave threads is very expensive in Gramine (and just generally in SGX enclaves).”
Is there any way to use a “thread pool” or something like this in PyTorch? I want to pre-allocate a pool of threads, in order to force the application to create N threads and re-use them for each new iteration.
My PyTorch script:
from pathlib import Path
import os
import shutil
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch import nn
from torch import optim
import torch.nn.functional as F
import torch
import torchvision.transforms as T
import torchvision
import glob
#import tqdm
#import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import time
import tracemalloc
myseed = 0
torch.manual_seed(myseed)
np.random.seed(myseed)
root_dir = './'
from torchvision import transforms
normalize = T.Normalize(
mean=[0.1307],
std=[0.3081]
)
augmentation = T.RandomApply(
[T.RandomHorizontalFlip(),
T.RandomCrop(28, padding=4)],
p=.5
)
# Compose transformations
data_transform = transforms.Compose([
transforms.Resize(28),
transforms.ToTensor(),
augmentation,
normalize
])
test_transform = transforms.Compose([
transforms.Resize(28),
transforms.ToTensor(),
normalize
])
# Load MNIST dataset with transforms
trainset = torchvision.datasets.MNIST(root=root_dir, train=True, download=True, transform=data_transform)
valset = torchvision.datasets.MNIST(root=root_dir, train=True, download=True, transform=test_transform)
testset = torchvision.datasets.MNIST(root=root_dir, train=False, download=True, transform=test_transform)
num_train = len(trainset)
indices = list(range(num_train))
valid_size = 0.2
split = int(np.floor(valid_size * num_train))
random_seed = 42
shuffle = True
if shuffle:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)
generator=torch.Generator()
generator.manual_seed(myseed)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, sampler=train_sampler, drop_last=True, generator=generator)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, sampler=valid_sampler, drop_last=False, generator=generator)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128, drop_last=False, shuffle=False, generator=generator)
# Get number of classes (we'll need it in the model)
num_classes = len(trainset.classes)
# Print dataset statistics
print(f"Num. classes: {num_classes}")
print(f"Classes:\n {trainset.classes}")
print(f"Num. train samples: {len(trainset)}")
print(f"Num. test samples: {len(testset)}")
'''
def show_example(img, label):
print('Label: ', trainset.classes[label], "("+str(label)+")")
plt.imshow(img.permute(1, 2, 0))
show_example(*trainset[1])
'''
def init_params(net):
'''Init layer parameters.'''
torch.manual_seed(myseed)
for m in net.modules():
if isinstance(m, nn.Conv2d):
init.kaiming_normal(m.weight, mode='fan_out')
if m.bias:
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
init.constant(m.weight, 1)
init.constant(m.bias, 0)
elif isinstance(m, nn.Linear):
init.kaiming_uniform(m.weight, mode="fan_in")
#if m.bias:
#init.kaiming_uniform(m.weight, mode="fan_in")
import torchvision.models as models
import torch.nn.init as init
resnet18 = models.resnet18()
resnet18.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
resnet18.fc = nn.Linear(in_features=512, out_features=10, bias=True)
# Create the model
model = resnet18
init_params(model)
print(model)
list(model.parameters())
torch.cuda.is_available()
dev = torch.device('cpu')
print(dev)
num_epochs=1000
# Define an optimizier
import torch.optim as optim
#optimizer = optim.SGD([x for x in model.parameters() if x.requires_grad], lr = 0.1, momentum=0.9, weight_decay=0.0005)
optimizer = optim.Adam([x for x in model.parameters() if x.requires_grad], lr=0.0001)
#optimizer = optim.AdamW([x for x in model.parameters() if x.requires_grad], lr=0.001, weight_decay=0.02)
# Define a loss
criterion = nn.CrossEntropyLoss()
#scheduler
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.01, epochs=10, steps_per_epoch=312)
def train(net, loaders, optimizer, criterion, epochs=100, dev=torch.device('cpu')):
torch.manual_seed(myseed)
start_time = time.time()
try:
net = net.to(dev)
print(net)
# Initialize history
history_loss = {"train": [], "val": [], "test": []}
history_accuracy = {"train": [], "val": [], "test": []}
# Process each epoch
for epoch in range(epochs):
start_epoch_time = time.time()
tracemalloc.start()
# Initialize epoch variables
sum_loss = {"train": 0, "val": 0, "test": 0}
sum_accuracy = {"train": 0, "val": 0, "test": 0}
# Process each split
for split in ["train", "val", "test"]:
# Process each batch
for (input, labels) in loaders[split]:
# Move to CUDA
input = input.to(dev)
labels = labels.to(dev)
# Reset gradients
optimizer.zero_grad()
# Compute output
pred = net(input)
loss = criterion(pred, labels)
# Update loss
sum_loss[split] += loss.item()
# Check parameter update
if split == "train":
# Compute gradients
loss.backward()
# Optimize
optimizer.step()
# Compute accuracy
_,pred_labels = pred.max(1)
batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)
# Update accuracy
sum_accuracy[split] += batch_accuracy
scheduler.step()
# Compute epoch loss/accuracy
epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train", "val", "test"]}
epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train", "val", "test"]}
# Update history
for split in ["train", "val", "test"]:
history_loss[split].append(epoch_loss[split])
history_accuracy[split].append(epoch_accuracy[split])
epoch_time = time.time()- start_epoch_time
current, peak = tracemalloc.get_traced_memory()
# Print info
print(f"Epoch {epoch+1}:",
f"TrL={epoch_loss['train']:.4f},",
f"TrA={epoch_accuracy['train']:.4f},",
f"VL={epoch_loss['val']:.4f},",
f"VA={epoch_accuracy['val']:.4f},",
f"TeL={epoch_loss['test']:.4f},",
f"TeA={epoch_accuracy['test']:.4f},",
f"LR={optimizer.param_groups[0]['lr']:.5f},"
f"et={epoch_time:.2f},"
f"tt={time.time()-start_time:.2f},"
f"Current memory:{current:0.2f},"
f"Peak memory:{peak:0.2f},")
tracemalloc.stop()
except KeyboardInterrupt:
print("Interrupted")
finally:
print("FINE")
'''
# Plot loss
plt.title("Loss")
for split in ["train", "val", "test"]:
plt.plot(history_loss[split], label=split)
plt.legend()
plt.show()
# Plot accuracy
plt.title("Accuracy")
for split in ["train", "val", "test"]:
plt.plot(history_accuracy[split], label=split)
plt.legend()
plt.show()
'''
# Define dictionary of loaders
loaders = {"train": train_loader,
"val": val_loader,
"test": test_loader}
# Train model
train(model, loaders, optimizer, criterion, epochs=num_epochs, dev=dev)