Thread pool in PyTorch

I have a PyTorch script that trains a ResNet-18 on MNIST (I will attach the script at the end of this post). I am running this script using SGX enclaves with the Gramine libOS. Unfortunately, I am encountering some performance problems: the training time increases epoch after epoch. For example, first epoch takes 50 seconds, while 1000th epoch takes 80 seconds…

Now, I know that my “PyTorch workload creates/destroys too many threads – each iteration creates a bunch of threads, and after the iteration these threads are destroyed, and the next iteration creates another bunch of threads, etc. The creation/destroying of enclave threads is very expensive in Gramine (and just generally in SGX enclaves).”

Is there any way to use a “thread pool” or something like this in PyTorch? I want to pre-allocate a pool of threads, in order to force the application to create N threads and re-use them for each new iteration.

My PyTorch script:
from pathlib import Path
import os
import shutil
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torch import nn
from torch import optim
import torch.nn.functional as F
import torch
import torchvision.transforms as T
import torchvision
import glob
#import tqdm
#import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
import time
import tracemalloc


myseed = 0
torch.manual_seed(myseed)
np.random.seed(myseed)

root_dir = './'
from torchvision import transforms

normalize = T.Normalize(
    mean=[0.1307],
    std=[0.3081]
)

augmentation = T.RandomApply(
    [T.RandomHorizontalFlip(),
     T.RandomCrop(28, padding=4)],
    p=.5
)

# Compose transformations
data_transform = transforms.Compose([
  transforms.Resize(28),
  transforms.ToTensor(),
  augmentation,
  normalize
])

test_transform = transforms.Compose([
  transforms.Resize(28),
  transforms.ToTensor(),
  normalize
])
# Load MNIST dataset with transforms
trainset = torchvision.datasets.MNIST(root=root_dir, train=True, download=True, transform=data_transform)
valset = torchvision.datasets.MNIST(root=root_dir, train=True, download=True, transform=test_transform)
testset = torchvision.datasets.MNIST(root=root_dir, train=False, download=True, transform=test_transform)

num_train = len(trainset)

indices = list(range(num_train))

valid_size = 0.2
split = int(np.floor(valid_size * num_train))

random_seed = 42
shuffle = True
if shuffle:
    np.random.seed(random_seed)
    np.random.shuffle(indices)

train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = SubsetRandomSampler(train_idx)
valid_sampler = SubsetRandomSampler(valid_idx)


generator=torch.Generator()
generator.manual_seed(myseed)

train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, sampler=train_sampler, drop_last=True, generator=generator)
val_loader = torch.utils.data.DataLoader(valset, batch_size=128, sampler=valid_sampler, drop_last=False, generator=generator)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128, drop_last=False, shuffle=False, generator=generator)

# Get number of classes (we'll need it in the model)
num_classes = len(trainset.classes)
# Print dataset statistics
print(f"Num. classes: {num_classes}")
print(f"Classes:\n {trainset.classes}")
print(f"Num. train samples: {len(trainset)}")
print(f"Num. test samples: {len(testset)}")

'''
def show_example(img, label):
    print('Label: ', trainset.classes[label], "("+str(label)+")")
    plt.imshow(img.permute(1, 2, 0))

show_example(*trainset[1])
'''

def init_params(net):
    '''Init layer parameters.'''
    torch.manual_seed(myseed)
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            init.kaiming_normal(m.weight, mode='fan_out')
            if m.bias:
                init.constant(m.bias, 0)
        elif isinstance(m, nn.BatchNorm2d):
            init.constant(m.weight, 1)
            init.constant(m.bias, 0)
        elif isinstance(m, nn.Linear):
            init.kaiming_uniform(m.weight, mode="fan_in")
            #if m.bias:
                #init.kaiming_uniform(m.weight, mode="fan_in")





import torchvision.models as models
import torch.nn.init as init

resnet18 = models.resnet18()
resnet18.conv1 = nn.Conv2d(1, 64, kernel_size=(7,7), stride=(2,2), padding=(3,3), bias=False)
resnet18.fc = nn.Linear(in_features=512, out_features=10, bias=True)


# Create the model
model = resnet18
init_params(model)
print(model)

list(model.parameters())

torch.cuda.is_available()

dev = torch.device('cpu')
print(dev)

num_epochs=1000

