Is there a way to set a custom scatter function for DataParallel?

Hey, I’ve been trying to use the DataParallel class for training on several local GPUs. The issue I have is that my input data doesn’t fit the type of data that DataParallel, and in particular its scatter function, expects.

The scatter function in DataParallel recursively looks for Tensor types in the input data and distributes it between your GPUs along the axis you tell it to. This is all well and fine but it doesn’t consider the use case of people using lists of Tensors as minibatches. Is this a use case that will be supported, or is there a way to customize the scatter function myself?

Cheers!

2 Likes

Hi, were you able to find a solution for this? I have a similar issue.

1 Like

As per SimonW’s answer here: Dataparallel chunking for a list of 3d tensors?. You might find luck overriding the scatter function: https://github.com/pytorch/pytorch/blob/v0.3.1/torch/nn/parallel/data_parallel.py#L76-L80 to make it work.

Other than that, this guy seems to have it working with a list with no issues: nn.DataParallel with input as a list not a tensor

Not sure as to the specific customizations that should be done to the function, or if there is a workaround.

Could you post an exact mock example of the code you are trying to run ? That might help.