I currently use this code to calculate top k of whole batch together rather than do separate topk for each prediction.

Is it possible to do this faster than this code. predictions is [batch_size,sentence_lenght,30522], and before I done topk of predictions[batch_index,mask_index].

```
the_big_tensor = torch.zeros([len(input_data),30522], dtype=torch.float16).to('cuda')
for in_batch_index in range(0,len(input_data)):
masked_index = masked_indexes[in_batch_index]
the_big_tensor[in_batch_index] = predictions[in_batch_index, masked_index]
top_prediction2 = torch.topk(the_big_tensor,30522)
```

I suppose it is possible to transform tensor predictions[batch_index,sentence_lenght, mask_index] to tensor the_big_tensor, given that I have tensor masked_indexes[sentence_lenght] with some tensor multiplication.