Need help to optimize speed of gather/scatter heavy model

I am trying to optimize a model with a lot of gather, scatter operations.
So far, my optimization to the code have focused on:

  • use inplace ops and out= parameter to avoid extra copies
  • replace long indices by int indices everywhere (cuda seems to be designed for int indices right?), my layers now use torch.index_select and torch.index_add and torch.searchsorted, all of which support int32 indices.
  • fp16 everywhere with padding to make shapes multiples of 64

Using the revamped torch profiler (thanks!) I can see that most of the time is still spent in index_add_ and index_select operations.
This is happening on an Nvidia A100 GPU, on a P1000 most of the time is spent in algebra operations as expected.
Are there any tips or known limitations or good practice to optimize this kind of workload?