Thanks @ptrblck for your helpful suggestions.
I am actually running hyper-parameter optimizer (Scikit-Optimize). However, you are right, I have not shared the full code.
This is the main training loop:
train_loader_1 = DataLoader(train_dataset_1, batch_size=args.batch_size, shuffle=False, collate_fn=collate)
val_loader_1 = DataLoader(val_dataset_1, batch_size=len(val_dataset_fmri), shuffle=False, collate_fn=collate)
test_loader_1 = DataLoader(test_dataset_1, batch_size=len(test_dataset_fmri), shuffle=False, collate_fn=collate)
model = MyEnsemble(args, device)
train_loader_2 = DataLoader(train_dataset_2, batch_size=args.batch_size, shuffle=False, collate_fn=collate)
val_loader_2 = DataLoader(val_dataset_2, batch_size=len(val_dataset_dti), shuffle=False, collate_fn=collate)
test_loader_2 = DataLoader(test_dataset_2, batch_size=len(test_dataset_dti), shuffle=False, collate_fn=collate)
trainer = Trainer(model, args)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optim, mode='min',
factor=args.lr_reduce_factor,
patience=args.lr_schedule_patience,
verbose=False)
train_losses, train_accs, val_losses, val_accs= [], [], [], []
for epoch in range(args.n_epochs):
train_loss = 0
train_acc = 0
model.train()
bNo = 0
for bg_1, bg_2 in zip(train_loader_1,train_loader_2):
batch_graphs_1,batch_labels_1,batch_graphs_2,batch_labels_2 = \
bg_1[0],bg_1[1],bg_2[0],bg_2[1]
batch_graphs_1.set_e_initializer(dgl.init.zero_initializer)
batch_graphs_1.set_n_initializer(dgl.init.zero_initializer)
batch_graphs_2.set_e_initializer(dgl.init.zero_initializer)
batch_graphs_2.set_n_initializer(dgl.init.zero_initializer)
loss, acc = trainer.iteration(batch_graphs_1, batch_graphs_2, batch_labels_1)
train_loss += loss
train_acc += acc
bNo += 1
train_loss /= len(train_loader_1)
train_acc /= len(train_loader_1)
train_losses.append(train_loss)
train_accs.append(train_acc)
trainer.save(epoch, args.save_dir)
val_loss = 0
val_acc = 0
model.eval()
for bg_1, bg_dti in zip(val_loader_1,val_loader_2):
batch_graphs_1,batch_labels_1,batch_graphs_2,batch_labels_2 = \
bg_1[0],bg_1[1],bg_2[0],bg_2[1]
batch_graphs_1.set_e_initializer(dgl.init.zero_initializer)
batch_graphs_1.set_n_initializer(dgl.init.zero_initializer)
batch_graphs_2.set_e_initializer(dgl.init.zero_initializer)
batch_graphs_2.set_n_initializer(dgl.init.zero_initializer)
loss, acc = trainer.iteration(batch_graphs_1, batch_graphs_2, batch_labels_1, train=False)
val_loss += loss
val_acc += acc
val_loss /= len(val_loader_1)
val_acc /= len(val_loader_1)
val_losses.append(val_loss)
val_accs.append(val_acc)
This is my Trainer:
class Trainer:
def __init__(self, model, args):
self.model = model
self.device = args.device
self.optim = torch.optim.Adam(self.model.parameters(), lr=args.lr)
print('Total Parameters:', sum([p.nelement() for p in self.model.parameters()]))
def iteration(self, g_1, g_2, labels, train=True):
labels = labels.to(self.device)
scores = self.model.forward(g_1,g_2)
loss = self.model.loss(scores,labels)
acc = accuracy(scores, labels)
if train:
self.optim.zero_grad()
loss.backward()
self.optim.step()
return loss.item(), acc
def save(self, epoch, save_dir):
output_path = os.path.join(save_dir, 'ep{:02}.pkl'.format(epoch))
torch.save(self.model.state_dict(), output_path)
This is GCNNet
class GCNNet(nn.Module):
def __init__(self, in_dim, hidden_dims, readout, device = "cpu"):
super(GCNNet, self).__init__()
self.readout = readout
layers = [GCN(in_dim, hidden_dims[0], activation =F.relu)]
if len(hidden_dims)>=2:
layers = [GCN(in_dim, hidden_dims[0], activation =F.relu)]
for i in range(1,len(hidden_dims)):
if i != len(hidden_dims)-1:
layers.append(GCN(hidden_dims[i-1], hidden_dims[i], activation = F.relu))
else:
layers.append(GCN(hidden_dims[i-1], hidden_dims[i], activation =lambda x:x))# no activation in=x, out=x
else:
layers = [GCN(in_dim, hidden_dims[0], activation =lambda x:x)]
self.layers = nn.ModuleList(layers)
self.device = device
if self.device != "cpu":
self.cuda()
def forward(self, g):
h = g.ndata['feat'].to(self.device)# you can replace feat with h, depends on how you construct the graph data in data_prep
g= g.to(self.device)
for conv in self.layers:
h = conv(g, h)
g.ndata['feat'] = h
if self.readout == "sum":
hg = dgl.sum_nodes(g, 'feat')
elif self.readout == "max":
hg = dgl.max_nodes(g, 'feat')
elif self.readout == "mean":
hg = dgl.mean_nodes(g, 'feat')
elif self.readout == "attn_pool":
# global attention pooling
hg = GlobalAttentionPooling(g, 'feat')
else:
hg = dgl.mean_nodes(g, 'feat') # default readout is mean nodes
return hg#self.MLP_layer(hg)