I updated the below script to reflect the interaction between different models.
Hi,
I am implementing a model with multiple types of forward
functions.
An example should look like:
import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel
class Model1(nn.Module):
def __init__(self):
...
def forward(self, x):
...
def forward2(self, x, y):
...
def forward3(self, z, w):
...
# model1 = DP(Model1(), ...)
model1 = DDP(Model1(), ...)
model2 = DDP(Model2(), ...)
...
out1 = model1(x) # parallized
out2 = model1.module.forward2(x, y) # not parallelized in DP and
z, w = model2(y) # model2 is also used somewhere else
out3 = model1.module.forward3(z, w) # no communication in DDP
These forward functions are there for different purposes and are all necessary for training and inference.
However, DP or DDP-wrapped models do not directly parallelize those functions other than the default forward
that could be called without explicitly naming it.
How can I parallelize other forward2
and forward3
?
Here are several candidates that I could think of for now.
- define a big forward function with an option flag so that it could call other functions in it.
I think this would work but would cause a lot of if else statements in the functions.
Also, the input argument parsing part would be cluttered.
- register a function to DDP such that it would recognize other functions
I looked into some hooks but didn’t find a way for my case.
- create a custom
myDDP
class that inheritsDDP
and implement other functions similar to forward.
This might be possible but I would need to update myDDP every time I define new functions.
If 2 is possible, that would be the most elegant solution and 1 could be a safe and dirty way.
Are there any suggestions?
p.s. I checked this thread but it does not apply to me.
My actual code is more complex and a more fundamental solution is necessary.