Model Parallel Pipelining not working

@mrshenli

I also observed a similar result.

This is the thread I created in the discussion.

I also benchmarked this using multiple configurations.

I am not sure about the concurrent run of the code. So I changed it as follows and got a sort of fine result.

class PipelineParallelResNet50(ModelParallelResNet50):
    def __init__(self, split_size=20, *args, **kwargs):
        super(PipelineParallelResNet50, self).__init__(*args, **kwargs)
        self.split_size = split_size

    def taskA(self, s_prev, ret):
        s_prev = self.seq2(s_prev)
        ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

    def taskB(self, s_next):
        s_prev = self.seq1(s_next).to('cuda:1')
        return s_prev

    def forward(self, x):
        splits = iter(x.split(self.split_size, dim=0))
        s_next = next(splits)
        s_prev = self.seq1(s_next).to('cuda:1')
        ret = []

        for s_next in splits:
            # A. s_prev runs on cuda:1
            # self.taskA(s_prev=s_prev, ret=ret)
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futureA = executor.submit(self.taskA, s_prev, ret)
                futureA.result()

            # B. s_next runs on cuda:0, which can run concurrently with A
            with concurrent.futures.ThreadPoolExecutor() as executor:
                futureB = executor.submit(self.taskB, s_next)
                s_prev = futureB.result()

        s_prev = self.seq2(s_prev)
        ret.append(self.fc(s_prev.view(s_prev.size(0), -1)))

        return torch.cat(ret)