Model parallelism in pytorch for large(r than 1 GPU) models?

Hi! I have a model that is too large to fit inside a single TITAN X (even with 1 batch size). I want to split it over several GPUs such that the memory cost is shared between GPUs. That is, place different parts of the same model on different GPUs and train it end-to-end.

Questions:

  1. Is this possible in pyTorch? If not, is this possible in Torch?
  2. Would inter-GPU communication (say, for transferring activations to later layers) involve GPU->host->GPU type transfers?
3 Likes
  1. Yes it is possible. Just put some of the layers in GPU0 (.cuda(0)) and others on GPU1 (.cuda(1)). Then, in the forward function, once the base on the first GPU finishes processing, call .cuda(1) on the output. Of course this can be extended to as many GPUs as you want. See an example below.
  2. No. Calling .cuda(i) on a CUDA tensor that’s on GPU j (j != i) is purely a peer to peer copy. Host doesn’t have to do anything.
class MyModel(nn.Module):
    def __init__(self, split_gpus):
        self.large_submodule1 = ...
        self.large_submodule2 = ...

        self.split_gpus = split_gpus
        if split_gpus:
            self.large_submodule1.cuda(0)
            self.large_submodule1.cuda(1)

    def forward(self, x):
        x = self.large_submodule1(x)
        if split_gpus:
            x = x.cuda(1) # P2P GPU transfer
        return self.large_submodule2(x)
13 Likes

This was so easy! I love your work with PyTorch. Minimum fuss! Cheers!

3 Likes

@apaszke Hi, very thanks for your examples.

I notice that when I split the whole model in 4 gpus and do forward/backward, the GPU memory on the fisrt GPU cost much more than it should be. For example, if the whole model cost 12GB on a single GPU, when split it to four GPUs, the first GPU cost 11GB and the sum of others cost about 11GB.
Is there an explaination for how does the GPU memory be malloced when using multiple GPUs for model parallelism.

Another question, when forward with the model parallelism, there is only one gpu hasing the Volatile GPU-Util with 100%, the others are 0%.
Is there any method to leverage all GPU-Util with the all four GPUs?

4 Likes

I am doing late fusion of features extracted by two large resnet.
After that, the classifier on the concatenated features can be run in the second GPU?
Will that be slow?

How to combine data parallel and model parallel?
(say if i have 4 gpu)

   def forward(self, x):
         
        x1= self.resnet_1(x[:,0:3,:,:]).cuda(0)
        x2= self.resnet_2(x[:,3:6,:,:]).cuda(1)

        flat  = torch.cat([x1, x2],1).cuda(???)
        logit = self.fc(flat).cuda(???) 
        return logit

5 Likes

did you figure out how to do this?

I am also interested in late fusion and running particular submodels in different GPUs. Did someone find out anything on how to do it? I am currently doing something similar as @Hengck suggested but it is not working.
Thanks

Dear apaszke,

I am trying to implement inter-GPU communiation by using pytorch+mpi+gpu.

Following are the tested code, which is designed to make sure that process0 runs on GPU0 and process1 runs on GPU1. However, the code can not be run successfully. Do you know why?

import os
import socket
import torch
import torch.distributed as dist
from torch.multiprocessing import Process
import platform

def run(rank, size):

