How to use multiple GPUs in Pytorch?


I have a Pytorch model for machine translation. I rented a 4 GPU machine.

The dataset is very large.

How to train my model using the 4 GPUs?

I see the model uses only 1 GPU.

So, my question is:

  1. What are the mechanics of training on 4 GPUs?

  2. Any way I can make my model run on the 4 GPUs?


hi, any help will save me a lot of money :slight_smile:

Have a look at torch.nn.parallel.DistributedDataParallel. :wink:

so pytorch or machines with multiple GPUs do not use the multiple GPUs by themselves?

That’s right. You can use a one-liner and wrap your model in nn.DataParallel or use the recommended DDP approach.
Alternatively, you could also use model sharding and split the model among all GPUs in case you are working with a huge model.

1 Like

thanks, i am having some issues with mult-gpu training. i have multiple encoders, and the data parallel module is splitting along different dimensions for different encoders, so i get error in decoder.

does data parallel module assume the first dimension is the batch? LSTM module assumes Batch is the middle dimension.


Yes, the batch will be chunked in dim0.
You could try to permute the data or use batch_first=True in your LSTM.

1 Like


I have been trying to train additional models / do work on a second GPU of a machine but am running into issues. I have confirmed that torch.cuda recognizes 2 GPUs but I cannot switch to second GPU to train different models in parallel.

I think I solved this by adding:


Many thanks for your help.
I got the following error message, could you tell me how to fix it? thanks a lot.

RuntimeError: expected device cuda:3 and dtype Float but got device cuda:0 and dtype Float

If you are using nn.DataParallel, this error is often raised if new tensors are created in the forward method and pushed to the default device (cuda:0).
Could you post the model definition so that we could have a look?

Many thanks for your prompt reply :slight_smile:
Here is my forward definition. The network is quite big, but it is a UNet-based for semantic segmentation using a shared encoder.

    def forward(self, x):
        features_x1 = self.encoder(x[:, : self.in_channels // 2, :, :])
        features_x2 = self.encoder(x[:, self.in_channels // 2 :, :, :])
        features = [[x1, x2], 1) for x1, x2 in zip(features_x1, features_x2)]
        features = [self.res[i].to(x.device)(x) for i, x in enumerate(features)]

        decoder_output = self.decoder(*features)
        masks = self.segmentation_head(decoder_output)

        return masks

Is it necessary to call self.res[i].to(x.device) in the list comprehension?
Based on the code, it should be a no-op, since self.res should already be on the corresponding device, if it’s registered as a module.
Could you post the complete error message including the line of code, which raises this error?

Hello (sorry for the late reply)

Regarding your first question, self.res[i].to(x.device).
I do this because I need to reduce the size (c*2) of the concatenated features in the previous step.
Here, self.res() = nn.Conv2d(c*2, c, stride=1, 0) where c is the size of x1 and x2

As for the full error messages:

Epoch: 1
train:   0%|                                                                       | 0/146 [00:12<?, ?it/s]
Traceback (most recent call last):
  File "", line 112, in <module>
    train_logs, *_ =
  File "/uge_mnt/home/bruno/parallel/codes/", line 125, in run
    loss, y_pred = self.batch_update(x, y)
  File "/uge_mnt/home/bruno/parallel/codes/", line 180, in batch_update
  File "/home/xxx/apps/intelpython3/lib/python3.6/site-packages/torch/", line 118, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph)
  File "/home/xxx/apps/intelpython3/lib/python3.6/site-packages/torch/autograd/", line 93, in backward
    allow_unreachable=True)  # allow_unreachable flag
RuntimeError: expected device cuda:3 and dtype Float but got device cuda:2 and dtype Float

many thanks in advance for your help :slight_smile:

Thanks for the follow-up.
Could you post a minimal, executable code snippet to reproduce this error please, as I cannot see any issues in the current code?