I’m trying to understand how to do the backward pass when parts of a model are completely separated, in other words they are not part of one calculation graph. I can simplify the problem to this example. Let’s say I have these declarations
modelPart1 = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.GELU(),
torch.nn.Linear(4, 4),
torch.nn.GELU(),
torch.nn.Linear(4, 4),
torch.nn.GELU()
)
modelPart2 = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.GELU(),
torch.nn.Linear(4, 4),
torch.nn.GELU(),
torch.nn.Linear(4, 1)
)
where modelPart2 is continuation of modelPart1, they form one model, and let’s say that they are declared on different systems. Then for the forward pass I have to do this on the first system
outputPart1 = modelPart1(inputPart1)
outputPart 1 is just a tensor with 4 values, so I can pass them to the second system over the network and, create a tensor out of them and do this on the second system
outputPart2 = modelPart2(outputPart1)
So far everything is clear, the output of the first part is the input of the second part. Now, let’s say I know the label value (labelPart2) and I can apply loss function
lossPart2 = torch.nn.functional.mse_loss(outputPart2, labelPart2)
At this point I can run the backward step on the second system
lossPart2.backward()
And at this point I’m lost. What I have to pass from second system to the first system to continue this backward pass and how exactly I have to call backward method and on what object?