Is there a way to set a custom scatter function for DataParallel?

As per SimonW’s answer here: Dataparallel chunking for a list of 3d tensors?. You might find luck overriding the scatter function: https://github.com/pytorch/pytorch/blob/v0.3.1/torch/nn/parallel/data_parallel.py#L76-L80 to make it work.

Other than that, this guy seems to have it working with a list with no issues: nn.DataParallel with input as a list not a tensor

Not sure as to the specific customizations that should be done to the function, or if there is a workaround.

Could you post an exact mock example of the code you are trying to run ? That might help.