Multi-label Graph NN not learning when using custom loss function

Hello all, I am trying to create a Graph Multi-label NN. The context is to predict the side-effects of drugs on the human body. I am passing in a graph (representing the drug), and my labels are the side-effects it causes encoded into a binary tensor. I am also using a custom loss function that uses word2vec to generate a numerical value representing the degree of similarity between two words. The problem I am facing is that my NN is not learning.

My model:

class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(9, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, nClasses)
        self.classifier = torch.nn.Sequential(                           
            torch.nn.Sigmoid()
        )

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        x = self.classifier(x)
        
        return x

My loss function:

from word2vec_ISOLATED import word2vecModel
def custom_loss_function(output, target):    
    res = []
    for graph_no in range(len(output)): # Iterate through batch
        currOutput = output[graph_no] # Get the output labels
        currTarget = target[graph_no] # Get the target labels

        # For both output and target labels, convert them into the actual strings/words and remove duplicates
        outputWords, targetWords = set(), set()
        for i, val in enumerate(currOutput):
            if val == 0: continue
            sideEffect = sideEffectIDMap[i].split()
            for s in sideEffect: outputWords.add(s)
        for i, val in enumerate(currTarget):
            if val == 0: continue
            sideEffect = sideEffectIDMap[i].split()
            for s in sideEffect: targetWords.add(s)
            
        # For each word in output, compare it to all words in target and save the highest similarity. 
        # 'word2vecModel.wv.similarity(w1,w2)' is the method that uses the word2vec NN to get the similarity.
        currSimilarities = []
        for w1 in outputWords:
            currSimilarities.append(max([word2vecModel.wv.similarity(w1, w2) for w2 in targetWords], default=0))

        # For this graph, save the average similarity of all words        
        res.append(torch.mean(torch.Tensor(currSimilarities)))
    
    # Average across all graphs in the batch, and subtract from 1 as we want the 'loss' instead of the 'similarity'.    
    finalRes = (1-torch.mean(torch.Tensor(res)))**2
    return finalRes

My Training Step:

def train(loader):
    model.train()
    for data in loader:  # Iterate in batches over the training dataset.
        out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
        out.detach().apply_(lambda x: 1 if x >= threshold else 0) # Convert to binary labels depending on if the probability output by sigmoid crosses a certain threshold
        loss = loss_fn(out, data.y)  # Compute the loss. loss_fn = custom_loss_function
        loss.requires_grad = True # Is this step correct?
        loss.backward()  # Derive gradients.
        optimizer.step()  # Update parameters based on gradients.
        optimizer.zero_grad()  # Clear gradients.

Few sources of error that I think may be the cause of the NN not learning, but am unsure:

  1. I saw on other posts about custom loss functions that it must be differentiable. But if I am using a entire different NN, and whats more converting the integers to strings and back, how do I do this? I’m also not sure if the step loss.requires_grad=True is correct?
  2. When testing with a regular loss function (e.g. CrossEntropyLoss), it also does not learn. I have also tried playing around with hyperparameters to no avail. Could it be something wrong with my layers?
  3. Instead of doing the binary conversion in the train step (the lambda), I can use the values as weights in the loss function. However, I don’t think this should be the root cause of the NN not learning right?

I’m still very new to PyTorch so I apologise in advance for any silly mistakes. Any and all help would be very greatly appreciated, thank you :sweat_smile: