Say I have an original model that I split into two chunks (at a so-called cut-layer); a client-side and server-side. When I perform backpropagation on the server-side, I need to continue the backward pass on the client-side as well to finalize a single backward pass over the entire model.
class ClientModel(nn.Module):
def __init__(self):
super(ClientModel, self).__init__()
num_flat_features = 3 * 224 * 224
self.layers = nn.Sequential(
nn.Flatten(),
nn.Linear(num_flat_features, 128),
nn.ReLU()
)
def forward(self, x):
x = self.layers(x)
return x
class ServerModel(nn.Module):
def __init__(self):
super(ServerModel, self).__init__()
self.cut_layer = nn.Linear(128, 64)
self.relu = nn.ReLU()
self.cls_layer = nn.Linear(64, 10)
def forward(self, x):
x = self.cut_layer(x)
x = self.relu(x)
x = self.cls_layer(x)
return x
If after BP on the server-side I obtain the cut-layer gradients via cut_layer.weight.grad
, how do I then continue the backpropagation on the client-side using these cut-layer gradients?
Any help is much appreciated.