Updating and removing entries in sparse tensor

I am using a sparse tensor as a kind of spike queue for my model. I am adding additional spikes to the queue by creating new sparse tensors with indices of where the new spikes should go and a list of ones as the values.

However, I am having trouble removing spikes from the queue. When I try setting the value corresponding to the index to be removed to zero and then coalesce I would expect the index to disappear but it doesn’t and the 0 values are kept in the list of values. Surprisingly the nnz attribute also is not changed. This means that my sparse queue is growing to be very dense with the indices of all the zeros. Also it means if I want to know if a spike exists at a location I can’t just check the presence of the index in the tensor’s indices but I also need to check the value if the index exists.

I’ve tried replacing the _indices and _values (and +._data) with new ones based on nonzero values but I can’t get this to work. Best I can do is just replace the entire queue with a new one, but this just defeats the point.

If I call zero_ all of the indices and values are removed as expected.

a = (torch.rand(3,4) > 0.5).to_sparse()
'''
tensor(indices=tensor([[0, 0, 2, 2, 2],
                       [0, 3, 0, 1, 2]]),
       values=tensor([1, 1, 1, 1, 1]),
       size=(3, 4), nnz=5, dtype=torch.uint8, layout=torch.sparse_coo)
'''

a.values()[0] = 0
'''
tensor(indices=tensor([[0, 0, 2, 2, 2],
                       [0, 3, 0, 1, 2]]),
       values=tensor([0, 1, 1, 1, 1]),
       size=(3, 4), nnz=5, dtype=torch.uint8, layout=torch.sparse_coo)
'''

b = torch.zeros(3,4)
b[1,1] = 1
b = b.to(torch.uint8).to_sparse()
a += b
'''
tensor(indices=tensor([[0, 0, 1, 2, 2, 2],
                       [0, 3, 1, 0, 1, 2]]),
       values=tensor([0, 1, 1, 1, 1, 1]),
       size=(3, 4), nnz=6, dtype=torch.uint8, layout=torch.sparse_coo)
'''

a.zero_()
'''
tensor(indices=tensor([], size=(2, 0)),
       values=tensor([], size=(0,)),
       size=(3, 4), nnz=0, dtype=torch.uint8, layout=torch.sparse_coo)
'''

Any advice?

3 Likes