Sparse torch.topk: can hybrid sparse+dense tensors help?

I’d like to keep keep in the tensor only K largest elements in each row (corresponding to logits/logprobs of K highest scoring classes) to optimize disk space during serialization. I’d like to get a sparse tensor as an output.

Is there an simpler way than this (e.g. directly passing indices from topk to torch.sparse.FloatTensor constructor) ?

x = torch.rand(128, 512, 40) # 128 and 512 are batch dimension, 40 is class logits dimension

def sparse_topk(x, K, dim = -1):
  return torch.zeros_like(x).scatter_(dim, x.topk(max(1, min(K, x.size(dim))), dim = dim).indices, 1.0).to_sparse()

torch.save(x, 'x.pt') # 11M     x.pt
torch.save(sparse_topk(x, K = 5), 'x_.pt') # 8.8M    x_.pt
torch.save(sparse_topk(x.flatten(end_dim = 1), K = 5), 'x__.pt') # 6.3M    x__.pt

I’m a bit confused that the file size depends on if I flatten the batch dimensions or not. Should I keep the sparse dimension first (and thus transpose my tensor)?

Ideally, I’d also like to store indices as a bitset (since my number of classes is smaller than 256).

Any other advice on saving disk space?

Thanks!

We don’t use the type constructors these days… Do you want sparse_coo_tensor?

If you save indices for each dimension, it’s not that surprising that you would have some overhead.

If you want full control, a good first step would be to save the values and indices yourself, given that you want a matrix that is “full” in dims 0 and 1 and only sparse in the third dimension, you’d need a more sophisticated sparse format than PyTorch currently offers.
Saving the indices as uint8 and the values as bfloat16, I get 964k.

As far as I know, PyTorch doesn’t support bitsets, so you’d need to do that yourself.

Best regards

Thomas

Oh, right. Factory method torch.sparse_coo_tensor must be the one I wanted.

I found on https://pytorch.org/docs/stable/sparse.html that some sort of hybrid sparse+dense tensors are supported, but I didn’t understand the format yet and if they are of help in this usecase. Do you know if those “hybrid” tensors are the “hybrid” tensors needed in this case?

About bitsets - yes, I’ll add this usecase to my feature request: https://github.com/pytorch/pytorch/issues/32867

I thought that even if I implement bitset compression, uncompression in a practical way would still be reconstruct the indices tensor and still construct the sparse tensor and then call to_dense().

When I make sure that the class dimensions is the first one and call to_sparse(1) instead of to_sparse(), the _values() still contains a lot of zeros. It seems, that PyTorch isn’t able to discover the nnz pattern autmatically, so directly using torch.sparse_coo_tensor to construct a hybrid tensor could be the way forward:

>>> a = torch.eye(3)
>>> a
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])
>>> a.to_sparse()._values()
tensor([1., 1., 1.])
>>> a.to_sparse(1)._values()
tensor([[1., 0., 0.],
        [0., 1., 0.],
        [0., 0., 1.]])

So far I wrote:

# works correctly only for non-negative tensors
def save_topk(x, k, dim = -1, indices_dtype = None, values_dtype = None):
    topk = x.topk(k, dim = dim)
    return dict(k = k, dim = dim, shape = x.shape, indices = topk.indices.to(dtype = indices_dtype), values = topk.values.to(dtype = values_dtype), dtype = x.dtype)

def load_topk(saved, **kwargs):
    return torch.zeros(saved['shape'], dtype = saved['dtype'], **kwargs).scatter_(saved['dim'], saved['indices'].long(), saved['values'].to(dtype = saved['dtype']))

x = torch.rand(3, 4, 5)
y = save_topk(x, 2, dim = -1, indices_dtype = torch.uint8, values_dtype = torch.bfloat16)
z = load_topk(y)

I guess at this point my question is: can stock hybrid sparse-dense tensors help eliminate custom tensor saving format (while preserving uint8/bfloat16 custom dtypes) and loading logic?