Avoid tensor split with nn.DataParallel

I am using a model whose forward pass contains some auxiliary tensors (which are agnostic of batch). I return them by updating in the forward pass. When I use nn.DataParallel, PyTorch splits them across 8 GPUs in the dim 0. Is there any way I can tell PyTorch not to split those tensors across the batch dim.? These auxiliary tensors are not a part of the computational graph.

5 Likes

Hi @Vardaan_Pahuja, have you found the solution? I am having the same problem here.

2 Likes

I’m also looking for solution, facing the same problem.
Have either of you been able to find a solution?

1 Like

Are these auxiliary tensors large or could you repeat them?

Hi @ptrblck, they’re quite large.

1 Like

Hi, @tshrjn @ptrblck, it has been a long time but did you find any proper way to handle this problem?

Let’s say i have an aux. tensor with shape (N, D) and i have 4 GPU device. I’m thinking of repeating my aux. tensor in a way that i’ll have a new (repeated) tensor with shape (4N, D) and when this new tensor is shared, each GPU process will end up with my original tensor (with shape (N,D) obviously).

I’m not sure if this is feasible (I’ll try it) and I’m wondering if there is an easy solution. Thanks.

Your aux tensors seem to include the batch dimension, so they should be chunked appropriately.
The original post dealt with aux tensors, which are independent of the current batch size, thus would have to be repeated.

For those of you still looking for a solution, I found out if the input data to the DataParallel module forward is a numpy array it won’t be split, and also no additional copies of it made on CPU memory. The transfer to GPU still needs to happen inside the forward. Also if you want to avoid split for only part of your data, if the input is a dict only those items with Tensor values will be split and the rest will be passed intactly.

1 Like