# Differentiable Loss Function for Indices

Hi, I am trying to build a loss function based on Gaussian distribution. In the calculation, this function uses indices of the matrix which (I think) are breaking the graph. Please can somebody help with this? I am attaching the code for the loss function and the output when I print out the `grad_fn` of each operation.

Code

``````class Gaussian(nn.Module):
def __init__(self, var: float) -> None:
"""
Initializes Gaussian metric.

Args:
var (float): Variance parameter.
"""

super(Gaussian, self).__init__()
self.var = var

def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the Gaussian metric.

Args:
preds (torch.Tensor): Predictions.
target (torch.Tensor): Target data.

Returns:
torch.Tensor: Gaussian metric value.
"""
preds = torch.softmax(preds, dim=1)
contact_matrix = preds[:, 1, :, :]
contact_matrix = torch.ge(contact_matrix, 0.5).float()
# print(contact_matrix.int())

gaussian_metric_value = self.calculate_distances(contact_matrix, self.var)
return gaussian_metric_value

def calculate_distances(self, contact_matrix, var):
"""
Calculate the distance matrix d_{i,j} for a given contact matrix.

Parameters:
contact_matrix (torch.Tensor): A 2D torch tensor representing contacts (zeros and ones).
var (float): The variance parameter for the distance calculation.

Returns:
torch.Tensor: The omega_sum value.
"""

device = contact_matrix.device
contact_matrix = contact_matrix.to(device='cpu')
indices = torch.nonzero(contact_matrix == 1, as_tuple=False).float()
indices_diff = indices.unsqueeze(1) - indices.unsqueeze(0)

indices_diff = indices_diff.pow(2)
sq_diff = indices_diff.sum(dim=2)  # Shape: (num_contacts, num_contacts)
d_matrix = torch.exp(-sq_diff/(2 * var))  # Shape: (num_contacts, num_contacts)
sum_exp = d_matrix.sum(dim=1)
omega = 1 / sum_exp
omega_sum = omega.mean()
omega_sum = -omega_sum.to(device=device) # negative sign for minimization
return omega_sum
``````

Output

``````Preds: <SoftmaxBackward0 object at 0x14e07ca72230>
Contact matrix: <SliceBackward0 object at 0x14e07ca72230>
Indices: None
Indices: None
Indices diff: <PowBackward0 object at 0x14e07ca72230>
Sq diff: <SumBackward1 object at 0x14e07ca72230>
D matrix: <ExpBackward0 object at 0x14e07ca72230>
Sum exp: <SumBackward1 object at 0x14e07ca72230>
Omega: <MulBackward0 object at 0x14e07ca72230>
``````

Any insights on this would be beneficial.

Thanks

It is expected that indices are not differentiable, do you observe that the gradients produced are not what you expect.

I don’t know what your use case is exactly, but keep in mind that the actual values of `contact_matrix` (assuming these are floating point numbers) should be differentiable in case you can use the `indices` tensor to index `contact_matrix` and perform further computations.