Pytorch Indexing is very slow

Recently in my project, I needed to lift a point cloud of shape (n, 3) from an Image of shape (3, H, W), whose RGB color represents the spatial coordinates. I also have a mask with shape (240, 320) to indicate which points I need. I used an indexing operation to lift it:

masked = mask > threshold
point_cloud = img[:, masked].transpose(0, 1)

However, this operation is very slow, took nearly 0.1 seconds. Is there any way to improve the performance here?


I don’t know why exactly it’s slow, but did you try indexing on different devices? For many of my purposes indexing on GPU is faster than on CPU

I’m working on a tesla V100, so I think may it’s not the problem with the device? and the masked would contain like 10k index.