Embedding_bag optimizations for performance


I am trying to understand how much optimized are the embedding operators available in pytorch. I intend to learn about the optimizations for improving execution time on CPU/GPU. Like, the simple functionality is to do a gather (lookup on a big table across distant indices making it an irregular operation) and then apply reduce (addition of contiguous embeddings making it a regular operation).

I understand that one optimization is the fusion of the gather and reduce which minimizes the generation of intermediate tensors. But, apart from that, are there other optimizations? like usage of AVX/SIMD, sorting the indices to exploit memory locality, etc.

Can someone indicate me on what optimizations are built on top of the basic/navie functionality and if possible point to the related references?

Rishabh Jain