Train multiple models on single GPU

Hello,

I am looking for a way to train multiple models on a single GPU(RTX A5000) in parallel. Actually, these are many (thousands) small non-linear inversion problems that I want to solve as efficiently as possible. Each problem is independent of the others and has unique input/output and objective function (loss function). From my limited knowledge on this topic I believe this should be a good candidate for parallelization.

I have already implemented a working version using numpy/scipy where solving one individual problem takes about 2 seconds depending on convergence criterias and so on. If, I run them in parallel using python multiprocessing, the average runtime is reduced to about 0.2 seconds per problem.

The next natural step for me to make the code run faster is to use GPU acceleration (which I have never done before). I know there are several ways to interact with the GPU, but after reading up on it, I thought pytorch could be a good alternative for me due to it also having automatic differentiation. I actually have written a code that works with pytorch, but it is not fast (see sample code below).

The problem I have is that I can’t seem to start running the individual optimization problems on the GPU in parallel. This may be a straightforward thing to do, but I have been looking around at various places including asking chatGPT and Bing chat without any luck.

I tried to do it by using torch multiprocessing, but it seems that it does not work as I thought it did. It seems as the processes are started in parallel on the CPU and not on the GPU (is this correct?).

Anyways, I would greatly appreciate if anyone could point me in the right direction on this problem.

Here’s a simplified sample code:

from time import time
import torch
import torch.optim as optim
import torch.nn as nn
import torch.multiprocessing as mp


class Model(nn.Module):
    def __init__(self, N):
        super().__init__()
        self.model_parameter = nn.Parameter(torch.ones(N))

    def forward(self, input):
        output = do_something_with_input(input)
        return output

# Objective function
def objective_function(modeled, true_data):
    return torch.sum((modeled-true_data)**2)

def train_model(model, input, true_data):
    # Move data to GPU
    input = input.to('cuda')
    true_data = true_data.to('cuda')
    model = model.to('cuda')

    optimizer = optim.Adam(model.parameters(), lr=0.01)

    # Optimization
    for iteration in range(20):
        optimizer.zero_grad()
        modeled_data = model(input)
        loss = objective_function(modeled_data, true_data)
        loss.backward()
        optimizer.step()
    
def train_process(queue):
    while not queue.empty():
        model, input, true_data = queue.get()
        train_model(model, input, true_data)

if __name__ == "__main__":
    input_data = load_the_input_data() # Torch tensor with shape (M,N)
    observed_data = load_the_observed_data #  Torch tensor with shape (M,N)
    
    num_processes = 4 # I have tried setting this equal to number of models, and that does not work well, even when I limit the number of models

    models = [Model(N) for i in range(M)] # N is number of samples, M is number of problems

    mp.set_start_method('spawn', force=True)

    queue = mp.Queue()
    for i in range(M):
        queue.put((models[i], input_data[i], observed_data[i]))

    processes = []
    for i in range(num_processes):
        p = mp.Process(target=train_process, args=(queue,))
        p.start()
        processes.append(p)
    
    # Wait for all processes to finish
    for p in processes:
        p.join()

You could use separate CUDA streams for each execution and depending on the occupancy kernels could run in parallel. Note however, that kernels will be blocked in case one kernel uses all compute resources.

Thanks @ptrblck, I tried to use separate CUDA streams for each problem as you suggested (code below). Currently, I can run the code and it produces the desired result, but unfortunately, the runtime seem to scale more or less linearly with number of models. I also tried simplifying the forward model to see if that would make it better. I will try a bit more, but maybe I have to approach a more low-level implementation…

Simplified code:

...
models = [Model() for i in range(N_models)]
optimizers = [optim.AdamW(models[i].parameters(), lr = learning_rate) for i in range(N_models)]
streams = [Stream() for _ in range(N_models)]

for epoch in range(20):
    for i in range(N_models):
        with torch.cuda.stream(streams[i]):
            optimizers[i].zero_grad()
            modeled_data = models[i]()
            loss = objective_function(observed_data[i], modeled_data)
            loss.backward()
            optimizers[i].step()

        for stream in streams:
            stream.synchronize()

Results from profiling (simplified forward model, sorted by self_cuda_time_total):

