Making Function Differentiable

I am attempting to build a function that does the following:

• finds the indices of non-zero diagonal values in an input matrix
• finds the indices of non-zero elements in each row where the diagonal is non-zero
• encodes a one-hot vector for each row based on the index
• ensures no row in the output is the same.

My code is below. However, it is not differentiable due to the torch.unique and in-place operations.

Can anyone help me rewrite this relatively simple function in a differentiable way?

``````
def GetClusters(matrix):
attractors = torch.diagonal(matrix, dim1=1, dim2=2).nonzero()

batch_size = matrix.shape[0]

clusters = []

for b in range(batch_size):

batch_attractors = attractors[attractors[:,0] == b, 1]
batch_matrix = matrix[b]
batch_cluster = torch.zeros(batch_attractors.shape[0], batch_matrix.shape[-1])

for i, att in enumerate(batch_attractors):
idx_row_nonzero = batch_matrix[att].nonzero()
batch_cluster[i, idx_row_nonzero] = 1.0

batch_cluster = torch.unique(batch_cluster, dim=0)

clusters.append(batch_cluster)

return clusters
``````

Perhaps you could clarify how you are using `clusters` afterward. There might be a simpler or more direct approach, but would need to know the context where this definition is being used. For example, is this within an attention layer?

If you’re selecting indices directly that’s non-differentiable by definition.

Per @J_Johnson suggestion, you could try an attention based approach, which performs a ‘soft’ argmax selection, which would be differentiable.

I am implementing a differentiable Markov Clustering algorithm. Original Implementation

I have made some progress on the matter but am still being hindered by this process, as it seems to kill the gradients during backpropagation.

In essence, provided with a transition matrix (i.e. a BxNxN matrix in [0,1]), the Markov Clustering algorithm performs successive power operations and normalisation on the matrix until convergence. Basically, one key advantage of this method is that it assumes no particular k for the clusters.

I will be using the clusters from this algorithm to compare to some true labels. As an output, I would hope for a CxN matrix, containing all the unique rows with a non-zero diagonal, which contain the clusters obtained by the algorithm.

I have updated the code, but I still am having problems

`````` ##ORIGINAL
def GetClusters(self, matrix):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

B, N, _ = matrix.shape

all_non_zero_diagonal_rows = []

for b in range(B):
# Get the diagonal elements of the matrix
diagonal_elements = torch.diagonal(matrix[b], 0, -2, -1)

# Create a mask for rows with non-zero diagonal elements

# Extract the rows with non-zero diagonal elements

if non_zero_diagonal_rows.size(0) > 0:
batch_index = torch.full((non_zero_diagonal_rows.size(0), 1), b, dtype=torch.float, device=device, requires_grad=False)
batch_rows_with_index = torch.cat([batch_index, non_zero_diagonal_rows], dim=1)

# Step 1: Remove duplicates within this batch
batch_rows_with_index_float = batch_rows_with_index.float()
num_rows = batch_rows_with_index_float.size(0)

for i in range(num_rows):
for j in range(i + 1, num_rows):
if unique_mask[j] and torch.allclose(batch_rows_with_index_float[i, 1:], batch_rows_with_index_float[j, 1:], atol=1e-5):

all_non_zero_diagonal_rows.append(unique_batch_rows)

# Concatenate all unique rows from all batches
if len(all_non_zero_diagonal_rows) > 0:
all_non_zero_diagonal_rows_tensor = torch.cat(all_non_zero_diagonal_rows, dim=0)
else:
all_non_zero_diagonal_rows_tensor=torch.empty((0, N + 1), device=device)  # No non-zero diagonal rows case

# Step 2: Sort by the batch index
batch_indices = all_non_zero_diagonal_rows_tensor[:, 0]
sorted_indices = torch.argsort(batch_indices)
clusters = all_non_zero_diagonal_rows_tensor[sorted_indices]

# Extract batch indices and cluster data
batch_indices = clusters[:, 0].long()  # Get batch indices (first column)
cluster_data = clusters[:, 1:]         # Get cluster elements (rest of the columns)

# Get unique batch indices directly from batch_indices
unique_batch_indices = torch.unique(batch_indices)

# Initialize dictionary using unique batch indices
grouped_clusters = {batch_idx.item(): [] for batch_idx in unique_batch_indices}

# Loop through each batch
for b in unique_batch_indices:
# Find indices where batch index equals the current batch