How to do backward pass when parts of a model are on separate systems?

I’m trying to understand how to do the backward pass when parts of a model are completely separated, in other words they are not part of one calculation graph. I can simplify the problem to this example. Let’s say I have these declarations

modelPart1 = torch.nn.Sequential(
   torch.nn.Linear(4, 4),
   torch.nn.GELU(),
   torch.nn.Linear(4, 4),
   torch.nn.GELU(),
   torch.nn.Linear(4, 4),
   torch.nn.GELU()
)

modelPart2 = torch.nn.Sequential(
   torch.nn.Linear(4, 4),
   torch.nn.GELU(),
   torch.nn.Linear(4, 4),
   torch.nn.GELU(),
   torch.nn.Linear(4, 1)
)

where modelPart2 is continuation of modelPart1, they form one model, and let’s say that they are declared on different systems. Then for the forward pass I have to do this on the first system

outputPart1 = modelPart1(inputPart1)

outputPart 1 is just a tensor with 4 values, so I can pass them to the second system over the network and, create a tensor out of them and do this on the second system

outputPart2 = modelPart2(outputPart1)

So far everything is clear, the output of the first part is the input of the second part. Now, let’s say I know the label value (labelPart2) and I can apply loss function

lossPart2 = torch.nn.functional.mse_loss(outputPart2, labelPart2)

At this point I can run the backward step on the second system

lossPart2.backward()

And at this point I’m lost. What I have to pass from second system to the first system to continue this backward pass and how exactly I have to call backward method and on what object?

@alex.49.98 Welcome to the forums!

Have you tried DDP?

https://pytorch.org/tutorials/intermediate/ddp_tutorial.html

More specific to your question, you can just do something like this if you don’t want to bother with DDP. However, I do recommend trying to give it a go, as it’s much better optimized for speed.

import torch.nn as nn
import torch

modelPart1 = torch.nn.Sequential(
   torch.nn.Linear(4, 4),
   torch.nn.GELU(),
   torch.nn.Linear(4, 4),
   torch.nn.GELU(),
   torch.nn.Linear(4, 4),
   torch.nn.GELU()
)

modelPart2 = torch.nn.Sequential(
   torch.nn.Linear(4, 4),
   torch.nn.GELU(),
   torch.nn.Linear(4, 4),
   torch.nn.GELU(),
   torch.nn.Linear(4, 1)
)

modelPart1.to('cuda:0')
modelPart2.to('cuda:1')

criterion = nn.MSELoss()
optimizer1 = torch.optim.SGD(modelPart1.parameters(), lr=0.001)
optimizer2 = torch.optim.SGD(modelPart2.parameters(), lr=0.001)

dummy_inputs = torch.rand(10, 4, device = 'cuda:0')
outputs1 = modelPart1(dummy_inputs)
outputs2 = modelPart2(outputs1.to(device = 'cuda:1'))

targets = torch.rand(10,1, device = 'cuda:1')

optimizer1.zero_grad()
optimizer2.zero_grad()

loss = criterion(outputs2, targets)

loss.backward()

optimizer2.step()
optimizer1.step()

# show gradients are populated in both
for param in modelPart1.parameters():
    print(param.grad)
for param in modelPart2.parameters():
    print(param.grad)

Note the gradients get stored on each device where the submodels are located. Autograd takes care of making sure the stream flows according to the forward pass. The graph is charted across devices. You can read more here:

https://pytorch.org/docs/stable/notes/cuda.html#stream-semantics-of-backward-passes

As @J_Johnson explained, the computation graph will be created if you move the data around using PyTorch ops. If you want to explicitly detach the forward pass you could call outputPart1.backward(input2.grad).

1 Like

That was simpler than I expected and exactly what I needed. Tested. Compared resulting gradients against the combined model. Perfect match. Thank you.

I suppose DDP is a great thing which I’m unable to use. I’m using TorchSharp - PyTorch port to .NET. Unfortunately, DDP is not a part of the port, but everything that is related to base PyTorch is ported and I’m researching my options for implementing parallel training. I have a farm of servers with 2 GPUs per server and ideally would like to do something like the PipeDream algorithm for a particular NN topology. Taking all this into account, please let me know if I might miss something important in the PyTorch base implementation that would help me in the parallel training implementation. By the way, does the concept of “device” that is used in methods like “to” covers only local resources or it can be spanned across network? Is there a way to extend this concept and create a custom device to specify a GPU on another node on the network?