Batched sequence indexing


I have a tensor which represents a batch of same-length sequences. Each item in the sequence is a vector of probabilities across the item space.
The tensor is of dimension (batch_size, seq_length, num_items).
I want to index it with an index tensor of dimension (batch_size, variable_length), where, for each batch item, I would index a variable number of items, across the item space, and in the same way for each element of the sequence at this batch index.
The goal is to set each of these indexed elements to the same scalar x.

I tried a simple my_tensor[index] = x , however this attempts indexing my_tensor over the batch dimension. I am not sure how to achieve a ‘repeated’ indexing across the sequence length.


could you post a minimal (slow) code snippet, which shows the desired results using loops?

This is the desired behaviour from python lists.

# my_batch is of dimension (batch_size, seq_length, num_items)
# indices is of dimension (batch_size, num_items)
# x is some value

for i in range(batch_size):
  index = indices[i]
  for j in seq_length:
    for index in indices:
      my_batch[i, j, index] = x

I solved it using scatter_:

# make indices same dimension as my_batch
indices = indices.unsqueeze(1).repeat(1, seq_length, 1)
my_batch = my_batch.scatter_(dim=2, index=indices, src=x)