I updated the below script to reflect the interaction between different models.
I am implementing a model with multiple types of
An example should look like:
import torch.nn as nn from torch.nn.parallel import DataParallel, DistributedDataParallel class Model1(nn.Module): def __init__(self): ... def forward(self, x): ... def forward2(self, x, y): ... def forward3(self, z, w): ... # model1 = DP(Model1(), ...) model1 = DDP(Model1(), ...) model2 = DDP(Model2(), ...) ... out1 = model1(x) # parallized out2 = model1.module.forward2(x, y) # not parallelized in DP and z, w = model2(y) # model2 is also used somewhere else out3 = model1.module.forward3(z, w) # no communication in DDP
These forward functions are there for different purposes and are all necessary for training and inference.
However, DP or DDP-wrapped models do not directly parallelize those functions other than the default
forward that could be called without explicitly naming it.
How can I parallelize other
Here are several candidates that I could think of for now.
- define a big forward function with an option flag so that it could call other functions in it.
I think this would work but would cause a lot of if else statements in the functions.
Also, the input argument parsing part would be cluttered.
- register a function to DDP such that it would recognize other functions
I looked into some hooks but didn’t find a way for my case.
- create a custom
myDDPclass that inherits
DDPand implement other functions similar to forward.
This might be possible but I would need to update myDDP every time I define new functions.
If 2 is possible, that would be the most elegant solution and 1 could be a safe and dirty way.
Are there any suggestions?
p.s. I checked this thread but it does not apply to me.
My actual code is more complex and a more fundamental solution is necessary.