How to properly use Cross Entropy with weights for node classification

I am programming my first GNN and want to do a node classification. However I feel like my predictions do not get trained properly. I think it has to do with the Cross Entropy Loss.

I want to perform a binary classification on every node in my Graph. I have two classes, 0 and 1. About 75% of the nodes belong to class 0 and 25% to class 1.

This is the Network:

import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class Network(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(Network, self).__init__()
        self.GCNConv1 = GCNConv(in_channels=6, out_channels = hidden_channels)
        self.GCNConv2 = GCNConv(in_channels=hidden_channels, out_channels = hidden_channels)
        self.GCNConv3 = GCNConv(in_channels = hidden_channels, out_channels = 2) 
        
    def forward(self, x, edge_index, edge_weight):
        
        edge_weight = 7.91 - edge_weight
        
        h = self.GCNConv1(x, edge_index, edge_weight)
        h = F.relu(h)
        h = self.GCNConv2(h, edge_index, edge_weight)
        h = F.relu(h)
        h = self.GCNConv3(h, edge_index, edge_weight)
        return h

And the training process looks like this:
preds has shape: preds.shape = [num_nodes,2]
batch.y has shape: batch.y.shape = [num_nodes] , batch.y holds 0 or 1 depending on the class

import torch_geometric

dataset = SimulatinData("")
dataset = dataset.shuffle() 
train_set = dataset[:25000]
batch_size = 100

network = Network(hidden_channels = 32).double()
train_loader = torch_geometric.data.DataLoader(train_set, batch_size = batch_size, shuffle = True)
optimizer = torch.optim.Adam(network.parameters(), lr=0.01)

for epoch in range(10):
    for batch in train_loader:
        
        network.train()
        preds = network(batch.x, batch.edge_index, batch.edge_weights)
        
        class_weight = torch.tensor([1,4]).double()
        loss = F.binary_cross_entropy_with_logits(preds, F.one_hot(batch.y).double(), weight = class_weight)

        loss.backward() # Calculate Gradients
        optimizer.step() # Update Weights
        optimizer.zero_grad() # Delete Gradients from previous loop

No matter for how many epochs I train my network, the predictions always look like this:

epoch:0
tensor([[ 0.6972, -1.0346],
        [ 0.6972, -1.0346],
        [ 0.6972, -1.0346],
        ...,
        [ 0.6972, -1.0346],
        [ 0.6972, -1.0346],
        [ 0.6972, -1.0346]], dtype=torch.float64, grad_fn=<AddBackward0>)
epoch:1
tensor([[ 0.7204, -0.9803],
        [ 0.7204, -0.9803],
        [ 0.7204, -0.9803],
        ...,
        [ 0.7204, -0.9803],
        [ 0.7204, -0.9803],
        [ 0.7204, -0.9803]], dtype=torch.float64, grad_fn=<AddBackward0>)
epoch:100
tensor([[ 0.9839, -0.7851],
        [ 0.9839, -0.7851],
        [ 0.9839, -0.7851],
        ...,
        [ 0.9839, -0.7851],
        [ 0.9839, -0.7851],
        [ 0.9839, -0.7851]], dtype=torch.float64, grad_fn=<AddBackward0>)

So it just assmues every node belongs to category 0. So the network is “correct” for 75% of the nodes but that is of course not the prediction I would like to have. No matter what class_weight I choose for the cross entropy, I always get these predictions where the GNN just says every node belongs to 0. Do you know what could go wrong in my code?

The weight argument in nn.BCEWithLogitsLoss and F.binary_cross_entropy_with_logits gives the weight values per sample, not per class. From the docs:

weight (Tensor, optional) – a manual rescaling weight given to the loss of each batch element. If given, has to be a Tensor of size nbatch.

so you could need to calculate this tensor for each batch using the current targets (or use the pos_weight if applicable).

Thanks for your answer. Actually, my code is computing the exact class_weight each time: class_weight = torch.tensor([1,batch.num_nodes/(batch_size*2)]).double(), where batch.num_nodes/(batch_size*2) is roughly 4. But I still have the same problem.

It could be that my model is just bad, that’s fine. But the fact that no matter how much I change the weights (for example class_weights = torch.tensor([1, 99999]) it always predicts every node as 0.That makes me think the weights are not passed properly. Shouldn’t the prediction change to predict every node as a 1 with that extreme class_weight?

For clarification:

the predictions of my network are:

In[1]: preds = network(batch.x, batch.edge_index, batch.edge_weights)
      preds

Out[1]:
tensor([[ 0.9925, -1.0165], <-- all nodes predicted as class 0
        [ 0.9925, -1.0165],
        [ 0.9925, -1.0165],
        ...,
        [ 0.9925, -1.0165],
        [ 0.9925, -1.0165],
        [ 0.9925, -1.0165]], dtype=torch.float64, grad_fn=<AddBackward0>)

And my labels are:

In[2]: F.one_hot(batch.y)
Out[2]:
tensor([[1, 0], <-- first node is class 0
        [1, 0],
        [0, 1], <-- third node is class 1
        ...,
        [0, 1],
        [1, 0],
        [1, 0]])

Now I want to calculate the the loss and I want to punish the system if it makes false predictions on the nodes from category 1. So is this the right way? Because no matter how I change the weight_class, the prediction stays the same…

weight_class = torch.tensor([1,4])
loss = F.binary_cross_entropy_with_logits(preds, F.one_hot(batch.y).double(), weight = class_weight)