Ah OK, so idx
would be used in dim0
:
M, D, K = 5, 4, 2
x = torch.randn(M, D)
print(x)
> tensor([[ 0.3545, -1.2674, -0.1321, -0.4732],
[-0.9573, 0.8507, -0.1691, 0.6376],
[ 0.1386, -1.4488, 0.5860, -0.0135],
[-0.9486, -0.2253, -1.0952, -0.0432],
[-0.0213, 0.2111, 0.1952, -0.8802]])
idx = torch.randint(0, M, (M, K))
print(idx)
> tensor([[0, 2],
[4, 1],
[1, 2],
[1, 3],
[3, 0]])
tmp = x[idx]
print(tmp)
> tensor([[[ 0.3545, -1.2674, -0.1321, -0.4732], # corresponds to first row in idx: [0, 2]
[ 0.1386, -1.4488, 0.5860, -0.0135]],
[[-0.0213, 0.2111, 0.1952, -0.8802],
[-0.9573, 0.8507, -0.1691, 0.6376]],
[[-0.9573, 0.8507, -0.1691, 0.6376],
[ 0.1386, -1.4488, 0.5860, -0.0135]],
[[-0.9573, 0.8507, -0.1691, 0.6376],
[-0.9486, -0.2253, -1.0952, -0.0432]],
[[-0.9486, -0.2253, -1.0952, -0.0432],
[ 0.3545, -1.2674, -0.1321, -0.4732]]])
print(tmp.mean(dim=[1, 2]))
> tensor([-0.2820, -0.0167, -0.0470, -0.2438, -0.4788])
In this example the first row of idx
is [0, 2]
, which is then creating [x[0], x[2]]
. The last line of code is then taking the mean
in dim1
and dim2
.