I have a 5 dimensional tensor S
of shape (bs,a,b,c,d)
which can be seen as a batch of (a x b) 2D matrices of (c x d) 2D matrices. Let say bs =1
and a = b = c = d = 3
for simplicity.
I have a shape (b) vector (let say ind = [2,0,1]
) that contains indices to columns of the inner c x d matrices. For each row i and column j of the outer a x b matrix, I need to gather the column ind[j] of the c x d matrix in an (a x b x c) tensor.
Example with S
tensor([[[[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[1211, 1212, 1213],
[1221, 1222, 1223],
[1231, 1232, 1233]],
[[1311, 1312, 1313],
[1321, 1322, 1323],
[1331, 1332, 1333]]],
[[[1211, 1221, 1231],
[1212, 1222, 1232],
[1213, 1223, 1233]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]],
[[2311, 2312, 2313],
[2321, 2322, 2323],
[2331, 2332, 2333]]],
[[[1311, 1321, 1331],
[1312, 1322, 1332],
[1313, 1323, 1333]],
[[2311, 2321, 2331],
[2312, 2322, 2332],
[2313, 2323, 2333]],
[[ 0, 0, 0],
[ 0, 0, 0],
[ 0, 0, 0]]]]])
I found that
S[:,torch.arange(3)[None,:],torch.arange(3)[:,None],ind[None,:]]
does the job for this specific case
tensor([[[[ 0, 0, 0],
[1211, 1221, 1231],
[1312, 1322, 1332]],
[[1231, 1232, 1233],
[ 0, 0, 0],
[2312, 2322, 2332]],
[[1331, 1332, 1333],
[2311, 2312, 2313],
[ 0, 0, 0]]]])
does the job but I don’t get why and how broadcasting works here. My indices have all shape (1,3)
or (3,1)
and should broadcast to a (3,3)
shape but the result has shape (3,3,3)
as required. I don’t get why.
What would be a general solution for arbitrary bs,a,b,c,d
?
T.