How to Perform 2d indexing, gather In Pytorch?

Use case: I am writing an NLP program which requires a 2d indexing at some point.
The setup is:

  • k_next_words is shape (batch_size, d1, d2) = (4,5,6)
  • ix is shape (batch_size, k_num_to_select, ix_dim=2) = (4,7,2),
    where the 4 dim is batch size, the 7 dim is number of samples to look up, and the 2 dim is a 2d index corresponding to d1 and d2 in k_next_words (such that the first entry in the index is between 0 and 4 and the second entry is between 0 and 5).

Problem: We want to, over all (5x6) datapoints in the batch, look up the 7 entries corresponding to the entries of k_next_words at the 2d index.

This is probably a use case for gather, but I haven’t gotten it to work so far!