Why is my model (using GATv2conv) stuck on its loss? (Predicting same value for every node)

I am trying to create a Geometric Deep Learning model using Pytorch. I have around 5000 graphs, which are split into a training set, validation set and test set. Every graph has one ‘correct’ node and is fully connected. There are different node and edge features used. Before loading the graphs into the model, they are split into batches.

This is an example of how a batch of the training loader looks:

DataBatch(x=[1404, 5], edge_index=[2, 14700], edge_attr=[2], y=[1], batch=[1404], ptr=[65])

Where this is an example of the x and y and their shape:

torch.Size([1406, 5])
tensor([[ 0.7833,  0.1309, -0.0708, -0.0496,  0.7143],
    [ 0.9170, -0.0228, -0.2538,  0.0542, -0.0476],
    [ 0.8326, -0.2361, -0.0492, -0.1727, -0.9048],
    ...,
    [ 0.9281,  0.0207, -0.1396,  0.0936,  0.1429],
    [ 0.9427,  0.8991, -0.1633,  0.0857, -1.0000],
    [ 0.8982,  0.0480, -0.4320,  0.0886, -0.2381]])
torch.Size([1406])
tensor([0., 0., 0.,  ..., 1., 0., 0.])

This is the model I am currently using:

import time

node_features = ['smooth_x', 'smooth_y', 'keeper', 'players_between', 'team']
edge_features = ['same_team', 'distance']

# Create the train, validation and test datasets
train_dataset, train_loader = create_train_batches(corner_graphs_train['Graph'], node_features, edge_features)
val_dataset, test_dataset, val_loader, test_loader = create_val_test_batches(corner_graphs_val['Graph'], corner_graphs_test['Graph'], node_features, edge_features)


class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GATv2Conv(train_dataset.num_features, 256, heads=8, dropout=0.3)

        self.conv2 = GATv2Conv(8 * 256, 256, heads=8, dropout=0.3)
        
        # Add dropout layer
        self.dropout = torch.nn.Dropout(p=0.3)

        self.conv3 = GATv2Conv(8* 256, 1)


    def forward(self, x, edge_index, batch):
        # Layer 1
        x = F.leaky_relu(self.conv1(x, edge_index))

        # Apply dropout
        x = self.dropout(x)

        # Layer 2
        x = F.leaky_relu(self.conv2(x, edge_index))

        # Apply dropout
        x = self.dropout(x)

        # Layer 3
        x = F.leaky_relu(self.conv3(x, edge_index))

        # From flat list of nodes to 64 lists of nodes belonging to the same graph based on batch
        x, mask = to_dense_batch(x, batch)
        x = x.squeeze(-1)
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
model = Net().to(device)
#loss_op = torch.nn.CrossEntropyLoss()
loss_op = torch.nn.BCEWithLogitsLoss()
#optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=5e-4)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=5e-4)


def train():
    model.train()

    total_loss = 0
    data_progress = 0
    for data in train_loader:
        #print('Data progress: {}/{}'.format(data_progress, len(train_loader)))
        optimizer.zero_grad()
        output = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))
        if data_progress == 1:
            print('Example prediction')
            print(list(output[0].float()))
        data_progress += 1
        truth, mask = to_dense_batch(data.y.to(device), data.batch.to(device))
        #print(output)
        #print(truth)
        loss = loss_op(output, truth)                 
        total_loss += loss.item() * data.num_graphs
        loss.backward()
        optimizer.step()
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def val(loader):
    model.eval()

    total_loss_val = 0
    ys, preds = [], []
    for data in loader:
        truth, mask = to_dense_batch(data.y.to(device), data.batch.to(device))
        ys.append(truth.float().cpu())
        out = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))
        preds.append((out > 0).float().cpu())

        loss = loss_op(out, truth)
        total_loss_val += loss.item() * data.num_graphs

    y, pred = torch.cat(ys, dim=0).numpy(), torch.cat(preds, dim=0).numpy()

    return f1_score(y, pred, average='samples'), total_loss_val / len(loader.dataset)

@torch.no_grad()
def test(loader):
    model.eval()

    ranks = []
    for data in loader:
        out = model(data.x.to(device), data.edge_index.to(device), data.batch.to(device))
        pred_probs = list(torch.softmax(out[0], dim=0).flatten().cpu().numpy())
        pred_ranks = pd.DataFrame({'probabilities': pred_probs, 'y': data.y})
        pred_ranks.sort_values('probabilities', ascending=False, inplace=True)
        pred_ranks['Rank'] = range(1, len(pred_ranks) + 1)
        rank_predicted = pred_ranks[pred_ranks['y'] == 1]['Rank'].values[0]

        ranks.append(rank_predicted)

    return np.mean(ranks)


