Differentiable Loss Function for Indices

Hi, I am trying to build a loss function based on Gaussian distribution. In the calculation, this function uses indices of the matrix which (I think) are breaking the graph. Please can somebody help with this? I am attaching the code for the loss function and the output when I print out the grad_fn of each operation.

Code

class Gaussian(nn.Module):
    def __init__(self, var: float) -> None:
        """
        Initializes Gaussian metric.

        Args:
            var (float): Variance parameter.
        """

        super(Gaussian, self).__init__()
        self.var = var

    def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Computes the Gaussian metric.

        Args:
            preds (torch.Tensor): Predictions.
            target (torch.Tensor): Target data.

        Returns:
            torch.Tensor: Gaussian metric value.
        """
        preds = torch.softmax(preds, dim=1)
        print('Preds:', preds.grad_fn)
        contact_matrix = preds[:, 1, :, :]
        # print('Contact matrix:', contact_matrix.grad_fn)
        contact_matrix = torch.ge(contact_matrix, 0.5).float()
        # print(contact_matrix.int())



        gaussian_metric_value = self.calculate_distances(contact_matrix, self.var)
        # print('Gaussian metric:', gaussian_metric_value.requires_grad)
        return gaussian_metric_value

    def calculate_distances(self, contact_matrix, var):
        """
        Calculate the distance matrix d_{i,j} for a given contact matrix.
        
        Parameters:
        contact_matrix (torch.Tensor): A 2D torch tensor representing contacts (zeros and ones).
        var (float): The variance parameter for the distance calculation.
        
        Returns:
        torch.Tensor: The omega_sum value.
        """

        device = contact_matrix.device
        contact_matrix = contact_matrix.to(device='cpu')
        indices = torch.nonzero(contact_matrix == 1, as_tuple=False).float()
        # print('Indices:', indices.grad)
        indices.requires_grad = True
        indices.retain_grad()
        indices_diff = indices.unsqueeze(1) - indices.unsqueeze(0)
        # print('Indices:', indices.grad)
        
        indices_diff = indices_diff.pow(2)
        # print('Indices diff:', indices_diff.grad_fn)
        sq_diff = indices_diff.sum(dim=2)  # Shape: (num_contacts, num_contacts)
        # print('Sq diff:', sq_diff.grad_fn)
        d_matrix = torch.exp(-sq_diff/(2 * var))  # Shape: (num_contacts, num_contacts)
        # print('D matrix:', d_matrix.grad_fn)
        sum_exp = d_matrix.sum(dim=1)
        # print('Sum exp:', sum_exp.grad_fn)
        omega = 1 / sum_exp
        # print('Omega:', omega.grad_fn)
        omega_sum = omega.mean()
        omega_sum = -omega_sum.to(device=device) # negative sign for minimization
        return omega_sum

Output

Preds: <SoftmaxBackward0 object at 0x14e07ca72230>
Contact matrix: <SliceBackward0 object at 0x14e07ca72230>
Indices: None
Indices: None
Indices diff: <PowBackward0 object at 0x14e07ca72230>
Sq diff: <SumBackward1 object at 0x14e07ca72230>
D matrix: <ExpBackward0 object at 0x14e07ca72230>
Sum exp: <SumBackward1 object at 0x14e07ca72230>
Omega: <MulBackward0 object at 0x14e07ca72230>

Any insights on this would be beneficial.

Thanks

It is expected that indices are not differentiable, do you observe that the gradients produced are not what you expect.

(You can use torch.autograd.gradcheck.gradcheck — PyTorch 2.3 documentation to validate your gradients against numerically computed ones using this function)

Thanks for the reply, but the problem is I don’t know what to expect. However, I can still try to check your suggested method.

I want the whole process to be differentiable so that backpropagation happens properly, which won’t happen as long as there are some non-differentiable parts.

Hey @ptrblck, can you please suggest something in this? I couldn’t find anything regarding it. Any insights would be helpful.

Indices are not usefully differentiable, since their gradients would be zero everywhere and undefined or Inf at the integer changes.
I don’t know what your use case is exactly, but keep in mind that the actual values of contact_matrix (assuming these are floating point numbers) should be differentiable in case you can use the indices tensor to index contact_matrix and perform further computations.