How to use nn.DataParallel to split a list of elements into several sublists

In the Data Parallelism Tutorial, authors show how to train a model using multi-GPU when each training batch is a tensor. My question is how to extend it when each batch is a python list. To give a concrete example, I slightly modify the RandomDataset class and Model in the above tutorial as follows.

# Parameters and DataLoaders
input_size = 5
output_size = 2
batch_size = 30
data_size = 100

class RandomDataset(Dataset):

    def __init__(self, size, length):
        self.len = length
        self.data = torch.randn(length, size)

    def __getitem__(self, index):
        return [self.data[index] for i in range(12)]  # This line is changed

    def __len__(self):
        return self.len

class Model(nn.Module):
    # Our model

    def __init__(self, input_size, output_size):
        super(Model, self).__init__()
        self.fc = nn.Linear(input_size, output_size)

    def forward(self, input_list):
        output = [self.fc(ele) for ele in input_list]  # This line is changed
        print("\tIn Model: length of input_list", len(input_list))

        return output

Previously, one batch is simply a tensor of size [30, 5] and if we have 3 GPUs, it can be automatically splitted into 3 tensors of size [10, 5] and the “input” of model is a tensor of size [10,5]. This is good.

Now, one batch is a list of 12 tensors, each tensor is of size [30, 5]. I want the DataParallel to split this list into 3 sublists, Each sublist contains 4 tensors and each tensor is of size [30, 5]. It seems the default nn.DataParallel cannot do this and I am wondering how I can achieve this goal.

I put the full script/notebook here: https://gist.github.com/mickeystroller/d428af55e3d2afb6b79eb888b139ba31

P.S.:Just to clarify, in this toy example, I know we can simply create a new dimension and stack this list of 12 tensors. However, for my real application, each training batch will be a list of complex objects instead of a list of tensors and thus I do need the DataParallel to split a list into several sublists.

Thanks for your help and time.

1 Like