times = []
for epoch in range(1, 200):
    start = time.time()
    loss = train()
    temp_loss_list.append(loss)
    print(f'Time: {time.time() - start:.4f}s')
    val_f1, loss_val = val(val_loader)
    print(f'Time: {time.time() - start:.4f}s')
    test_avg_rank = test(test_loader)
    print(f'Time: {time.time() - start:.4f}s')
    
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val_f1: {val_f1:.4f}, '
        f'Val Loss: {loss_val:.4f}, Test Avg Rank: {test_avg_rank:.4f}')
    
    # If the loss is not decreasing two times in a row, break the loop
    if epoch > 1 and temp_loss_list[-1] > temp_loss_list[-2]:
        if epoch > 2 and temp_loss_list[-2] > temp_loss_list[-3]:
            if epoch > 3 and temp_loss_list[-3] > temp_loss_list[-4]:
                if epoch > 4 and temp_loss_list[-4] > temp_loss_list[-5]:
                    if epoch > 5 and temp_loss_list[-5] > temp_loss_list[-6]:
                        break

    times.append(time.time() - start)
print(f"Median time per epoch: {torch.tensor(times).median():.4f}s")
        

Which gives the following output, where it seems like the model does not learn anything and most of the times predicts the same value for every 22 nodes in the graph.

Epoch: 001, Loss: 0.3710, Val_f1: 0.0000, Val Loss: 0.2000, Test Avg Rank: 11.2676
[-0.0572, -0.0553, -0.0580, -0.0571, -0.0573, -0.0556, -0.0582, -0.0534, -0.0637, -0.0579, -0.0575, -0.0573, -0.0566, -0.0561, -0.0570, -0.0580, -0.0579, -0.0579, -0.0579, -0.0579, -0.0567, -0.0579, -0.0530]

Epoch: 002, Loss: 0.1981, Val_f1: 0.0000, Val Loss: 0.1995, Test Avg Rank: 11.5412
[-2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348, -2.1348]

Epoch: 003, Loss: 0.1978, Val_f1: 0.0000, Val Loss: 0.1990, Test Avg Rank: 11.9824
[-3.7584, -3.7584, -3.7535, -3.7191, -3.7683, -3.7191, -4.0502, -3.7632, -3.7683, -3.7683, -3.7587, -3.7588, -3.7621, -3.7621, -4.8226, -3.7588, -3.7191, -3.7683, -3.7585, -3.7587, -3.7587, -3.7191]

Epoch: 004, Loss: 0.1973, Val_f1: 0.0000, Val Loss: 0.1989, Test Avg Rank: 12.0324
[-2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.8775, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917, -2.0917]

Epoch: 005, Loss: 0.1971, Val_f1: 0.0000, Val Loss: 0.2002, Test Avg Rank: 11.8971
[-2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -2.6774, -3.2843, -2.6774, -2.6774, -2.6774, -3.3673, -2.6774, -2.6774, -2.6774]

Epoch: 006, Loss: 0.1971, Val_f1: 0.0000, Val Loss: 0.1990, Test Avg Rank: 12.1912
[-3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155, -3.0155]

Epoch: 007, Loss: 0.1961, Val_f1: 0.0000, Val Loss: 0.1986, Test Avg Rank: 12.2324
[-3.1592, -3.1592, -3.1585, -3.1592, -3.1521, -3.1521, -3.1585, -3.1585, -3.1592, -3.1521, -3.1592, -3.1592, -3.1592, -3.1592, -3.1592, -3.1592, -3.1585, -3.1592, -3.1589, -5.4095, -3.1592, -3.1592]

I have really tried everything in terms of layer architectures, different learning rates, feature combinations, optimizers, but nothing seems to be working.

It would be really appreciated if someone could point out what I have done wrong.

If there is anything that needs to be added to this question (x.shapes at moments, gradients (they are extremely low), other info) or anything else that would help answering the question, please let me know.

Your batch example says y=[1], but in your example shape, the shape is y=[1406]. Is it possible that the Batch that is loaded by your dataloader only returns 1 value for the ground truth y, but it should actually return the one-hot encoding?

Hi Ruben,

Thanks for the response. There was a small error in the preprocessing of the target labels. In the rest of my code I took y[0] which gave the correct targets. I have now fixed it so that I can just take y. But as they already were correct, it does not solve the problem.