[DataParallel] How to split the forward function into multiple steps

Hello,

I am implementing a custom network, which first computes a latent encoding, and then does some forward passes. I am wondering whether this kind of network can work on nn.DataParallel. I assume nn.DataParallel only does some parallelism on self.forward. Does it also handle the custom functions as well?

def MyNetwork(nn.Module):
    def preprocess(self, a):
        self.encoding = encode(a)

    def forward(self, b):
        return self.encoding + self.fc(b)

model = nn.DataParallel(MyNetwork())
# Pre encoding
model.module.preprocess(a)
# Parallel data pass
output1 = model(b1)
output2 = model(b2)

nn.DataParallel will not use the passed GPUs in model.module.preprocess, since the forward method will be used to create the model clones, split the data etc., as you’ve already assumed.
You could call other functions inside the forward method, if you want to use all devices.