How to parallelise a pytorch model on GPU?

Just for the sake of completeness:
Based on a small chat, it seems this code base is used.
Currently the Trainer class provides convenient methods to train the model. However, skimming through the code it looks like some refactoring would be needed to make this code executable for nn.DataParallel, e.g. since the optimizer seems to be embedded in the trainer class.

I’m also not sure how these lines of code would be handled by nn.DataParallel, since no GPU id is passed to the cuda calls. It’s currently a guess, but I think this might also cause the OOM issue in this case.

1 Like