Run forward on machine A, and run backward on machine B

Hi, I am currently working on a project about pipeline parallelism which is based on DeepSpeed. In this project, we are trying to dynamically change the pipeline size in a very efficient way.
In a simplified setting of 3 machines and a model with 12 layers, each node computes 4 layers. And then node 1 may die at any moment. Before node 1 dies, it has 2 minutes to transfer data out. (Like spot instances on AWS).
So before node 1 dies,
node 0: layer 0-3, node 1: layer 4-7, node 2: layer 8-11
After,
node 0: layer 0-5, node 2: layer 6-11
But node 1 may have computed several forward passes. To avoid recomputation, I will send out the intermediate tensors to other nodes and I want other nodes to finish the backward pass based on the information they receive from node 1. I wonder if it is possible.

The background may be confusing. Let me put a code snippet. And is it possible to make backward update the “linear” model on the “to_rank”?

 def remapping_layer(self, from_rank, to_rank):
        if self.global_rank == from_rank:
            #doing forward pass
            linear = nn.Linear(4, 5).to(self.device)
            optimizer = torch.optim.SGD(
                linear.parameters(), lr=0.1, momentum=0.9)
            optimizer.zero_grad()
            x = torch.tensor(
                [[1, 2, 3, 4.]], requires_grad=True).to(self.device)
            y = linear(x)

            self.coord_com.setTensor('y', y) #Send tensor through TCP Store
            self.coord_com.setStateDict("linear", linear.state_dict()) #Send state dict through TCP Store
            exit()
        if self.global_rank == to_rank:
            # Receive tensors and continue to do the backward pass
           # Doesn't work since  the 'y' received from from_rank doesn't associate the parameter of liner model on this machine
            linear = nn.Linear(4, 5).to(self.device)
            optimizer = torch.optim.SGD(
                linear.parameters(), lr=0.1, momentum=0.9)
            optimizer.zero_grad()

            y = torch.zeros(1, 5).to(self.device)
            y = self.coord_com.getTensor('y') # Get tensor from TCP Store

            state_dict = self.coord_com.getStateDict("linear") # Get state dict of linear
            linear.load_state_dict(state_dict)

            print(f'rank:{to_rank},{y}') #rank:0,tensor([[ 1.8351, -0.8024,  0.7445,  2.3400,  0.4823]], device='cuda:0', requires_grad=True)

            z = torch.tensor([2, 3, 4, 5., 6],
                             requires_grad=True).to(self.device)
            loss = (z-y).sum()/5
            loss.backward()
            print(f"rank: 0, before {list(linear.parameters())}") # let it = statement A

            optimizer.step()
            print(f"rank: 0, {list(linear.parameters())}") # print the same thing as statement A
            exit()