Instead of wrapping block2 into DataParallel before run time, should I wrap it in the forward function at run time, such as:
class DataParallelModel(nn.Module):
def __init__(self):
super().__init__()
self.block1 = nn.Linear(10, 20)
# Do not wrap block2 in DataParallel here
self.block2 = nn.Linear(20, 20)
self.block3 = nn.Linear(20, 20)
def forward(self, x, Num_GPUs):
x = self.block1(x)
if Num_GPUs > 1:
x = nn.DataParallel(self.block2)(x) # Do it here
x = self.block3(x)
return x