Normalized cut loss function, update network weights

Hi

I’m interested in implementing a normalized cut criteria as a loss function. I want this normalized cut loss to be incorporated into the back-propagation. For example, given the logits output of a network, pushed through a F.softmax(logits) layer, I want to be able to apply the normalized cut criteria to the predicted labels of the softmax output. I’m using Deep Graph Library with the Pytorch backend.

See below for my implementation of the NCuts loss. I want to compute the gradient of the loss with respect to the network weights and parameters. Other forums mentioned that Pytorch does not explicitly compute the one-hot encoding matrix for losses, but that the scatter_ method might be of help? I’m not sure how to proceed.

import torch
import torch.nn.functional as F

def normalized_cut(graph, logits):
        
    """
    Parameters:
    - - - - - - - - -
    graph: DGL graph
        graph structure of data
    logits: torch tensor, float
        output of network

    Returns:
    - - - - 
    loss: torch float tensor
        ncut loss value 
    """

    A = graph.adjacency_matrix()
    d = torch.sparse.sum(A, dim=0)

    # compute maximum-probability class for each node
    max_idx = torch.argmax(F.softmax(logits), 1, keepdim=True)

    # initialize one-hot encoding matrix
    one_hot = torch.FloatTensor(logits.shape)
    one_hot.zero_()
    one_hot.scatter_(0, max_idx, 1)

    assoc = torch.matmul(one_hot.t(), torch.matmul(A, one_hot))
    degree = torch.matmul(d, one_hot)

    loss = torch.nansum(assoc.diag() / degree)

    return -loss

Any help is appreciated. Thanks.

*** EDIT ***

One work-around would be to relax the one-hot discrete constraint as noted in this multi-class spectral clustering paper – that is, rather than working with the hot encoded matrix, we work with the softmax probabilities, such that we can retain the previous gradient computation steps?