How to continue BP using intermediate gradients

Say I have an original model that I split into two chunks (at a so-called cut-layer); a client-side and server-side. When I perform backpropagation on the server-side, I need to continue the backward pass on the client-side as well to finalize a single backward pass over the entire model.

class ClientModel(nn.Module):

    def __init__(self):
        super(ClientModel, self).__init__()

        num_flat_features = 3 * 224 * 224

        self.layers = nn.Sequential(
            nn.Flatten(),
            nn.Linear(num_flat_features, 128),
            nn.ReLU()
        )

    def forward(self, x):
        x = self.layers(x)

        return x

class ServerModel(nn.Module):

    def __init__(self):
        super(ServerModel, self).__init__()

        self.cut_layer = nn.Linear(128, 64)
        self.relu = nn.ReLU()
        self.cls_layer = nn.Linear(64, 10)

    def forward(self, x):
        x = self.cut_layer(x)
        x = self.relu(x)
        x = self.cls_layer(x)

        return x

If after BP on the server-side I obtain the cut-layer gradients via cut_layer.weight.grad, how do I then continue the backpropagation on the client-side using these cut-layer gradients?

Any help is much appreciated.

Doing two backward calls back to back and passing the intermediate gradient is equivalent to doing a single backward torch.autograd.backward(intermediate_output, grad_tensors=(cut_layer.weight.grad,))

Thanks for your reply!

When using what you mentioned, I receive the following error;
RuntimeError: Mismatch in shape: grad_output[0] has a shape of torch.Size([64, 128]) and output[0] has a shape of torch.Size([512, 128]).

My batch size is 512. So it looks like the gradients of the cut-layer are not of the same dimension as the activations when I do this. Which I suppose makes sense but then how would I go about this?

Ah, you should use retain_grad to ensure that .grad is recorded for the activation as well, and make sure you use the .grad wrt the activation.

I’m indeed already using activations.retain_grad()

How do I explicitly use the .grad wrt the activations? Should I extract those cut-layer gradients in a different manner than just server_model.cut_layer.weight.grad?

If you have a reference to the activation Tensor somewhere, you can access its .grad attribute. If not, there’s a little more work you would need to do, but still possible.