Hi, I am currently working on a project about pipeline parallelism which is based on DeepSpeed. In this project, we are trying to dynamically change the pipeline size in a very efficient way.

In a simplified setting of 3 machines and a model with 12 layers, each node computes 4 layers. And then node 1 may die at any moment. Before node 1 dies, it has 2 minutes to transfer data out. (Like spot instances on AWS).

So before node 1 dies,

node 0: layer 0-3, node 1: layer 4-7, node 2: layer 8-11

After,

node 0: layer 0-5, node 2: layer 6-11

But node 1 may have computed several forward passes. To avoid recomputation, I will send out the intermediate tensors to other nodes and I want other nodes to finish the backward pass based on the information they receive from node 1. I wonder if it is possible.

The background may be confusing. Let me put a code snippet. And is it possible to make backward update the “linear” model on the “to_rank”?

```
def remapping_layer(self, from_rank, to_rank):
if self.global_rank == from_rank:
#doing forward pass
linear = nn.Linear(4, 5).to(self.device)
optimizer = torch.optim.SGD(
linear.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
x = torch.tensor(
[[1, 2, 3, 4.]], requires_grad=True).to(self.device)
y = linear(x)
self.coord_com.setTensor('y', y) #Send tensor through TCP Store
self.coord_com.setStateDict("linear", linear.state_dict()) #Send state dict through TCP Store
exit()
if self.global_rank == to_rank:
# Receive tensors and continue to do the backward pass
# Doesn't work since the 'y' received from from_rank doesn't associate the parameter of liner model on this machine
linear = nn.Linear(4, 5).to(self.device)
optimizer = torch.optim.SGD(
linear.parameters(), lr=0.1, momentum=0.9)
optimizer.zero_grad()
y = torch.zeros(1, 5).to(self.device)
y = self.coord_com.getTensor('y') # Get tensor from TCP Store
state_dict = self.coord_com.getStateDict("linear") # Get state dict of linear
linear.load_state_dict(state_dict)
print(f'rank:{to_rank},{y}') #rank:0,tensor([[ 1.8351, -0.8024, 0.7445, 2.3400, 0.4823]], device='cuda:0', requires_grad=True)
z = torch.tensor([2, 3, 4, 5., 6],
requires_grad=True).to(self.device)
loss = (z-y).sum()/5
loss.backward()
print(f"rank: 0, before {list(linear.parameters())}") # let it = statement A
optimizer.step()
print(f"rank: 0, {list(linear.parameters())}") # print the same thing as statement A
exit()
```