Hello, I am working on model partition training. I deploy same networks in two servers and they connected by wireless. The network structure is below. I deploy fore part of network on server 1 and hind part of network on server 2. Training data are given on server 1 and intermediate forward output is sent to server 2 by socket. Then, the hind part use intermediate forward output to train the hind part and the backward gradient is sent back to server 1 for fore part propagation. I code it but I find that my training does not converge. I want to know why? Thank you for any suggestion.
My code is below
class Net(nn.Module):
def __init__(self, num_classes=10):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
)
self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(256 * 6 * 6, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, num_classes),
)
def forward(self, x, mode):
if mode==0: # whole network
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
elif mode == 1: # fore part of network
x = self.features(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
else: # hind part of network
x = self.classifier(x)
return x
server 1
Model_1 = Net()
for param in Model_1 .parameters():
param.requires_grad = False
for parameter in Model_1 .features.parameters():
parameter.requires_grad = True
optimizer_1 = optim.SGD(filter(lambda p: p.requires_grad, Model_1.parameters()), lr=0.001, momentum=0.9)
def train(Model_1 , train_dataloader):
Model_1 .train()
process = 'train'
train_loss = 0.0
train_correct = 0
for i, data in enumerate(train_dataloader):
data, target = data[0].to(device), data[1].to(device)
optimizer_1.zero_grad()
middle = Model_1 (data, mode = 1)
intermediate = [process, middle, target]
msg = pickle.dumps(intermediate)
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.connect((host, port))
send_msg(s, msg)
recv_data = recv_msg(s)
recv_data = pickle.loads(recv_data)
gradient = recv_data[0]
batch_loss = recv_data[1]
batch_correct = recv_data[2]
train_loss += batch_loss
train_correct += batch_correct
gradient.to(device)
torch.autograd.backward(middle, grad_tensors=gradient)
optimizer_1 .step()
train_loss = train_loss / len(train_dataloader.dataset)
train_correct = 100. * train_correct / len(train_dataloader.dataset)
print(f'Train loss: {train_loss: .4f}, Train acc: {train_correct: .2f}')
return train_loss, train_correct
server 2
Model_2 = Net()
for param in Model_2 .parameters():
param.requires_grad = False
for parameter in Model_2 .classifier.parameters():
parameter.requires_grad = True
optimizer_2 = optim.SGD(filter(lambda p: p.requires_grad, Model_2.parameters()), lr=0.001, momentum=0.9)
def train(Model_2, data, target):
Model_2.train()
data.to(device)
target.to(device)
optimizer_2.zero_grad()
output = Model_2(data, mode = 2)
loss = criterion(output, target)
train_running_loss = loss.item()
_, preds = torch.max(output.data, 1)
train_running_correct = (preds == target).sum().item()
# data.register_hook(get_grad('middle'))
loss.backward()
gradient = data.grad
optimizer_2.step()
msg = [gradient, train_running_loss, train_running_correct]
return msg