I am building a framework for decentralized learning, a fully connected graph. The behavior of a single node with batch size 256 should be the same as 16 nodes with batch size 16 (when distributing the full data among 16 nodes randomly). My code works for a single client, however, once I increase the number of clients the loss got stagnating at almost 2.3 value and after a certain number of steps, it starts decreasing (after how many steps it starts decreasing depends on how much I increased the number of clients 4,8,16…), but all experiments stagnate in the same value of loss. I am not understanding why this happening.
Client class
class Client(object):
def __init__(self, client_id, local_steps, task, learning_rate, batch_size, device, lambd, model, train_loaders,
test_loaders):
self.client_id = client_id
self.device = device
self.local_steps = local_steps
self.task = task
self.learning_rate = learning_rate
self.batch_size = batch_size
# added for regularization
self.lambd = lambd
self.model = model
self.optimizer=torch.optim.SGD(self.model.parameters(), lr=self.learning_rate, momentum=0.0)
self.criterion = torch.nn.CrossEntropyLoss()
# train
self.train_set = train_loaders
self.train_loader = iter(self.train_set)
# test
self.test_set = test_loaders
self.test_loader = iter(self.test_set)
def train(self, model):
self.model = model
ac=[]
los=[]
run_loss=0
for completed_steps in range(self.local_steps):
correct = 0
total = 0
try:
inputs, labels = next(self.train_loader)
except StopIteration:
print("except train")
self.train_loader = iter(self.train_set)
inputs, labels = next(self.train_loader)
# zero the parameter gradients
self.optimizer.zero_grad()
# forward + backward + optimize
outputs = self.model(inputs)
loss = self.criterion(outputs, labels)
loss.backward()
self.optimizer.step()
run_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
acc = 100. * correct / total
ac.append(acc)
los.append(run_loss)
acc=sum(ac)/len(ac)
run_loss=sum(los)/len(los)
results = {'clientId': self.client_id, 'update_weight': self.model.state_dict(), 'train_acc': acc}
results['train_loss'] = run_loss
return results
The class responsible for training clients:
class trainClients(object):
def __init__(self, cfg, hydra_cfg):
# self.device = 'cuda' if args.gpu else 'cpu'
self.cfg = cfg
self.device = torch.device('cuda:0') if self.cfg.device else torch.device('cpu')
self.training_sets = self.test_dataset = None
self.models = []
self.optimizers = []
self.epoch = 0
self.client_id = 0
self.hydra_cfg = hydra_cfg
self.aggregation = False
def train_clients(self):
print("################ CIFAR TRAIN FUNCTION")
#WANDB log
if self.cfg.use_wandb:
if str(self.hydra_cfg["split"]) == "non_iid_split":
missing = self.cfg.split.missing
a = str(self.cfg.split.a)
else:
missing = "_"
a = "_"
top_name = self.cfg.split.top_name
run_name = "DL" + "_" + str(self.cfg.decentralized) + "_" + "Clt" + "_" + str(
self.cfg.num_clients) + "_" + "Rds" + "_" + str(self.cfg.rounds) + "_" + "Stp" + "_" + str(
self.cfg.client_training.local_steps) + "_" + "Top" + "_" + str(
top_name) + "_a_" + a + "_" + "Opt" + "_" + str(self.hydra_cfg["optim"]) + "_" + "Data" + "_" + str(
self.hydra_cfg["datamodule"]) + "_Splt_" + str(self.hydra_cfg["split"]) + "_miss_" + str(missing)
time_stamp = datetime.datetime.fromtimestamp(time.time()).strftime('%m%d_%H%M%S')
print(wandb.config)
wandb_run = instantiate(self.cfg.logger, id=run_name + '_' + "time_stamp= " + time_stamp)
print(self.cfg.logger)
self.cfg.client_training.client_id = self.client_id
wandb.config = OmegaConf.to_container(
self.cfg, resolve=True, throw_on_missing=True
)
datamodule = instantiate(self.cfg.datamodule)
# train and test data:
log.info("load_and_split")
train_sets, test_sets = datamodule.load_and_split(self.cfg)
# Create number of clients models and Create clients
log.info("Preparing the clients and the models and datasets..")
clients = []
self.cfg.model.num_classes = datamodule.num_classes
for i in range(self.cfg.num_clients):
print("i",i)
# Model
model = instantiate(self.cfg.model)
model.apply(initialize_weights)
self.models.append(model)
# Dataset
train_loader = datamodule.data_loaders(train_sets)
test_loader = datamodule.test_loader(test_sets)
# client object
self.cfg.client_training.client_id = self.client_id
client = instantiate(self.cfg.client_training, device=self.device, model=model, train_loaders=train_loader,
test_loaders=test_loader)
clients.append(client)
self.client_id += 1
datamodule.next_client()
self.client_id = 0
datamodule.set_client()
#Connections between clients:
adjacency_matrix = graph_generation(self.cfg.num_clients)
print(adjacency_matrix)
log.info("############# Start training ###############")
for t in range(self.cfg.rounds):
log.info("####### This is ROUND number {} ######".format(t))
datamodule.set_client()
self.client_id = 0
batch = 0
batches = self.cfg.split.data_points / self.cfg.datamodule.batch_size
log.info("batch SIZE {}".format(int(batches)))
test_acc, train_acc, test_loss, train_loss = [], [], [], []
while batch < int(batches):
datamodule.set_client()
self.client_id = 0
local_weights = OrderedDict()
# for client in clients:
for i in range(len(clients)):
results = clients[i].train(self.models[clients[i].client_id])
train_acc.append(results["train_acc"])
train_loss.append(results["train_loss"])
# test
test_results = clients[i].test()
test_acc.append(test_results["test_acc"])
test_loss.append(test_results["test_loss"])
if i==0:
wandb.log({'client/train_loss': results["train_loss"],
'client/test_loss': test_results["test_loss"],
'client/train_accuracy': results["train_acc"],
'client/test_accuracy': test_results["test_acc"]
},
step=batch + (t * int(batches)))
if batch % 1 == 0 and batch != 0:
log.info("aggregation step ......")
c=0
for model in self.models:
local_weights[c] = model.state_dict()
c=c+1
self.models= decentralized_average(local_weights, self.models)
train_acc_batch = sum(train_acc) / len(train_acc)
train_loss_batch = sum(train_loss) / len(train_loss)
test_acc_batch = sum(test_acc) / len(test_acc)
test_loss_batch = sum(test_loss) / len(test_loss)
wandb.log({'Batch/train_loss': train_loss_batch,
'Batch/test_loss': test_loss_batch,
'Batch/train_accuracy': train_acc_batch,
'Batch/test_accuracy': test_acc_batch
},
step=batch + (t * int(batches))) #step=
test_acc, train_acc, test_loss, train_loss = [], [], [], []
batch += 1
aggregation function:
def decentralized_average(w,models):
"""
Returns the average of the weights.
this function is correct only for the case of fully connected graph
"""
weights = list(w.values())
w_avg = copy.deepcopy(weights[0])
for key in w_avg.keys(): # keys are layer names
for i in range(1, len(weights)): # for all clients (all weights)
w_avg[key] += weights[i][key] #summing for that layer of name key all weights
w_avg[key] = torch.div(w_avg[key], len(weights)) #dividing to do the average
for i in range(len(models)):
models[i].load_state_dict(w_avg) # all n models will have the same weights which is the average
return models
Topology
def graph_generation(num_clients):
seed = 31
np.random.seed(seed)
G = networkx.binomial_graph(num_clients, p=1, seed=42) # 50 nodes, random probability of an edge
adjacency_matrix = networkx.adjacency_matrix(G).todense() + np.identity(num_clients)
graph_drawing(G, name="binomial_graph")
return adjacency_matrix / adjacency_matrix.sum(1)
The blue is 16 clients and the green is one client