Multiple forward functions in DP and DDP

I updated the below script to reflect the interaction between different models.


I am implementing a model with multiple types of forward functions.
An example should look like:

import torch.nn as nn
from torch.nn.parallel import DataParallel, DistributedDataParallel

class Model1(nn.Module):
    def __init__(self):
    def forward(self, x):
    def forward2(self, x, y):
    def forward3(self, z, w):

# model1 = DP(Model1(), ...)
model1 = DDP(Model1(), ...)
model2 = DDP(Model2(), ...)

out1 = model1(x)                       # parallized
out2 = model1.module.forward2(x, y)    # not parallelized in DP and

z, w = model2(y)                       # model2 is also used somewhere else
out3 = model1.module.forward3(z, w)    # no communication in DDP

These forward functions are there for different purposes and are all necessary for training and inference.
However, DP or DDP-wrapped models do not directly parallelize those functions other than the default forward that could be called without explicitly naming it.
How can I parallelize other forward2 and forward3?

Here are several candidates that I could think of for now.

  1. define a big forward function with an option flag so that it could call other functions in it.

I think this would work but would cause a lot of if else statements in the functions.
Also, the input argument parsing part would be cluttered.

  1. register a function to DDP such that it would recognize other functions

I looked into some hooks but didn’t find a way for my case.

  1. create a custom myDDP class that inherits DDP and implement other functions similar to forward.

This might be possible but I would need to update myDDP every time I define new functions.

If 2 is possible, that would be the most elegant solution and 1 could be a safe and dirty way.
Are there any suggestions?

p.s. I checked this thread but it does not apply to me.
My actual code is more complex and a more fundamental solution is necessary.

How do you train your model locally? Do you call your forward functions in a specific order? Do they train different parts of your model?

Hi @cbalioglu

Thank you for looking into this issue:)

  1. training

I am training an invertible network with forward/inverse path and optionally, jointly with other models.
The invertible network has several other functions, too.

  1. function call order

I (usually) call the functions in order but not always:


model2.func1    # sometimes omitted for different purposes

model1.func3    # also called outside the training loop

I am using the intermediate outputs to compute loss functions.
I sometimes call func3 independently from the above flow.

  1. training parts per function

The training weights are partially shared across different functions while some don’t require weights.

Hey @seungjun, yep, I confirm that your 1st option should work as DDP only calls the default forward() function.

However, there is a caveat. DDP has internal states that requires alternating forward and backward. So if you call things like foward, forward, backward, DDP is likely to hang or crash.

Hi @mrshenli,

Thank you for confirming the 1st option and pointing to the related part of the DDP source code.
I checked the DDP implementation and it seems that option 1 is the only possible way for now.

forward is the only function that DDP supports safe parallelization and going for option 3 would be an adventure.

By the way, I’m not sure if I could avoid the function call patterns like forward, forward, backward you mentioned.
Thank you very much and I will post here when I come up with a nice solution.