Training with sparse semi-structured objects

Hi,
The other day I was checking out the sparse semi-structured tensors. While evaluation was possible I could not make training work. Should I be able to make it work? If not, is this a planned feature?

If it should be workable, please point out to me where I went wrong! Here’s some context on what I was working with.

MNIST dataloader:

from torchvision.transforms import v2 as T
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
import torch
from pathlib import Path

test_dataloader = DataLoader(
    MNIST(
        root=Path('data'),
        train=False,
        download=True,
        transform=T.Compose([
            T.ToImage(),
            T.ToDtype(torch.float32, scale=True),
            T.Normalize(mean=(0.1307,), std=(0.3081,)),
            T.ToDtype(torch.float16)
        ])
    ),
    batch_size=8
)

Wrapper class for the SparseSemiStructuredTensor:

import torch
from torch.sparse import to_sparse_semi_structured
from torch import nn


class SparseLinear(nn.Module):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        density: float = 0.1,
    ):
        super(SparseLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features


        self.weight = nn.Parameter(to_sparse_semi_structured((torch.rand(self.out_features, self.in_features) < density).half().cuda()))
        self.bias = None

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return F.linear(input, self.weight, self.bias)

Initialize model and try to process the first batch:

from torch.optim import SGD

model = nn.Sequential(
        nn.Flatten(),
        nn.Linear(in_features=28*28, out_features=512, dtype=torch.float16),
        nn.ReLU(),
        SparseLinear(512, 256, density=0.1, bias=False),
        nn.ReLU(),
        nn.Linear(in_features=256, out_features=10, dtype=torch.float16),
        nn.Softmax(dim=1)
)

model = model.to(torch.device(0))

loss_function = nn.CrossEntropyLoss()
optimizer = SGD(params=model.parameters(), lr=1e-4)

x, y_true = next(iter(test_dataloader))

optimizer.zero_grad()
y_pred = model(x.to(device))
loss = loss_function(y_pred, F.one_hot(y_true, num_classes=10).half().to(device))
loss.backward()

This however results in an error: NotImplementedError: SparseSemiStructuredTensorCUTLASS matmul: operation is not supported

Hi Bendeguz!

If I understand you correctly, you can perform the forward pass through your
SparseLinear, but not the backward pass.

It would appear that this bit sparse-semi-structured functionality is not (yet)
implemented so you won’t be able to backpropagate through SparseLinear
(yet).

Pytorch’s sparsity (and semi-structured sparsity) is a work in progress, with
a lot of missing functionality (and some bugs). If this is important to you, you
could log a feature-request issue to pytorch’s github.

(I would speculate that this would not be a priority. I imagine that the target use
case is that you train non-sparse weights, prune, and convert to semi-structured
sparse weights to gain some speed-up during inference. It’s also not clear to me
how much logical sense it would make to compute gradients and train in a way
that preserves semi-structured sparsity.)

You might also want to post which gpu and version of pytorch you are using.

Best.

K. Frank

That is correct, this is my issue, I will log a feature-request on github.

My usecase is training “ultra-sparse networks” (only ~0.1% of nonzero elements in a tensor). As far as I know currently this can only be done by masking dense tensors. I am hoping to see decreased training times and memory footprint by using sparse semi-structured tensors instead of masked dense tensors.

Hi Bendeguz!

My guess then is that you do not want to use sparse-semi-structured tensors, but
rather pytorch’s (more mature) sparse_coo tensors. Even here, support is spotty,
but there may be enough for you to implement your SparseLinear layer.

Consider:

>>> import torch
>>> torch.__version__
'2.3.0'
>>> _ = torch.manual_seed (2024)
>>> tf = torch.randn (4, 4)
>>> ts = tf.to_sparse()
>>> ts.requires_grad = True
>>> v = torch.randn (4, 1)
>>> ls = torch.sparse.mm (ts, v).sum()
>>> ls.backward()
>>> ts.grad
tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
                       [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
       values=tensor([-0.4439,  0.1888,  0.5986,  0.6458, -0.4439,  0.1888,
                       0.5986,  0.6458, -0.4439,  0.1888,  0.5986,  0.6458,
                      -0.4439,  0.1888,  0.5986,  0.6458]),
       size=(4, 4), nnz=16, layout=torch.sparse_coo)

You can backpropagate through torch.sparse.mm(), which is the key piece of
functionality you need.

Best.

K. Frank

1 Like

Hi Frank!

Thank you for this, I think I can work with the sparse COO structure type, after some experimentation I have found that indeed torch.sparse.mm() is required to produce sparse gradients. Also it seems like not all built-in loss functions are workable, while for example torch.nn.MSELoss() and torch.nn.L1Loss() looks promising, torch.nn.CrossEntropyLoss() produces zero gradients after backpropagation!

Hi Bendeguz!

Works for me:

>>> import torch
>>> print (torch.__version__)
2.3.0
>>>
>>> _ = torch.manual_seed (2024)
>>>
>>> tf = torch.randn (4, 4)
>>> ts = tf.to_sparse()
>>> ts.requires_grad = True
>>>
>>> inp = torch.randn (3, 4)
>>> trg = torch.randint (4, (3,))
>>>
>>> loss = torch.nn.CrossEntropyLoss() (torch.sparse.mm (ts, inp.T).T, trg)
>>>
>>> loss.backward()
>>> ts.grad
tensor(indices=tensor([[0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],
                       [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]]),
       values=tensor([-0.0638,  0.0182,  0.0863,  0.0955,  0.0626, -0.1396,
                      -0.0997, -0.0430, -0.0724,  0.0641,  0.4070, -0.1736,
                       0.0737,  0.0573, -0.3936,  0.1211]),
       size=(4, 4), nnz=16, layout=torch.sparse_coo)

Best.

K. Frank