I am micro-benchmarking masked_select
, direct indexing and index_select
on a 1D tensor to compare their performance. After running my benchmarks, I noticed that all operations exhibit very similar execution times.
I also examined the TorchScript graph IR, but it closely resembles my original Python code—except for the ATen function prefixes. To investigate further, I used the PyTorch Profiler, and in the profiling results, I observed that both masked_select
and indexing include aten::nonzero
, with a similar overhead portion.
def is(tensor, threshold: int):
mask = torch.lt(tensor, threshold)
vector = torch.nonzero(mask).squeeze()
return torch.index_select(tensor, 0, vector)
def ms(tensor, threshold: int):
mask = torch.lt(tensor, threshold)
return torch.masked_select(tensor, mask)
def indexing(tensor, threshold):
return tensor[tensor < threshold]
This raises a few questions:
- Do
masked_select
and direct tensor indexing (tensor[mask]
) all rely onaten::nonzero
after graph IR transformation? - If they indeed share the same implementation under the hood, does that mean they are ultimately performing an
index_select
operation? - Is there a way to avoid the overhead of
aten::nonzero
when usingmasked_select
or indexing?
Any insights on how PyTorch handles these operations internally and possible workarounds to avoid aten::nonzero
would be greatly appreciated.
Thanks!