I’m afraid the sparse implementation is surprisingly slow
You can check that making groups a dense Tensor makes this run significantly faster: groups = groups.to_dense() (even on CPU).

If your index_select implementation works for all the Sparse Tensors we have, we would be more than happy to accept a PR that improves the perf of the current implementation!

To be honest, I am not programming for a long time yet, so I kinda doubt that my code would be in any way up to PyTorch standards

But I’ll just add my function here and you can let me know if that is indeed useful, otherwise it might help other people
Right now it only works for selecting rows. Values are neglected, since I am clipping them to 1 anyways in my scenario. I could easily add columns and values though.

def myindexrowselect(groups, mask_index):
index = groups._indices()
newrowindex = -1
for ind in mask_index:
try:
newrowindex = newrowindex + 1
except NameError:
newrowindex = 0
keptindex = torch.squeeze((index[0] == ind).nonzero())
if len(keptindex.size()) == 0:
# Get column values from mask, create new row idx
try:
newidx = torch.cat((newidx, torch.tensor([newrowindex])), 0)
newcolval = torch.cat((newcolval, torch.tensor([index[1][keptindex.item()]])), 0)
except NameError:
newidx = torch.tensor([newrowindex])
newcolval = torch.tensor([index[1][keptindex.item()]])
else:
# Get column values from mask, create new row idx
# Add newrowindex eee.size() time to list
for i in range(list(keptindex.size())[0]):
try:
newidx = torch.cat((newidx, torch.tensor([newrowindex])), 0)
newcolval = torch.cat((newcolval, torch.tensor([index[1][keptindex.tolist()[i]]])), 0)
except NameError:
newidx = torch.tensor([newrowindex])
newcolval = torch.tensor([index[1][keptindex.tolist()[i]]])
groups = torch.sparse_coo_tensor(indices=torch.stack((newidx, newcolval), dim=0),
values=torch.ones(newidx.shape[0], dtype=torch.float),
size=(len(mask_index), groups.shape[1]))
return groups