Ensemble model in Pytorch, the correct backpropagation

I am trying to build an ensemble of graph convolution models. However, when training, the performance is very poor.
Is there anything wrong with my code?

class MyEnsemble(nn.Module):
    def __init__(self, params, device='cpu'):
        super(MyEnsemble, self).__init__()
        self.model_1= GCNNet(params.in_dim1,params.hidden_dim1,params.readout, device)
        self.model_2 = GCNNet(params.in_dim2,params.hidden_dim2,params.readout,device)
        # Output layer
        self.out = nn.Linear(params.hidden_dim1[-1]+params.hidden_dims2[-1], params.n_classes)# in_dim = concat(data1,data2) 
        
        self.device = device
        if self.device != "cpu":
            self.cuda()
            
    def forward(self, x_1, x_2):# x_1 and x_2: graph data
        x1 = self.model_1.forward(x_1.clone())  # clone to make sure x is not changed by inplace 
        x2 = self.model_2.forward(x_2.clone())
        x = torch.cat((x1, x2), dim=1) # concatenate model_1 and model_2 readout layers
        output= F.softmax(self.out(x), dim=1)
        return output
    
    def loss(self, pred, label):
        criterion = nn.CrossEntropyLoss()
        loss = criterion(pred, label)
        return loss

nn.CrossEntropyLoss expects raw logits as the model output so you should remove the F.softmax operation in forward.

Thank you so much.
I will remove the Softmax and keep the Linear.

I have tried your suggestion, the performance is still the same.
I guess something else is causing this issue.

Try to overfit a small dataset, e.g. just 10 samples, by playing around with the hyperparamters.
Once your model is able to do so, try to scale up the use case again. If your model is not able to perfectly learn these 10 samples, some other bug might be in the code which we haven’t found yet.

Thanks @ptrblck for your helpful suggestions.

I am actually running hyper-parameter optimizer (Scikit-Optimize). However, you are right, I have not shared the full code.

This is the main training loop:

train_loader_1 = DataLoader(train_dataset_1, batch_size=args.batch_size, shuffle=False, collate_fn=collate)
val_loader_1 = DataLoader(val_dataset_1, batch_size=len(val_dataset_fmri), shuffle=False, collate_fn=collate)
test_loader_1 = DataLoader(test_dataset_1, batch_size=len(test_dataset_fmri), shuffle=False, collate_fn=collate)
    
model = MyEnsemble(args, device)
train_loader_2 = DataLoader(train_dataset_2, batch_size=args.batch_size, shuffle=False, collate_fn=collate)
val_loader_2 = DataLoader(val_dataset_2, batch_size=len(val_dataset_dti), shuffle=False, collate_fn=collate)
test_loader_2 = DataLoader(test_dataset_2, batch_size=len(test_dataset_dti), shuffle=False, collate_fn=collate)
    
trainer = Trainer(model, args)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(trainer.optim, mode='min',
                                                     factor=args.lr_reduce_factor,
                                                     patience=args.lr_schedule_patience,
                                                     verbose=False)
    
train_losses, train_accs, val_losses, val_accs= [], [], [], []
for epoch in range(args.n_epochs):
        train_loss = 0 
        train_acc = 0
        model.train()

        bNo = 0
        for bg_1, bg_2 in zip(train_loader_1,train_loader_2):
            batch_graphs_1,batch_labels_1,batch_graphs_2,batch_labels_2 = \
                bg_1[0],bg_1[1],bg_2[0],bg_2[1]
            batch_graphs_1.set_e_initializer(dgl.init.zero_initializer)
            batch_graphs_1.set_n_initializer(dgl.init.zero_initializer) 
            batch_graphs_2.set_e_initializer(dgl.init.zero_initializer)
            batch_graphs_2.set_n_initializer(dgl.init.zero_initializer)                
            loss, acc = trainer.iteration(batch_graphs_1, batch_graphs_2, batch_labels_1)
            train_loss += loss
            train_acc += acc
            bNo += 1
        train_loss /= len(train_loader_1)
        train_acc /= len(train_loader_1)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        trainer.save(epoch, args.save_dir)

        val_loss = 0
        val_acc = 0
        model.eval()
        for bg_1, bg_dti in zip(val_loader_1,val_loader_2):
            batch_graphs_1,batch_labels_1,batch_graphs_2,batch_labels_2 = \
                bg_1[0],bg_1[1],bg_2[0],bg_2[1]
            batch_graphs_1.set_e_initializer(dgl.init.zero_initializer)
            batch_graphs_1.set_n_initializer(dgl.init.zero_initializer) 
            batch_graphs_2.set_e_initializer(dgl.init.zero_initializer)
            batch_graphs_2.set_n_initializer(dgl.init.zero_initializer)
            loss, acc = trainer.iteration(batch_graphs_1, batch_graphs_2, batch_labels_1, train=False)
            val_loss += loss
            val_acc += acc
        val_loss /= len(val_loader_1)
        val_acc /= len(val_loader_1)
        val_losses.append(val_loss)
        val_accs.append(val_acc)

This is my Trainer:

class Trainer:
    def __init__(self, model, args):
        self.model = model
        self.device = args.device
        self.optim = torch.optim.Adam(self.model.parameters(), lr=args.lr)
        print('Total Parameters:', sum([p.nelement() for p in self.model.parameters()]))

    def iteration(self, g_1, g_2, labels, train=True):
        labels = labels.to(self.device)
        scores = self.model.forward(g_1,g_2)
        loss = self.model.loss(scores,labels)
        acc  = accuracy(scores, labels)
        
        if train:
            self.optim.zero_grad()
            loss.backward()
            self.optim.step()    
        return loss.item(), acc

    def save(self, epoch, save_dir):
        output_path = os.path.join(save_dir, 'ep{:02}.pkl'.format(epoch))
        torch.save(self.model.state_dict(), output_path)

This is GCNNet

class GCNNet(nn.Module):
    def __init__(self, in_dim, hidden_dims, readout, device = "cpu"):
        super(GCNNet, self).__init__()
        
        self.readout = readout
        
        layers = [GCN(in_dim, hidden_dims[0], activation =F.relu)]
        if len(hidden_dims)>=2:
            layers = [GCN(in_dim, hidden_dims[0], activation =F.relu)]
            for i in range(1,len(hidden_dims)):
                if i != len(hidden_dims)-1:
                    layers.append(GCN(hidden_dims[i-1], hidden_dims[i], activation = F.relu))
                else:
                    layers.append(GCN(hidden_dims[i-1], hidden_dims[i], activation =lambda x:x))# no activation in=x, out=x
        else:
            layers = [GCN(in_dim, hidden_dims[0], activation =lambda x:x)]
            
        self.layers = nn.ModuleList(layers)
        self.device = device
        if self.device != "cpu":
            self.cuda()
    
    def forward(self, g):
        h = g.ndata['feat'].to(self.device)# you can replace feat with h, depends on how you construct the graph data in data_prep
        g= g.to(self.device)
        for conv in self.layers:
            h = conv(g, h)
            
            
        g.ndata['feat'] = h
        
        if self.readout == "sum":
            hg = dgl.sum_nodes(g, 'feat')
        elif self.readout == "max":
            hg = dgl.max_nodes(g, 'feat')
        elif self.readout == "mean":
            hg = dgl.mean_nodes(g, 'feat')
        elif self.readout == "attn_pool":
            # global attention pooling
            hg = GlobalAttentionPooling(g, 'feat')
        else:
            hg = dgl.mean_nodes(g, 'feat')  # default readout is mean nodes
            
        return hg#self.MLP_layer(hg)