Is there a way to train independent models in parallel using the same dataloader?

I’m training multiple models using the same datasets. Currently I simply write separate scripts for these models and train them on a single GPU. But as they are using the same dataset, I think my current way of doing things will create a lot overhead on the dataloading part.

So I’m just wondering if there is a way to train multiple models under the same dataloader. An obvious way is to apply models in a sequential way inside the same dataloader iteration, but would it make use of my gpu efficiently? My naive guess is that if multiple models can be run in a parallel fashion under inside the same dataloader iteration then that would fully make use my single GPU.

If you are worried about host-device/device-host data copy from one model blocks computation from another model, you can try using multiple CUDA streams, one stream per model. Operations in different streams can run in parallel.

No, you’d only amortize data loading time. May be worth it if it is a notable proportion of iteration time (data loading+forward+backward).

Yes, but they won’t run forward() in parallel, unless you write code for it. Here, first problem is you would need num_models times more gpu memory (probably more due to increased fragmentation); another complications include python GIL and need for cuda streams.

I used this below code, and it works, but I’m not happy with it because it appears to create copies of the dataset (1.5GB) in memory, which gives me OOM if I want to run more than about 16 models in parallel. Any ideas on how to fix that @ptrblck ? :pray:

import torch, threading
import torch.nn as nn
from torch_geometric.loader import DataLoader as pygDataLoader
from torch.optim import AdamW
from models.models import WeightedGCN


def trainer(rank, params):

    global DATA
    loader = pygDataLoader(
        DATA,
        batch_size=640,
        num_workers=0,
        shuffle=True,
        pin_memory=False,
    )

    model = WeightedGCN(params).to(params.device)
    optimizer = AdamW(model.parameters())
    Xent = nn.CrossEntropyLoss(reduction='mean')
    
    for j in range(15000):
    
        optimizer.zero_grad()
        for batch in loader:

            # do stuff
            
            embeds, logits = model(batch.x, batch.edge_index, batch.edge_weight, batch.batch, deterministic=False)
            xent_loss = Xent(embeds, ids)
            xent_loss.backward()
       optimizer.step()

        

params.device = 'cuda:0'
NUM_MODEL_COPIES = 16

X, DATA = torch.load(f"../datasets/DATA.pth").values()

# move data to device
for i in range(len(DATA)):
    DATA[i] = DATA[i].to(device=params.device)

if __name__ == '__main__':

    processes = []
    for i in range(NUM_MODEL_COPIES):
        process = threading.Thread(target=trainer, args=(i, params))
        process.start()
        processes.append(process)

    # Wait for all processes to finish
    for process in processes:
        process.join()

Maybe you could use a Queue as described here or a simple implementation of a shared array as given in this example.

1 Like

Shared arrays are created on CPU, so your example isn’t quite what I want, which is to have some shared array (or any other strucure) in CUDA memory. Another thing I don’t get is why, in the following example that uses a torch.multiprocessing, despite setting the data tensors as shared on CUDA, they have different addresses (I check their memory address from each process). nvidia-smi confirms this: I get 10 processes each of which takes about one GB of GPU memory. Shouldn’t all of these processes get their data from the same GPU memory location?

import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, TensorDataset
from termcolor import cprint


# Define your model
class MyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


# Define a function to train a single copy of the model
def train_model(rank, DEVICE, seed=0):
    # Set the random seed for reproducibility
    torch.manual_seed(seed)

    # Load your dataset
    dataset = TensorDataset(
        X,
        y,
    )

    print(rank, X.data_ptr())
    cprint(f'Rank: {rank}, X data_ptr: {X.data_ptr()}', color='yellow')

    # Set the device to the current process's device
    model = MyModel().to(DEVICE)
    cprint(f'Rank: {rank}, model data_ptr: {list(model.parameters())[0].data_ptr()}', color='blue')

    # Create a DataLoader for your dataset
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    # Define the loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # Train the model
    for epoch in range(100):
        for i, (inputs, labels) in enumerate(dataloader):

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            if (i + 1) % 10 == 0:
                print(
                    f"Process {rank} Epoch [{epoch + 1}/{100}], Step [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}"
                )
    cprint(f'{rank} finished!', color='yellow')


NUM_MODEL_COPIES = 10
DEVICE = 'cuda:0'

X = torch.rand(size=(10000, 10)).to(DEVICE).share_memory_()
y = torch.randint(2, size=(10000,)).to(DEVICE).share_memory_()

