Unsupported sparse MM in CUDA

I have this custom layer:

class CustomFullyConnectedLayer(nn.Module):
    def __init__(self, in_features, out_features, device=None, sparsity = 0.1, diagPos=[], alphaLR=0.01):
        super(CustomFullyConnectedLayer, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.total_permutations = max(in_features, out_features)
        self.diag_length = min(in_features, out_features)
        
        print("Sparsity is: ", sparsity)    
        num_params = in_features * out_features
        req_params = int((1-sparsity) * num_params)
        K = math.ceil(req_params/min(in_features, out_features))

        self.K = K
        self.topkLR = alphaLR
        self.device = device if device else torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self.V = nn.Parameter(torch.empty(self.total_permutations, self.diag_length, device=self.device, dtype=torch.float32, requires_grad=True))
        nn.init.kaiming_uniform_(self.V, a=math.sqrt(5))

        self.alpha = nn.Parameter(torch.empty(self.total_permutations, device=self.device, requires_grad=True))
        nn.init.constant_(self.alpha, 1/self.in_features)
        #pdb.set_trace()
        assert torch.all(self.alpha >= 0)

    def compute_weights(self):
        self.alpha_topk = sparse_soft_topk_mask_dykstra(self.alpha, self.K, l=self.topkLR, num_iter=50).to(self.device)
        non_zero_alpha_indices = torch.nonzero(self.alpha_topk, as_tuple=False).squeeze()
        
        #print("memory after alpha_topk:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))

        if non_zero_alpha_indices.dim() == 0:
            non_zero_alpha_indices = non_zero_alpha_indices.unsqueeze(0) 
        
        WSum = torch.zeros((self.out_features, self.in_features), device=self.device)
        #print("Memory after WSum:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
        
        results = []

        for i in non_zero_alpha_indices:
            #print("Iteration: ", i)
            mask1 = get_mask_pseudo_diagonal_torch((self.out_features, self.in_features), sparsity=0.99967, experimentType="randDiagOneLayer", diag_pos=i)
            #mask1 = mask1.detach()
            #print("Memory after mask1:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))

            V_scaled = self.V[i] * self.alpha_topk[i]
            #print("Memory after V_scaled:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))

            #Using V_scaled make a sparse matrix equivalent to doing torch.diag(self.V_scaled) in the COO format
            n = V_scaled.size(0)
            indices = torch.arange(0, n, device=self.device).unsqueeze(0).repeat(2, 1)  # Create (i, i) indices
            values = V_scaled  # Diagonal values are just V_scaled
            V_scaled_sparse = torch.sparse_coo_tensor(indices, values, (n, n), device=self.device)
            
            with torch.cuda.amp.autocast(enabled=False):
                if self.out_features > self.in_features:
                    #WSum += self.alpha_topk[i] * torch.matmul(mask1, torch.diag(self.V[i]).to(self.device))
                    #WSum += torch.matmul(mask1, torch.diag(V_scaled).to(self.device))
                    WSum += torch.sparse.mm(mask1, V_scaled_sparse.float())
                    #print("Memory after WSum:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
                    #WSum += self.alpha_topk[i] * torch.einsum('ij,j->ij', mask1, self.V[i])
                else:
                    mask1 = mask1.T
                    #WSum += self.alpha_topk[i] * (torch.matmul(mask1, torch.diag(self.V[i])).T.to(self.device))
                    #WSum += torch.matmul(mask1, torch.diag(V_scaled)).T.to(self.device)
                    WSum += torch.sparse.mm(mask1, V_scaled_sparse.float()).T
                    #print("Memory after WSum:{}MB ".format(torch.cuda.memory_allocated()/(1024**2)))
        
        return WSum

    @property
    def weights(self):
        return self.compute_weights()

    def forward(self, x):
        x = x.to(self.device)
        W = self.weights
        #pdb.set_trace()    

        out = F.linear(x, W)
        return out

    def update_alpha_lr(self, new_alpha_lr):
        self.topkLR = new_alpha_lr
        #print("New learning rate for alpha is: ", self.topkLR) 

And my mask is being generated by using the following function:

def get_mask_random_torch(mask_shape, sparsity, device='cuda'):
        num_elements = mask_shape[0] * mask_shape[1]
        num_ones = int(sparsity * num_elements)
        
        # Generate random indices
        indices = torch.randperm(num_elements, device=device)[:num_ones]
        
        # Convert flat indices to 2D indices
        rows = indices // mask_shape[1]
        cols = indices % mask_shape[1]
        
        # Create a sparse tensor using the computed rows and cols
        sparse_indices = torch.stack([rows, cols], dim=0)
        values = torch.ones(num_ones, device=device)
        
        # Create the sparse COO tensor
        sparse_mask = torch.sparse_coo_tensor(sparse_indices, values, size=mask_shape, device=device)
        
        return sparse_mask

For the time being, let us ignore the implementation of sparse_soft_topk_mask_dykstra.

When I use this layer in a network, I am getting the following error:

File "/gpfs/fs2/scratch/atyagi2/pytorch-image-models/train.py", line 1067, in _backward
    loss_scaler(
  File "/gpfs/fs2/scratch/atyagi2/pytorch-image-models/timm/utils/cuda.py", line 62, in __call__
    self._scaler.scale(loss).backward(create_graph=create_graph)
  File "/scratch/atyagi2/vitCifar/lib/python3.9/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/scratch/atyagi2/vitCifar/lib/python3.9/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/scratch/atyagi2/vitCifar/lib/python3.9/site-packages/torch/autograd/graph.py", line 768, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: mat2_.is_sparse() INTERNAL ASSERT FAILED at "../aten/src/ATen/native/sparse/cuda/SparseMatMul.cu":785, please report a bug to PyTorch.

If I understand correctly, torch.sparse.mm supports backpropagation when both matrices are of COO format. So, I am unsure why I should be getting this error?

My pytorch version is: 2.4.0+cu121 and which nvcc returns /software/cuda/12.1/bin/nvcc

Am I right to assume that the error is happening because of using torch.sparse.mm? Any inputs on how this can be avoided?