Hi, I am trying to build a loss function based on Gaussian distribution. In the calculation, this function uses indices of the matrix which (I think) are breaking the graph. Please can somebody help with this? I am attaching the code for the loss function and the output when I print out the grad_fn
of each operation.
Code
class Gaussian(nn.Module):
def __init__(self, var: float) -> None:
"""
Initializes Gaussian metric.
Args:
var (float): Variance parameter.
"""
super(Gaussian, self).__init__()
self.var = var
def forward(self, preds: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""
Computes the Gaussian metric.
Args:
preds (torch.Tensor): Predictions.
target (torch.Tensor): Target data.
Returns:
torch.Tensor: Gaussian metric value.
"""
preds = torch.softmax(preds, dim=1)
print('Preds:', preds.grad_fn)
contact_matrix = preds[:, 1, :, :]
# print('Contact matrix:', contact_matrix.grad_fn)
contact_matrix = torch.ge(contact_matrix, 0.5).float()
# print(contact_matrix.int())
gaussian_metric_value = self.calculate_distances(contact_matrix, self.var)
# print('Gaussian metric:', gaussian_metric_value.requires_grad)
return gaussian_metric_value
def calculate_distances(self, contact_matrix, var):
"""
Calculate the distance matrix d_{i,j} for a given contact matrix.
Parameters:
contact_matrix (torch.Tensor): A 2D torch tensor representing contacts (zeros and ones).
var (float): The variance parameter for the distance calculation.
Returns:
torch.Tensor: The omega_sum value.
"""
device = contact_matrix.device
contact_matrix = contact_matrix.to(device='cpu')
indices = torch.nonzero(contact_matrix == 1, as_tuple=False).float()
# print('Indices:', indices.grad)
indices.requires_grad = True
indices.retain_grad()
indices_diff = indices.unsqueeze(1) - indices.unsqueeze(0)
# print('Indices:', indices.grad)
indices_diff = indices_diff.pow(2)
# print('Indices diff:', indices_diff.grad_fn)
sq_diff = indices_diff.sum(dim=2) # Shape: (num_contacts, num_contacts)
# print('Sq diff:', sq_diff.grad_fn)
d_matrix = torch.exp(-sq_diff/(2 * var)) # Shape: (num_contacts, num_contacts)
# print('D matrix:', d_matrix.grad_fn)
sum_exp = d_matrix.sum(dim=1)
# print('Sum exp:', sum_exp.grad_fn)
omega = 1 / sum_exp
# print('Omega:', omega.grad_fn)
omega_sum = omega.mean()
omega_sum = -omega_sum.to(device=device) # negative sign for minimization
return omega_sum
Output
Preds: <SoftmaxBackward0 object at 0x14e07ca72230>
Contact matrix: <SliceBackward0 object at 0x14e07ca72230>
Indices: None
Indices: None
Indices diff: <PowBackward0 object at 0x14e07ca72230>
Sq diff: <SumBackward1 object at 0x14e07ca72230>
D matrix: <ExpBackward0 object at 0x14e07ca72230>
Sum exp: <SumBackward1 object at 0x14e07ca72230>
Omega: <MulBackward0 object at 0x14e07ca72230>
Any insights on this would be beneficial.
Thanks