if rank == 0:
tensor = torch.zeros(1).cuda(0)
# Send the tensor to process 1
tensor += 1
dist.send(tensor=tensor, dst=1)
else:
tensor = torch.zeros(1).cuda(1)
# Receive tensor from process 0
dist.recv(tensor=tensor, src=0)
print('Rank ', rank, ’ has data ', tensor[0])

def init_processes(fn):
“”" Initialize the distributed environment. “”"
dist.init_process_group(‘mpi’)
rank = dist.get_rank()
size = dist.get_world_size()
print('I am rank ', rank, ’ on ', platform.node())
fn(rank, size)

if name == “ main ”:
init_processes(run)

Following is the error message.
[osherlab:21377] *** Process received signal ***
[osherlab:21377] Signal: Segmentation fault (11)
[osherlab:21377] Signal code: Invalid permissions (2)
[osherlab:21377] Failing at address: 0x10030800000
[osherlab:21377] [ 0] /lib/x86_64-linux-gnu/libpthread.so.0(+0x11390)[0x7f9cd4e4b390]
[osherlab:21377] [ 1] /lib/x86_64-linux-gnu/libc.so.6(+0x14e04b)[0x7f9cd4bbe04b]
[osherlab:21377] [ 2] /home/osherlab/guanlei/software/openmpi-4.0.0/openmpi/lib/libopen-pal.so.40(opal_convertor_unpack+0x11b)[0x7f9c79c0363b]
[osherlab:21377] [ 3] /home/osherlab/guanlei/software/openmpi-4.0.0/openmpi/lib/openmpi/mca_pml_ob1.so(mca_pml_ob1_recv_frag_callback_match+0x4de)[0x7f9c51f9c1fe]
[osherlab:21377] [ 4] /home/osherlab/guanlei/software/openmpi-4.0.0/openmpi/lib/openmpi/mca_btl_smcuda.so(mca_btl_smcuda_component_progress+0x3b9)[0x7f9c51d6be99]
[osherlab:21377] [ 5] /home/osherlab/guanlei/software/openmpi-4.0.0/openmpi/lib/libopen-pal.so.40(opal_progress+0x2c)[0x7f9c79bf1dac]
[osherlab:21377] [ 6] /home/osherlab/guanlei/software/openmpi-4.0.0/openmpi/lib/libopen-pal.so.40(ompi_sync_wait_mt+0xb5)[0x7f9c79bf86a5]
[osherlab:21377] [ 7] /home/osherlab/guanlei/software/openmpi-4.0.0/openmpi/lib/libmpi.so.40(ompi_request_default_wait+0x20f)[0x7f9ca9f11c9f]
[osherlab:21377] [ 8] /home/osherlab/guanlei/software/openmpi-4.0.0/openmpi/lib/libmpi.so.40(PMPI_Wait+0x4e)[0x7f9ca9f56c2e]
[osherlab:21377] [ 9] /home/osherlab/guanlei/venv/lib/python3.6/site-packages/torch/lib/libtorch_python.so(_ZN4c10d15ProcessGroupMPI9AsyncWork4waitEv+0x6d)[0x7f9cc375191d]
[osherlab:21377] [10] /home/osherlab/guanlei/venv/lib/python3.6/site-packages/torch/lib/libtorch_python.so(+0x5fd98e)[0x7f9cc369298e]
[osherlab:21377] [11] /home/osherlab/guanlei/venv/lib/python3.6/site-packages/torch/lib/libtorch_python.so(+0x112f5d)[0x7f9cc31a7f5d]
[osherlab:21377] [12] python(_PyCFunction_FastCallDict+0x154)[0x563fdb4a0b94]
[osherlab:21377] [13] python(+0x19e7ce)[0x563fdb5307ce]
[osherlab:21377] [14] python(_PyEval_EvalFrameDefault+0x2fa)[0x563fdb552cba]
[osherlab:21377] [15] python(+0x197a94)[0x563fdb529a94]
[osherlab:21377] [16] python(+0x198941)[0x563fdb52a941]
[osherlab:21377] [17] python(+0x19e755)[0x563fdb530755]
[osherlab:21377] [18] python(_PyEval_EvalFrameDefault+0x10ba)[0x563fdb553a7a]
[osherlab:21377] [19] python(+0x19870b)[0x563fdb52a70b]
[osherlab:21377] [20] python(+0x19e755)[0x563fdb530755]
[osherlab:21377] [21] python(_PyEval_EvalFrameDefault+0x2fa)[0x563fdb552cba]
[osherlab:21377] [22] python(+0x197a94)[0x563fdb529a94]
[osherlab:21377] [23] python(+0x198941)[0x563fdb52a941]
[osherlab:21377] [24] python(+0x19e755)[0x563fdb530755]
[osherlab:21377] [25] python(_PyEval_EvalFrameDefault+0x10ba)[0x563fdb553a7a]
[osherlab:21377] [26] python(PyEval_EvalCodeEx+0x329)[0x563fdb52b459]
[osherlab:21377] [27] python(PyEval_EvalCode+0x1c)[0x563fdb52c1ec]
[osherlab:21377] [28] python(+0x2149a4)[0x563fdb5a69a4]
[osherlab:21377] [29] python(PyRun_FileExFlags+0xa1)[0x563fdb5a6da1]
[osherlab:21377] *** End of error message ***
Rank 0 has data tensor(1., device=‘cuda:0’)

Primary job terminated normally, but 1 process returned
a non-zero exit code. Per user-direction, the job has been aborted.


mpirun noticed that process rank 1 with PID 0 on node osherlab exited on signal 11 (Segmentation fault).

Your help will be appreciated. Thank you.

1 Like

Would this also work for one 1 gpu with two sequential steps somehow? If my model is too large to fit on one gpu can I somehow do the forward/backward pass sequtially where I only have one part in gpu memory and somehow cache the other part for the backward pass later.

Somehow like this:

x = submodule1(x)
#somehow unload intermediate results of submodule1 from gpu here and cache for later backward pass 
#(and then load on gpu again when needed in backward pass of submodule1)
x = submodule2(x)

I could imagine how this works but then I don’t know how I would pass the gradients that come from submodule2 back to submodule1 and initiate the backward pass on submodule1.

Same problem here. Inter-GPU communication gets a similar error with pytorch+mpi+gpu.
Have you found a way out?

Thx

You may want to checkout the concept of Gradient checkpointing.
Here is an example repo csrhddlam/pytorch-checkpoint