I have one batched tensor with matrices:

```
batch_size = 64
num_elements = 500
channels = 2
a = torch.rand(batch_size, num_elements, channels) # [64, 500, 2]
```

And one tensor with indices, that describe how to index each previous matrix:

```
indices = 80
b = (torch.rand(batch_size, indices) * num_elements).long() # [64, 80]
```

I want to perform an indexing such that each matrix in the batch is indexed with the corresponding index in the other tensor. The resulting matrix has shape `[64, 80, 2]`

How can I do that?