Model.cuda() in pytorch

if I call model.cuda() in pytorch where model is where model is a subclass of nn.Module, and say if I have four GPUs, how it will utilize the GPUs and how do I know which GPUs that are using?

model.cuda() will push the parameters to the default device.
If you print a parameter’s device, you should see, which GPU is used:

print(model.fc.weight.device)
> device(type='cuda', index=0)

Also, have a look at nvidia-smi to see, which GPUs are used.

If you would like to use data parallel (copy the model onto all GPUs and send chunks of the batch to each GPU), have a look at this tutorial.

1 Like