How to select index over two dimension?

Given a = torch.randn(3, 2, 4, 5), how I can select sub tensor like (2, :, 0, :), (1, :, 1, :), (2, :, 2, :), (0, :, 3, :) (a resulting tensor of size (2, 4, 5) or (4, 2, 5)?

While a[2, :, 0, :] gives

 0.5580 -0.0337  1.0048 -0.5044  0.6784
-1.6117  1.0084  1.1886  0.1278  0.3739
[torch.FloatTensor of size 2x5]

however, a[[2, 1, 2, 0], :, [0, 1, 2, 3], :] gives

TypeError: Performing basic indexing on a tensor and encountered an error indexing dim 0 with an object of type list. The only supported types are integers, slices, numpy scalars, or if indexing with a torch.LongTensor or torch.ByteTensor only a single Tensor may be passed.

though numpy returns a (4, 2, 5) tensor successfully.

1 Like

Any help on this?:confused:

I don’t think there’s a native way to do this. You could do this using the for comprehension though:

xs = [2, 1, 2, 0]
ys = [0, 1, 2, 3]
torch.cat([a[x, :, y, :].unsqueeze(0) for x, y in zip(xs, ys)])
1 Like

Thank you. Thought that’s a bit frustrated.

You can easily select twice though…

torch.randn(3,2,4,5)[[2,1,2,0],:,:,:][:,:,[0,1,2,3],:]
2 Likes

Will selecting twice has performance problem?

Update:
Seems 4 times slower than torch.cat([a[x, :, y, :].unsqueeze(0) for x, y in zip(xs, ys)]):joy:

1 Like

Right… since it is doing copy twice. The cat version is only copying once.

If these indices are fixed, you may want to look into masked_select. http://pytorch.org/docs/master/torch.html#torch.masked_select.

1 Like