# Spawn a separate process for each copy of the model
if __name__ == '__main__':
    mp.set_start_method('spawn')  # must be not fork, but spawn

    processes = []
    for rank in range(NUM_MODEL_COPIES):
        process = mp.Process(target=train_model, args=(rank, DEVICE, rank))
        process.start()
        processes.append(process)

    # Wait for all processes to finish
    for process in processes:
        process.join()

and as I said, the memory addresses are different for the data tensors (X). Why is this so?

Also tried the idea with torch.multiprocessing.Queue.

import torch, time, sys, os, copy

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
sys.path.append('../')

import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, TensorDataset
from termcolor import cprint

queue = mp.Queue()


# Define your model
class MyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


# Define a function to train a single copy of the model
def train_model(rank, queue, DEVICE):
    # Set the random seed for reproducibility
    torch.manual_seed(rank)

    X, y = queue.get()
    cprint(f'Rank: {rank}, X data_ptr: {X.data_ptr()}', color='yellow')

    # Load your dataset
    dataset = TensorDataset(
        X,
        y,
    )

    # Set the device to the current process's device
    model = MyModel().to(DEVICE)
    cprint(f'Rank: {rank}, model data_ptr: {list(model.parameters())[0].data_ptr()}', color='blue')

    # Create a DataLoader for your dataset
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    # Define the loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # Train the model
    for epoch in range(100):
        for i, (inputs, labels) in enumerate(dataloader):

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            optimizer.step()

            if (i + 1) % 10 == 0:
                print(
                    f"Process {rank} Epoch [{epoch + 1}/{100}], Step [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}"
                )
    cprint(f'{rank} finished!', color='yellow')


# Spawn a separate process for each copy of the model
# mp.set_start_method('spawn')  # must be not fork, but spawn

NUM_MODEL_COPIES = 10
DEVICE = 'cuda:0'

processes = []
for rank in range(NUM_MODEL_COPIES):
    process = mp.Process(target=train_model, args=(rank, queue, DEVICE))
    process.start()
    processes.append(process)

time.sleep(2)

X = torch.rand(size=(10000, 10)).to(DEVICE)
y = torch.randint(2, size=(10000,)).to(DEVICE)

for rank in range(NUM_MODEL_COPIES):
    queue.put((X, y))

# Wait for all processes to finish
for process in processes:
    process.join()

Now at least the dataset tensors in each process point to the same location in GPU memory, but the model (each of which is created inside the separate processes point to the same (!) address. I want the models to be independent.
image

Moreover, even though the dataset tensors are supposed to be at the same memory address across all the processes, I still see the memory consumed as though each process creates its own copy:

That’s confusing. What am I doing wrong @ptrblck ?

UPDATE:
It’s actually even weirder: even though the pointers to the first element in weight tensors of each of the models inside different processes are the same (according to data_ptr), the actual values of the weights are different (as they should be), but why are the pointers the same? In the below picture the last number in each line is element [0,0] is weight[0] of each copy of the model.

I figured it out (almost everything).

  1. When you create a model inside a separate process with torch.multiprocessing, the parameters of the model would have the same pointers, because the pointers are apparently not global, but are relative to the process’s memory space. In the example below, I tampered with the weights from Process 0 and the weights were not changed in the other processes (I checked Process 8). As an additional check, I tried passing another shared tensor (shared_bias) to each process. In that case, if I tamper with this shared bias, the change will be reflected in all of the subprocesses (and that is intended). So everything checks out: you CAN share CUDA tensors (e.g. datasets) across processes, each of which running a different model. Moreover, it is possible to share some of the models’ parameters across the processes. My confusion was with the pointers of the model weight tensors: now I realize that the same pointers in different processes don’t mean the same underlying data.

One remaining concern is that a small model inside each process allocates 1 GB of GPU memory. I thought that I would be able to train hundreds of small models on one GPU in parallel, but now that seems impossible (or is it? @ptrblck ).

Below is the complete code snippet to reproduce:

import torch, time, sys, os, copy

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
sys.path.append('../')

import torch
import torch.multiprocessing as mp
from torch.utils.data import DataLoader, TensorDataset
from termcolor import cprint

# Spawn a separate process for each copy of the model
# mp.set_start_method('spawn')  # must be not fork, but spawn

queue = mp.Queue()


# Define your model
class MyModel(torch.nn.Module):

    def __init__(self):
        super().__init__()
        self.fc1 = torch.nn.Linear(10, 10)
        self.relu = torch.nn.ReLU()
        self.fc2 = torch.nn.Linear(10, 2)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x


# Define a function to train a single copy of the model
def train_model(rank, queue, DEVICE):
    # Set the random seed for reproducibility
    torch.manual_seed(rank)

    X, y, bias = queue.get()
    cprint(f'Rank: {rank}, X data_ptr: {X.data_ptr()}', color='yellow')

    # Load your dataset
    dataset = TensorDataset(
        X,
        y,
    )

    # Set the device to the current process's device
    with torch.no_grad():
        model = MyModel().to(DEVICE)
        model.fc1.bias = torch.nn.Parameter(bias)

        if rank == 0:
            # changing weight in one model in a separate process doesn't affect the weights in the model in another process, because the weight tensors are not shared
            model.fc1.weight[0][0] = -33.0

            # but changing bias (which is a shared tensor) should affect biases in the other processes
            model.fc1.bias *= 4

            cprint(f'RANK: {rank} | {list(model.parameters())[0][0,0]}', color='magenta')

        if rank == 8:
            cprint(f'RANK: {rank} | {list(model.parameters())[0][0,0]}', color='red')
            cprint(f'RANK: {rank} | BIAS: {model.fc1.bias}', color='red')

    ptr = model.fc1.weight[0][0].storage().data_ptr()
    cprint(f'Rank: {rank}, model data_ptr: {ptr}', color='blue')

    # Create a DataLoader for your dataset
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)

    # Define the loss function and optimizer
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

    # Train the model
    for epoch in range(100):
        for i, (inputs, labels) in enumerate(dataloader):

            if rank == 0:
                cprint(f'RANK: {rank} | {list(model.parameters())[0][0,0]}', color='magenta')
                cprint(f'RANK: {rank} | BIAS: {model.fc1.bias}', color='magenta')
            if rank == 8:
                cprint(f'RANK: {rank} | {list(model.parameters())[0][0,0]}', color='red')
                cprint(f'RANK: {rank} | BIAS: {model.fc1.bias}', color='red')

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()

            # optimizer.step()

            if (i + 1) % 10 == 0:
                print(
                    f"Process {rank} Epoch [{epoch + 1}/{100}], Step [{i + 1}/{len(dataloader)}], Loss: {loss.item():.4f}"
                )
    cprint(f'{rank} finished!', color='yellow')


