Makes sense - thanks a lot!
The other aspect I’m wondering about in the above e.g. is that if I change the line
ixs = (torch.arange(100) > 50)
to
ixs = (torch.arange(100) > 99)
I get back to the baseline speed. So it seems the slowdown is not the indexing directly, but due to a copy being made of the elements of the matmul output where the indices that are True, and operations on this copy (hence the difference depending on how many elements of ixs are True). Is that kind of correct?!