While using nn.DataParallel only accessing one GPU

I think I have solved this problem.
For your model=DataParallel(model) at forward() step, if you pass arguments into forward(), according to pytorch document:

Arbitrary positional and keyword inputs are allowed to be passed into DataParallel EXCEPT Tensors. All tensors will be scattered on dim specified (default 0). Primitive types will be broadcasted, but all other types will be a shallow copy and can be corrupted if written to in the model’s forward pass.

which means if the input argument type is tensor then it would be split by dim=0 (which is the batch dimension). For other types like python list/dict/str,DataParallel.forward() automatically copies it to N replicas ( N equals to your GPU number).
The key is that, if you pass an argument like this
[torch.tensor]
or
{"example":torch.tensor}
Even though they are python list/dict, but DataParallel.forward() is not able to deal with these type of argument (And it won’t raise an error). So the fix is just to simply convert those argument (and all their elements) to python types.

1 Like