I am implementing a differentiable Markov Clustering algorithm. Original Implementation

I have made some progress on the matter but am still being hindered by this process, as it seems to kill the gradients during backpropagation.

In essence, provided with a transition matrix (i.e. a BxNxN matrix in [0,1]), the Markov Clustering algorithm performs successive power operations and normalisation on the matrix until convergence. Basically, one key advantage of this method is that it assumes no particular k for the clusters.

I will be using the clusters from this algorithm to compare to some true labels. As an output, I would hope for a CxN matrix, containing all the unique rows with a non-zero diagonal, which contain the clusters obtained by the algorithm.

I have updated the code, but I still am having problems

```
##ORIGINAL
def GetClusters(self, matrix):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
B, N, _ = matrix.shape
all_non_zero_diagonal_rows = []
for b in range(B):
# Get the diagonal elements of the matrix
diagonal_elements = torch.diagonal(matrix[b], 0, -2, -1)
# Create a mask for rows with non-zero diagonal elements
non_zero_diagonal_mask = diagonal_elements > self.pruning_threshold
# Extract the rows with non-zero diagonal elements
non_zero_diagonal_rows = matrix[b][non_zero_diagonal_mask]
if non_zero_diagonal_rows.size(0) > 0:
batch_index = torch.full((non_zero_diagonal_rows.size(0), 1), b, dtype=torch.float, device=device, requires_grad=False)
batch_rows_with_index = torch.cat([batch_index, non_zero_diagonal_rows], dim=1)
# Step 1: Remove duplicates within this batch
batch_rows_with_index_float = batch_rows_with_index.float()
num_rows = batch_rows_with_index_float.size(0)
unique_mask = torch.ones(num_rows, dtype=torch.bool, device=device)
for i in range(num_rows):
if unique_mask[i]:
for j in range(i + 1, num_rows):
if unique_mask[j] and torch.allclose(batch_rows_with_index_float[i, 1:], batch_rows_with_index_float[j, 1:], atol=1e-5):
unique_mask[j] = False
unique_batch_rows = batch_rows_with_index[unique_mask]
all_non_zero_diagonal_rows.append(unique_batch_rows)
# Concatenate all unique rows from all batches
if len(all_non_zero_diagonal_rows) > 0:
all_non_zero_diagonal_rows_tensor = torch.cat(all_non_zero_diagonal_rows, dim=0)
else:
all_non_zero_diagonal_rows_tensor=torch.empty((0, N + 1), device=device) # No non-zero diagonal rows case
# Step 2: Sort by the batch index
batch_indices = all_non_zero_diagonal_rows_tensor[:, 0]
sorted_indices = torch.argsort(batch_indices)
clusters = all_non_zero_diagonal_rows_tensor[sorted_indices]
# Extract batch indices and cluster data
batch_indices = clusters[:, 0].long() # Get batch indices (first column)
cluster_data = clusters[:, 1:] # Get cluster elements (rest of the columns)
# Get unique batch indices directly from batch_indices
unique_batch_indices = torch.unique(batch_indices)
# Initialize dictionary using unique batch indices
grouped_clusters = {batch_idx.item(): [] for batch_idx in unique_batch_indices}
# Loop through each batch
for b in unique_batch_indices:
# Find indices where batch index equals the current batch
mask = batch_indices == b.item()
# Extract cluster data for the current batch
clusters_for_batch = cluster_data[mask].permute(1, 0)
# Append to the dictionary
grouped_clusters[b.item()] = clusters_for_batch
return grouped_clusters
```