Slicing 3d tensor with multiple 2d tensors

Hello,

I am implementing a Q-Learning algorithm for a routing problem at the moment, and I have come across a problem I don’t currently know how to solve. I have a 3d tensor that holds mutiple distance matrices of 2 dimensions (so ` batchsize x coord1 x coord2`). Now I have a second tensor “B” that has batches of sequences of indices. Each row in each batch of “B” has indices as tuples, which refer to one cell in the distance matrix of the same batch. The dimension is hence `batchsize x rows x ntuples x 2` Now I want to select for each row in “B” (which consists of “n” tuples) the corresponding entries in the distance matrix of the same batch. I think torch.gather cannot help me there.
A small example of batchsize 2 could look like this:

Basically, each list of tuples should result in a flat list of the values from the corresponding distance matrix, so the output dimension is ` batchsize x rows x ntuples` , since each index tuple resolves to one entry in the distance matrix. It is important that each batch of “B” should access the correct distance matrix, which is on the same dimension. I hope it was clear enough what I wanted to reach.

Any idea or help on how to solve this is well appreciated!

This code might work:

``````idx = torch.tensor([[[[0, 1],
[1, 2],
[2, 3],
[3, 0]],
[[0, 1],
[1, 2],
[2, 3],
[3, 0]],
[[1, 0],
[0, 2],
[2, 3],
[3, 1]]],
[[[0, 0],
[1, 1],
[2, 2],
[3, 3]],
[[0, 0],
[1, 1],
[2, 2],
[3, 3]],
[[3, 3],
[2, 2],
[1, 1],
[0, 0]]]])

x = torch.randn(2, 4, 4)
x = x.unsqueeze(1).expand(-1, idx.size(1), -1, -1)
res = x[torch.arange(x.size(0))[:, None, None, None],
torch.arange(x.size(1))[:, None, None],
idx[:, :, :, 0:1], idx[:, :, :, 1:2]]
``````

Could you verify it, as I was too lazy to type in your distance matrices?

Thank you for your fast response!
The code seems to do exactly what I want. Is this easy to apply to different batch sizes?

Yes, different batch sizes should be directly usable, as I tried to avoid fixed batch sizes (or other dimensions).

Currently, of course `idx` is fixed as well as its last dimension, but I understand that that’s exactly the use case.

Let me know, if you encounter any issues using this code with other shapes.