Index_select() for sparse tensors slower on GPU than CPU

Hi all,

when I am masking a sparse Tensor with index_select() in PyTorch 1.4, the computation is much slower on a GPU (31 seconds) than a CPU (~6 seconds).

Does anyone know why there is such a huge difference?

Here is a simplyfied code snippet for the GPU:

n= 2000
groups = torch.sparse_coo_tensor(indices= torch.stack((torch.arange(n), torch.arange(n)),
                                 values=torch.ones(n, dtype= torch.long),
idx = torch.ones(1999, dtype= torch.long)

idx = idx.cuda()
groups = groups.cuda()

start_time = time.time()
groups.index_select(0, idx)
print("--- %s seconds for masking---" % (time.time() - start_time))


I’m afraid the sparse implementation is surprisingly slow :confused:
You can check that making groups a dense Tensor makes this run significantly faster: groups = groups.to_dense() (even on CPU).

I’ve checked the dense version before and it is indeed so much faster.
The dense representation doesn’t fit into my memory though :frowning:

I just wrote my own index_select functions for sparse Tensors instead, which works for my needs.

Thanks for your help though :slight_smile:

If your index_select implementation works for all the Sparse Tensors we have, we would be more than happy to accept a PR that improves the perf of the current implementation! :slight_smile:

To be honest, I am not programming for a long time yet, so I kinda doubt that my code would be in any way up to PyTorch standards :joy:

But I’ll just add my function here and you can let me know if that is indeed useful, otherwise it might help other people :slight_smile:
Right now it only works for selecting rows. Values are neglected, since I am clipping them to 1 anyways in my scenario. I could easily add columns and values though.

def myindexrowselect(groups, mask_index):

    index = groups._indices()
    newrowindex = -1

    for ind in mask_index:
            newrowindex = newrowindex + 1
        except NameError:
            newrowindex = 0

        keptindex = torch.squeeze((index[0] == ind).nonzero())

        if len(keptindex.size()) == 0:
            # Get column values from mask, create new row idx
                newidx =, torch.tensor([newrowindex])), 0)
                newcolval =, torch.tensor([index[1][keptindex.item()]])), 0)
            except NameError:
                newidx = torch.tensor([newrowindex])
                newcolval = torch.tensor([index[1][keptindex.item()]])

            # Get column values from mask, create new row idx
            # Add newrowindex eee.size() time to list
            for i in range(list(keptindex.size())[0]):
                    newidx =, torch.tensor([newrowindex])), 0)
                    newcolval =, torch.tensor([index[1][keptindex.tolist()[i]]])), 0)
                except NameError:
                    newidx = torch.tensor([newrowindex])
                    newcolval = torch.tensor([index[1][keptindex.tolist()[i]]])

    groups = torch.sparse_coo_tensor(indices=torch.stack((newidx, newcolval), dim=0),
                                     values=torch.ones(newidx.shape[0], dtype=torch.float),
                                     size=(len(mask_index), groups.shape[1]))
    return groups
