class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(30, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.sigmoid(x)
return x
bobs_model = Net()
alices_model = Net()
bobs_optimizer = optim.Adam(bobs_model.parameters(), lr=args.lr)
alices_optimizer = optim.Adam(alices_model.parameters(), lr=args.lr)
models = [bobs_model, alices_model]
optimizers = [bobs_optimizer, alices_optimizer]
def update(data, target, model, optimizer):
model.send(data.location)
optimizer.zero_grad()
prediction = model(data)
print(prediction.view(-1))
loss = F.binary_cross_entropy(prediction.view(-1), target)
loss.backward()
optimizer.step()
return model
def train():
for data_index in range(len(remote_dataset[0])-1):
for remote_index in range(len(compute_nodes)):
data, target = remote_dataset[remote_index][data_index]
models[remote_index] = update(data, target, models[remote_index], optimizers[remote_index])
for model in models:
model.get()
return utils.federated_avg({
"bob": models[0],
"alice": models[1]
})
def test(federated_model):
correct = 0
total = 0
federated_model.eval()
test_loss = 0
for data, target in test_loader:
output = federated_model(data)
test_loss += F.binary_cross_entropy(output.view(-1), target, reduction='sum').item()
predection = output.data.max(1, keepdim=True)[1]
# outputs = federated_model(data)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
print('Accuracy of the network: %f %%' % (100 * correct / total))
test_loss /= len(test_loader.dataset)
print('Test set: Average loss: {:.4f}'.format(test_loss))
for epoch in range(args.epochs):
start_time = time.time()
print(f"Epoch Number {epoch + 1}")
federated_model = train()
model = federated_model
test(federated_model)
total_time = time.time() - start_time
print('Communication time over the network', round(total_time, 2), 's\n')