Multiple inputs DataParallel

I’m interested in using DataParallel with two inputs behaving differently :


Where S would be split but not A. I didn’t find anything about the possibility of doing that.

In nn.DataParallel your model will be replicated onto the specified GPUs.
The data will then be split in the batch dimension and pushed to each replica.
What would your use case be, if A shouldn’t be split? The replicas won’t be able to use it in any calculation, since A would still be on your “master” device.

I have to handle a huge matrix multiplication between S which is (555,3) and A which is (3, 1500000).
As it cannot fit in memory because the target data is (555,1500000), I need to split the matrix multiplication in two to avoid going out of memory. I do that by just splitting S in two on the first dimension. Then I would obtain two matrices (SA)1 and (SA)2 on which I compute my loss and I merge the results back before backwarding.

In that case nn.DataParallel won’t help, since it would copy the data onto all GPUs.
You could try to push S1 and S2 onto different GPUs and calculate the loss there.
Here is a small pseudo-code for this use case:

S1, S2 = torch.chunk(S, 2, 0)
S1 ='cuda:0')
S2 ='cuda:1')
SA1 = torch.matmul(S1,'cuda:0')
SA2 = torch.matmul(S2,'cuda:1')
# calculate losses and backward

It works great. Thanks for your help !

1 Like