How are batches split across multiple GPUs?

Suppose you have 4 GPUs, are batches then split evenly into 4 parts (without changing the order), and then distributed to different GPUs? Or is each individual image in the batch sent to a random GPU?

The reason I am asking is because I have run into some problems training on multiple GPUs for few-shot learning. In few-shot learning batches are constructed in a specific manner, i.e., a seperation between support images and query images. By splitting the batch evenly, the batches sent to different GPUs will not contain the same classes, and there will be different number of examples per class. I have noticed some performance drop in my models, which were mitigated by shuffling the data before passing it through the model and reordering it after getting my required features:

        # we have to shuffle the batch for multi-gpu training purposes
        shuffle = torch.randperm(input.shape[0])
        input = input[shuffle]

        # idx keeps track on how to shuffle back to original order
        idx = torch.argsort(shuffle)

        # get features
        features, output = model(input, use_fc=False)

        # shuffle back so protonet loss still works
        features = features[idx]

If batches are split evenly, is there an in-built method to send each image to a random GPU, and preserve the order once the output is returned?

A quick test shows that the data is split sequentially:

import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

    def forward(self, x):
        print(x, x.device)
        return x

model = MyModel()
model = nn.DataParallel(model)

batch_size = 2
x = torch.arange(batch_size * 8).float().view(batch_size*8, -1)

out = model(x)


        [3.]], device='cuda:1') cuda:1
        [1.]], device='cuda:0') cuda:0
        [11.]], device='cuda:5') cuda:5
        [9.]], device='cuda:4') cuda:4
        [15.]], device='cuda:7') cuda:7
        [5.]], device='cuda:2') cuda:2
        [7.]], device='cuda:3') cuda:3
        [13.]], device='cuda:6') cuda:6

If you need to shuffle the data, you would have to apply the shuffling manually before passing the data to nn.DataParallel (assuming you are using this approach).


Shuffling the data manually before passing to nn.Dataparallel worked :slight_smile: Thanks a lot!