# Define an optimizier
import torch.optim as optim
#optimizer = optim.SGD([x for x in model.parameters() if x.requires_grad], lr = 0.1, momentum=0.9, weight_decay=0.0005)
optimizer = optim.Adam([x for x in model.parameters() if x.requires_grad], lr=0.0001)
#optimizer = optim.AdamW([x for x in model.parameters() if x.requires_grad], lr=0.001, weight_decay=0.02)
# Define a loss
criterion = nn.CrossEntropyLoss()
#scheduler
#scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr = 0.01, epochs=10, steps_per_epoch=312)


def train(net, loaders, optimizer, criterion, epochs=100, dev=torch.device('cpu')):
    torch.manual_seed(myseed)
    start_time = time.time()
    try:
        net = net.to(dev)
        print(net)
        # Initialize history
        history_loss = {"train": [], "val": [], "test": []}
        history_accuracy = {"train": [], "val": [], "test": []}
        # Process each epoch
        for epoch in range(epochs):
            start_epoch_time = time.time()
            tracemalloc.start()
            # Initialize epoch variables
            sum_loss = {"train": 0, "val": 0, "test": 0}
            sum_accuracy = {"train": 0, "val": 0, "test": 0}
            # Process each split
            for split in ["train", "val", "test"]:
                # Process each batch
                for (input, labels) in loaders[split]:
                    # Move to CUDA
                    input = input.to(dev)
                    labels = labels.to(dev)
                    # Reset gradients
                    optimizer.zero_grad()
                    # Compute output
                    pred = net(input)
                    loss = criterion(pred, labels)
                    # Update loss
                    sum_loss[split] += loss.item()
                    # Check parameter update
                    if split == "train":
                        # Compute gradients
                        loss.backward()
                        # Optimize
                        optimizer.step()
                    # Compute accuracy
                    _,pred_labels = pred.max(1)
                    batch_accuracy = (pred_labels == labels).sum().item()/input.size(0)
                    # Update accuracy
                    sum_accuracy[split] += batch_accuracy
                scheduler.step()
            # Compute epoch loss/accuracy
            epoch_loss = {split: sum_loss[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            epoch_accuracy = {split: sum_accuracy[split]/len(loaders[split]) for split in ["train", "val", "test"]}
            # Update history
            for split in ["train", "val", "test"]:
                history_loss[split].append(epoch_loss[split])
                history_accuracy[split].append(epoch_accuracy[split])


            epoch_time = time.time()- start_epoch_time
            current, peak =  tracemalloc.get_traced_memory()

            # Print info
            print(f"Epoch {epoch+1}:",
                  f"TrL={epoch_loss['train']:.4f},",
                  f"TrA={epoch_accuracy['train']:.4f},",
                  f"VL={epoch_loss['val']:.4f},",
                  f"VA={epoch_accuracy['val']:.4f},",
                  f"TeL={epoch_loss['test']:.4f},",
                  f"TeA={epoch_accuracy['test']:.4f},",
                  f"LR={optimizer.param_groups[0]['lr']:.5f},"
                  f"et={epoch_time:.2f},"
                  f"tt={time.time()-start_time:.2f},"
                  f"Current memory:{current:0.2f},"
                  f"Peak memory:{peak:0.2f},")

            tracemalloc.stop()

    except KeyboardInterrupt:
        print("Interrupted")
    finally:
        print("FINE")
        '''
        # Plot loss
        plt.title("Loss")
        for split in ["train", "val", "test"]:
            plt.plot(history_loss[split], label=split)
        plt.legend()
        plt.show()
        # Plot accuracy
        plt.title("Accuracy")
        for split in ["train", "val", "test"]:
            plt.plot(history_accuracy[split], label=split)
        plt.legend()
        plt.show()
        '''
# Define dictionary of loaders
loaders = {"train": train_loader,
           "val": val_loader,
           "test": test_loader}

# Train model
train(model, loaders, optimizer, criterion, epochs=num_epochs, dev=dev)

Using torch.set_num_threads(N) is not the same as creating a thread pool, right? With this function, it is guaranteed that it uses N threads, but the process of creation/destruction of threads is not avoided, right?

torch.set_num_threads changes the number of threads for the “intraop parallelism” in PyTorch based on the docs.
Based on your description it seems you are more concerned about “general” threads being spawned? Do you see in your profile where these are coming from? I.e. are these from any CPU library running multi-threaded operations?

So, you confirm that torch.set_num_threads() does not create a thread pool, but only fixes the number of threads. However, I think that this function already helped me with my problem (if you want to know, the problem was about performance of PyTorch in Intel SGX: GitHub

If I want to create a thread pool do you think that a solution like this can work?

from concurrent.futures import ThreadPoolExecutor

executor = ThreadPoolExecutor(max_workers=80)

executor.submit(train()).result()