Maybe you splitting the original forward() method in two different ones breaks inheritence and/or autograd in some way?
Why do you need to split it in the first place? I think you could retrieve the midoutput result in the original forward by keeping an intermediate variable and returning it as well as the final output… Something like that:
def forward(x):
x = self.layer1(x)
midoutput = x
x = self.layer2(x)
return midoutput, x
Thanks, but I need to have 2 different forward functions, It’s important for to understand why doing that “breaking” the dataparallel settings, I think it’s also important for the community.
Then digging a bit in the doc shows why it acts like this: check the source code for DataParallel here.
What happens when you call model(images) is that you call the DataParallel module’s forward method, which handles all the parallelism. Inside the forward method, the submodules (your model split in the number of GPUs you specified) are called with return self.module(*inputs[0], **kwargs[0]), which in turn calls the forward method of the module itself… Which is not forward1 or forward2 in your case, and it doesn’t split it on multiple GPUs.
Hence my recommendations of keeping the original forward method or having two nested modules.
One solution I can imagine is inheriting the DataParallel class, and overriding the forward method with your own looping behaviour, but doing it without breaking anything will probably be a bit complex…
Thanks for your investigation, I saw the source code and I think you are right.
I will try your solution of two nested modules, can you give a bit more details about this solution?
I can also do something like:
def forward(self,x,flag):
if flag:
# do foraward1 code
else:
# do foraward2 code