Use case: I am writing an NLP program which requires a 2d indexing at some point.
The setup is:
k_next_wordsis shape (batch_size, d1, d2) = (4,5,6)
ixis shape (batch_size, k_num_to_select, ix_dim=2) = (4,7,2),
where the 4 dim is batch size, the 7 dim is number of samples to look up, and the 2 dim is a
2dindex corresponding to
k_next_words(such that the first entry in the index is between 0 and 4 and the second entry is between 0 and 5).
Problem: We want to, over all (5x6) datapoints in the batch, look up the 7 entries corresponding to the entries of
k_next_words at the 2d index.
This is probably a use case for
gather, but I haven’t gotten it to work so far!