From what I understand, torch.nn.DataParallel
works by splitting the input(s) to a module along the batch dimension and feeding the split inputs to copies of the model in each GPU. This imposes a somewhat annoying restriction that the input to my model cannot be, for example, a custom class, because then DataParallel
just creates a shallow copy that can’t be distributed among multiple GPUs.
My question is: can I design a custom class in such a way that when I provide it as input to a torch module with DataParallel
wrapped around it, the input is properly split. It would be ideal if I could just write a class function like def split(self, num_gpus)
that implements this splitting, and then DataParallel
internally internally just call it.