Why model partition training not converging?

Hello, I am working on model partition training. I deploy same networks in two servers and they connected by wireless. The network structure is below. I deploy fore part of network on server 1 and hind part of network on server 2. Training data are given on server 1 and intermediate forward output is sent to server 2 by socket. Then, the hind part use intermediate forward output to train the hind part and the backward gradient is sent back to server 1 for fore part propagation. I code it but I find that my training does not converge. I want to know why? Thank you for any suggestion.
My code is below

class Net(nn.Module):
def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
        nn.Conv2d(256, 256, kernel_size=3, padding=1),
        nn.MaxPool2d(kernel_size=3, stride=2),
    self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
    self.classifier = nn.Sequential(
        nn.Linear(256 * 6 * 6, 4096),
        nn.Linear(4096, 4096),
        nn.Linear(4096, num_classes),

def forward(self, x, mode):
   if mode==0: # whole network
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.classifier(x)
  elif mode == 1: # fore part of network
        x = self.features(x)
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
  else: # hind part of network
        x = self.classifier(x)
  return x

server 1

Model_1 = Net()
for param in Model_1 .parameters():
    param.requires_grad = False

for parameter in Model_1 .features.parameters():
    parameter.requires_grad = True

optimizer_1 = optim.SGD(filter(lambda p: p.requires_grad, Model_1.parameters()), lr=0.001, momentum=0.9)

def train(Model_1 , train_dataloader):
    Model_1 .train()
    process = 'train'
    train_loss = 0.0
    train_correct = 0
    for i, data in enumerate(train_dataloader):
        data, target = data[0].to(device), data[1].to(device)

        middle = Model_1 (data, mode = 1)
        intermediate = [process, middle, target]
        msg = pickle.dumps(intermediate)

        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.connect((host, port))
        send_msg(s, msg)

        recv_data = recv_msg(s)
        recv_data = pickle.loads(recv_data)
        gradient = recv_data[0]
        batch_loss = recv_data[1]
        batch_correct = recv_data[2]
        train_loss += batch_loss
        train_correct += batch_correct

        torch.autograd.backward(middle, grad_tensors=gradient)

        optimizer_1 .step()

    train_loss = train_loss / len(train_dataloader.dataset)
    train_correct = 100. * train_correct / len(train_dataloader.dataset)

    print(f'Train loss: {train_loss: .4f}, Train acc: {train_correct: .2f}')
    return train_loss, train_correct

server 2

Model_2 = Net()
for param in Model_2 .parameters():
    param.requires_grad = False

for parameter in Model_2 .classifier.parameters():
    parameter.requires_grad = True

optimizer_2 = optim.SGD(filter(lambda p: p.requires_grad, Model_2.parameters()), lr=0.001, momentum=0.9)

def train(Model_2, data, target):
    output = Model_2(data, mode = 2)

    loss = criterion(output, target)
    train_running_loss = loss.item()
    _, preds = torch.max(output.data, 1)
    train_running_correct = (preds == target).sum().item()

    # data.register_hook(get_grad('middle'))


    gradient = data.grad


    msg = [gradient, train_running_loss, train_running_correct]

    return msg