Autograd of precomputed distance of a vector quantization codebook

Hi!
I am new to deep learning and pytorch, and I am very confused about achieving an autograd process that I needed.
There is a block in my network that calculates cosine distance on 1e10 pairs of vectors, which can be extremely slow. Therefore, I am trying to add a vector quantization module to my network and it only contains a codebook of 512 vectors. In this way, I only need to calculate 512*512 pairs of cosine distance and store them in a new codebook. For all vectors, they are first quantitated, and then the pair-wise distance can be looked up in the codebook.
My question is how should I set autograd function of the lookup process, to pass the correct gradient function of cosine distance to the lookup table.

Thank you a lot!