Making a custom class input compatible with DataParallel

From what I understand, torch.nn.DataParallel works by splitting the input(s) to a module along the batch dimension and feeding the split inputs to copies of the model in each GPU. This imposes a somewhat annoying restriction that the input to my model cannot be, for example, a custom class, because then DataParallel just creates a shallow copy that can’t be distributed among multiple GPUs.

My question is: can I design a custom class in such a way that when I provide it as input to a torch module with DataParallel wrapped around it, the input is properly split. It would be ideal if I could just write a class function like def split(self, num_gpus) that implements this splitting, and then DataParallel internally internally just call it.

Hello. Have you figured out the solution?

For anyone coming here, I solved this by writing a custom scatter (for custom input) for torch.nn.parallel.DataParallel.