Avoid tensor split with nn.DataParallel

I am using a model whose forward pass contains some auxiliary tensors (which are agnostic of batch). I return them by updating in the forward pass. When I use nn.DataParallel, PyTorch splits them across 8 GPUs in the dim 0. Is there any way I can tell PyTorch not to split those tensors across the batch dim.? These auxiliary tensors are not a part of the computational graph.

3 Likes

Hi @Vardaan_Pahuja, have you found the solution? I am having the same problem here.

1 Like

I’m also looking for solution, facing the same problem.
Have either of you been able to find a solution?

Are these auxiliary tensors large or could you repeat them?

Hi @ptrblck, they’re quite large.