NUM_MODEL_COPIES = 10
DEVICE = 'cuda:0'

processes = []
for rank in range(NUM_MODEL_COPIES):
    process = mp.Process(target=train_model, args=(rank, queue, DEVICE))
    process.start()
    processes.append(process)

time.sleep(2)

X = torch.rand(size=(10000, 10)).to(DEVICE)
y = torch.randint(2, size=(10000,)).to(DEVICE)
shared_bias = torch.ones(size=(10,), device=DEVICE)
for rank in range(NUM_MODEL_COPIES):
    queue.put((X, y, shared_bias))

# Wait for all processes to finish
for process in processes:
    process.join()

1 Like

Hi Roman,

I am working on something similar - using the same training data for training multiple (identical) models in parallel. Regarding your question on the ~1GB GPU memory allocation per process, it is most likely due to the CUDA context overhead. If you start an interactive Python instance, import torch and create a tensor and send it to your device, you should see what the overhead of CUDA contexts is in nvidia-smi. For instance, I am hovering around ~5-600MB.

Some pointers on the CUDA context overhead:

Secondly, I wanted to ask whether you had encountered discrepancies in the GPU memory consumption for each training process. For instance, I am training multiple ResNet18 models in parallel. If I use a certain batch size below a threshold, each training instance uses the same amount of memory. Once I surpass this threshold for the batch size, each training instance start using different amounts of GPU memory. Did you notice something similar to this? This is what my nvidia-smi looks like for three ResNet18 models being trained in parallel with a batch size above the threshold:

GPU-2022369e-2f16-0362-7dc3-ea36ded90774, 23441, 2428 MiB
GPU-2022369e-2f16-0362-7dc3-ea36ded90774, 23443, 2172 MiB
GPU-2022369e-2f16-0362-7dc3-ea36ded90774, 23442, 2428 MiB

And if I train the same three models with a smaller batch size, I instead see identical memory consumption for each training process:

GPU-2022369e-2f16-0362-7dc3-ea36ded90774, 51077, 1840 MiB
GPU-2022369e-2f16-0362-7dc3-ea36ded90774, 51075, 1840 MiB
GPU-2022369e-2f16-0362-7dc3-ea36ded90774, 51076, 1840 MiB

Update: Fixed my issue with the memory consumption in another post I made (I can only link to two URLs in a post due to being a new user)

1 Like