In [40]: print(prof.key_averages().table(^M
     ...:     sort_by="self_cuda_time_total", row_limit=20))
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg       CPU Mem  Self CPU Mem      CUDA Mem  Self CUDA Mem    # of Calls
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
                                 aten::_index_put_impl_         5.65%      17.012ms        13.50%      40.664ms     112.956us     160.395ms        26.37%     208.365ms     578.792us           0 b      -1.63 Kb      75.00 Kb      -5.20 Mb           360
                                            aten::index         5.24%      15.769ms         6.30%      18.989ms      26.374us      44.589ms         7.33%      50.169ms      69.679us           0 b           0 b       2.11 Mb       2.11 Mb           720
                                              aten::add         1.73%       5.215ms         1.73%       5.215ms       8.692us      38.890ms         6.39%      38.890ms      64.817us           0 b           0 b     900.00 Kb     900.00 Kb           600
autograd::engine::evaluate_function: CumsumBackward0...         0.23%     681.000us         3.16%       9.528ms      79.400us      35.924ms         5.91%      51.434ms     428.617us           0 b           0 b           0 b    -247.50 Kb           120
                                              aten::mul         5.06%      15.236ms         5.06%      15.236ms       8.464us      27.294ms         4.49%      27.294ms      15.163us        1000 b        1000 b       3.57 Mb       3.57 Mb          1800
                                            aten::copy_         3.62%      10.910ms         3.62%      10.910ms       5.051us      25.350ms         4.17%      25.350ms      11.736us           0 b           0 b           0 b           0 b          2160
                                          ReluBackward0         0.48%       1.459ms         1.32%       3.979ms      16.579us      22.801ms         3.75%      29.753ms     123.971us           0 b           0 b     360.00 Kb           0 b           240
                                            aten::zero_         1.60%       4.814ms         2.54%       7.660ms       7.881us      22.399ms         3.68%      34.390ms      35.381us           0 b           0 b           0 b           0 b           972
     autograd::engine::evaluate_function: ReluBackward0         0.98%       2.964ms         2.31%       6.943ms      28.929us      20.148ms         3.31%      49.901ms     207.921us           0 b           0 b     -58.50 Kb    -418.50 Kb           240
    autograd::engine::evaluate_function: IndexBackward0         1.18%       3.552ms        18.39%      55.394ms     153.872us      16.286ms         2.68%     237.599ms     659.997us           0 b           0 b      -1.32 Mb      -2.37 Mb           360
                                           aten::arange         2.66%       8.015ms         4.49%      13.526ms      14.090us      15.260ms         2.51%      30.421ms      31.689us           0 b           0 b       2.81 Mb     342.00 Kb           960
                                            aten::fill_         1.42%       4.285ms         1.42%       4.285ms       3.535us      13.379ms         2.20%      13.379ms      11.039us           0 b           0 b           0 b           0 b          1212
    autograd::engine::evaluate_function: SliceBackward0         1.30%       3.914ms         7.38%      22.234ms      61.761us      11.876ms         1.95%      52.642ms     146.228us           0 b           0 b    -423.00 Kb    -963.00 Kb           360
                                             aten::sort         1.42%       4.266ms         2.86%       8.617ms      71.808us      11.545ms         1.90%      18.792ms     156.600us           0 b           0 b     720.00 Kb     360.00 Kb           120
                                              aten::sub         2.11%       6.342ms         2.11%       6.342ms       7.550us      11.522ms         1.89%      11.522ms      13.717us           0 b           0 b       1.93 Mb       1.93 Mb           840
                                             aten::add_         2.45%       7.383ms         3.14%       9.444ms       8.744us      10.875ms         1.79%      12.550ms      11.620us         120 b        -360 b           0 b           0 b          1080
                                             aten::div_         1.16%       3.501ms         1.16%       3.501ms       9.725us      10.444ms         1.72%      10.444ms      29.011us       1.01 Kb       1.01 Kb           0 b           0 b           360
                                             aten::flip         1.30%       3.915ms         1.70%       5.109ms      21.288us       9.367ms         1.54%      11.244ms      46.850us           0 b           0 b     360.00 Kb           0 b           240
                                              aten::sum         2.62%       7.893ms         2.63%       7.920ms      16.500us       7.732ms         1.27%       8.701ms      18.127us           0 b           0 b     240.00 Kb     240.00 Kb           480
                                       aten::as_strided         0.38%       1.148ms         0.38%       1.148ms       0.309us       7.474ms         1.23%       7.474ms       2.009us           0 b           0 b           0 b           0 b          3720
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------
Self CPU time total: 301.214ms
Self CUDA time total: 608.263ms