I am trying to optimize a model with a lot of gather, scatter operations.
So far, my optimization to the code have focused on:
- use inplace ops and
out=
parameter to avoid extra copies - replace long indices by int indices everywhere (cuda seems to be designed for int indices right?), my layers now use torch.index_select and torch.index_add and torch.searchsorted, all of which support int32 indices.
- fp16 everywhere with padding to make shapes multiples of 64
Using the revamped torch profiler (thanks!) I can see that most of the time is still spent in index_add_ and index_select operations.
This is happening on an Nvidia A100 GPU, on a P1000 most of the time is spent in algebra operations as expected.
Are there any tips or known limitations or good practice to optimize this kind of workload?