How to use predict function, when model was trained with DataParallel?

Hello,

What is the correct way to get predictions when model is trained with DataParallel?

I’ve trained a model which uses the following to make use of multiple GPUs.
model = nn.DataParallel(model)

I save the model with,
torch.save(model, model_home+‘best_model.pth’)

Load and run predictions:
best_model = torch.load(model_home+‘best_model.pth’)

predictions = best_model.predict(x_tensor)

I run into the below error:
ModuleAttributeError: ‘DataParallel’ object has no attribute ‘predict’

Thanks

In case your original model provides a predict method, you could use best_model.module.predict.
nn.DataParallel will use the forward method to in its data parallel approach and will ignore your custom methods. If you want to use predict in the same data parallel way, you would have to use it in your forward method instead.

1 Like

I tried to change it this way without the predict function when DataParallel was used while training.

predictions = best_model(x_tensor)

It seems to be predicting, but not sure if this is the right way?

predictions = best_model(x_tensor) would call into __call__ and then into the forward method.
I don’t know how predict is defined and what the difference between it and forward would be.
In case both methods are doing the same, your approach should be fine.

Got it. thank you.
My understanding so far, for training without DataParallel, I can use predict function, other wise use the model(tensor) which would use call (then the forward method), which will work to get predictions.

This might work, but note that it’s not a general rule since your predict method is a custom function which is not a standard method in nn.Module.
The usual and required workflow is to override the forward method. You should thus compare your predict method with the forward and check what the difference is.
nn.DataParallel doesn’t have any knowledge about custom methods (such as predict) and thus will use the standard forward method to use the data parallel approach.

Got it. That makes sense. Will keep in mind. Thank you