Avoid tensor split with nn.DataParallel

I am using a model whose forward pass contains some auxiliary tensors (which are agnostic of batch). I return them by updating in the forward pass. When I use nn.DataParallel, PyTorch splits them across 8 GPUs in the dim 0. Is there any way I can tell PyTorch not to split those tensors across the batch dim.? These auxiliary tensors are not a part of the computational graph.

4 Likes

Hi @Vardaan_Pahuja, have you found the solution? I am having the same problem here.

1 Like

I’m also looking for solution, facing the same problem.
Have either of you been able to find a solution?

Are these auxiliary tensors large or could you repeat them?

Hi @ptrblck, they’re quite large.

Hi, @tshrjn @ptrblck, it has been a long time but did you find any proper way to handle this problem?

Let’s say i have an aux. tensor with shape (N, D) and i have 4 GPU device. I’m thinking of repeating my aux. tensor in a way that i’ll have a new (repeated) tensor with shape (4N, D) and when this new tensor is shared, each GPU process will end up with my original tensor (with shape (N,D) obviously).

I’m not sure if this is feasible (I’ll try it) and I’m wondering if there is an easy solution. Thanks.

Your aux tensors seem to include the batch dimension, so they should be chunked appropriately.
The original post dealt with aux tensors, which are independent of the current batch size, thus would have to be repeated.