I also observed a similar result.
This is the thread I created in the discussion.
I also benchmarked this using multiple configurations.
I am not sure about the concurrent run of the code. So I changed it as follows and got a sort of fine result.
class PipelineParallelResNet50(ModelParallelResNet50):
def __init__(self, split_size=20, *args, **kwargs):
super(PipelineParallelResNet50, self).__init__(*args, **kwargs)
self.split_size = split_size
def taskA(self, s_prev, ret):
s_prev = self.seq2(s_prev)
ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
def taskB(self, s_next):
s_prev = self.seq1(s_next).to('cuda:1')
return s_prev
def forward(self, x):
splits = iter(x.split(self.split_size, dim=0))
s_next = next(splits)
s_prev = self.seq1(s_next).to('cuda:1')
ret = []
for s_next in splits:
# A. s_prev runs on cuda:1
# self.taskA(s_prev=s_prev, ret=ret)
with concurrent.futures.ThreadPoolExecutor() as executor:
futureA = executor.submit(self.taskA, s_prev, ret)
futureA.result()
# B. s_next runs on cuda:0, which can run concurrently with A
with concurrent.futures.ThreadPoolExecutor() as executor:
futureB = executor.submit(self.taskB, s_next)
s_prev = futureB.result()
s_prev = self.seq2(s_prev)
ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))
return torch.cat(ret)