Poor performance with Masking/Sparcity in linear layers

I have a model that uses linear layers where most of the parameters are zero. I’d like to maintain the structure of the linear layers while reducing the effective number of trainable parameters. I’ve been experimenting with sparse tensors and masked tensors, but I’m not getting the performance boost I’d hoped from my simple experiments. Does anyone have any suggestions for how to use these properly?

More details here… Here are a few simple modules that capture the essence of what I need. The first is just a classic NxN linear module. The second uses masked tensors to reduce the number of non-zero parameters to 1/10 of the fully connected layer. The third attempts to do the same thing using sparsity.

class Classic_Linear_Module(nn.Module): 
    def __init__(self, N=256, device="cpu"): 
        super().__init__()
        self.device = device
        self.Linear = nn.Linear(N, N, bias=False, device=device)
        self.Linear.weight = torch.nn.init.uniform_(self.Linear.weight, a=0.0, b=1.0)
        self = self.to(device)
        
    def forward(self, x):
        return self.Linear(x)
    
class Masked_Linear_Module(nn.Module): 
    def __init__(self, N=256, device="cpu"): 
        super().__init__()
        self.device = device
        Linear = torch.rand(N, N, device=device)
        mask = Linear <= .1
        Masked_Linear = masked_tensor(Linear, mask).to(device)
        self.Masked_Linear = torch.nn.Parameter(Masked_Linear)
        self = self.to(device)
        
    def forward(self, x):
        Linear = self.Masked_Linear.to_tensor(0).to(self.device)
        return torch.matmul(x, Linear)
    
class Sparse_Linear_Module(nn.Module): 
    def __init__(self, N=256, device="cpu"): 
        super().__init__()
        Linear = torch.rand(N, N)
        mask = Linear <= .1
        Sparse_Linear = masked_tensor(Linear, mask).to_tensor(0.0).to_sparse()
        self.Sparse_Linear = torch.nn.Parameter(Sparse_Linear)
        
    def forward(self, x):
        return torch.matmul(x, self.Sparse_Linear)

Note that the masked tensor first converts to an ordinary tensor before applying matrix multiplication because matmul and other matrix operations don’t appear to be implemented yet (maybe there’s a better way?).

I tested out the performance of the forward pass and backward gradient decent step with a simple loss function by simply applying it to a big batch (batch_size=256, N=2048, 100 epochs) and measuring the speed. The masked case was about 50% lower than the fully connected layer despite having one-tenth of the parameters, and the sparse implementation was about 10 times slower:

Classic training time: 1.2331788539886475
Masked training time: 1.8363850116729736
Sparse training time: 12.60797905921936

Similar results on gpu, but even more pronounced (masked is about 10 times slower and sparse about 100 times slower than fully connected).

Any ideas on how this could be implemented better?