If there any good practices to parallel model prediction calculations, not training. So if we have a pretrained pth-model how to run it on a few GPUs? Is there a doc with this specific topic covered?
You should have a look at the nn.DataParallel
and nn.distributed.DistributedDataParallel
docs. Both are fairly easy objects to work with. It will basically split your input in the batch dimension across the ids of the GPUs that you pass at initialization.
Note that some model architectures are not able to be parallelized across devices.
Also, if you are just running inference, you may not see any benefit to multi-GPU parallelization. Or even using a GPU at all. You might try running the model after a call to model.eval()
if you are experiencing performance issues.
It will basically split your input in the batch dimension across the ids of the GPUs that you pass at initialization.
So this code is enough:
model = torch.nn.Module(options)
...
if torch.cuda.is_available():
ids = [i for i in range(torch.cuda.device_count())]
model = torch.nn.DataParallel(model, device_ids=ids).cuda()
os.environ['CUDA_VISIBLE_DEVICES'] = ids.join(',')[:-1]
print("Using ", len(ids), " GPUs!")
...
model_results = model(input)
Note that some model architectures are not able to be parallelized across devices.
What does it depend on?
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.conv_first.weight",
really long list here...
Should we prepare the code anyhow for wrapping in DataParallel?
nn.DataParallel
adds a .module
attribute to the model, so that you might see these key errors while trying to load a state_dict
from a plain PyTorch model.
You could either add/remove the .module
keys manually or store and load the state_dict
using the plain model without the nn.DataParallel
wrapper.
To store the state_dict
you would use torch.save(model.module.state_dict(), path)
, while you could just load the state_dict
before wrapping the model into nn.DataParallel
.