Use case: I am writing an NLP program which requires a 2d indexing at some point.
The setup is:
-
k_next_words
is shape (batch_size, d1, d2) = (4,5,6) -
ix
is shape (batch_size, k_num_to_select, ix_dim=2) = (4,7,2),
where the 4 dim is batch size, the 7 dim is number of samples to look up, and the 2 dim is a2d
index corresponding tod1
andd2
ink_next_words
(such that the first entry in the index is between 0 and 4 and the second entry is between 0 and 5).
Problem: We want to, over all (5x6) datapoints in the batch, look up the 7 entries corresponding to the entries of k_next_words
at the 2d index.
This is probably a use case for gather
, but I haven’t gotten it to work so far!