Mean of 2d tensor using indices tensor

Hello,
Let’s assume I have a Tensor of shape [M , D] where M is the number of samples and D is the feature dimension. additionally a mask tensor of shape [M , K] which contains K different indices for each of the M samples.

my question is how can I for each of the M samples calculate the mean of its K indices (for each sample we have in the mask which K indices we need to use)

Thanks!

This should probably work, if I understand the use case correctly:

M, D, K = 2, 4, 3
x = torch.randn(M, D)
idx = torch.randint(0, D, (M, K))
x[torch.arange(x.size(0)).unsqueeze(1), idx].mean(1)

Thanks for your reply!
but notice, in your code you are using the mask for getting features entries, my goal is different
for instance if the first row of the mask contains : [0 , 10 , 13 , 200] I would like to mean X[0 , :] , X[10 , :] , X[13 , : ] , X[200 , : ] and not X[0 , 0] , X[0 , 10] , X[0 , 13] , X[0 , 200]
Thanks!

Ah OK, so idx would be used in dim0:

M, D, K = 5, 4, 2
x = torch.randn(M, D)
print(x)
> tensor([[ 0.3545, -1.2674, -0.1321, -0.4732],
          [-0.9573,  0.8507, -0.1691,  0.6376],
          [ 0.1386, -1.4488,  0.5860, -0.0135],
          [-0.9486, -0.2253, -1.0952, -0.0432],
          [-0.0213,  0.2111,  0.1952, -0.8802]])

idx = torch.randint(0, M, (M, K))
print(idx)
> tensor([[0, 2],
          [4, 1],
          [1, 2],
          [1, 3],
          [3, 0]])

tmp = x[idx]
print(tmp)
> tensor([[[ 0.3545, -1.2674, -0.1321, -0.4732], # corresponds to first row in idx: [0, 2]
           [ 0.1386, -1.4488,  0.5860, -0.0135]],

          [[-0.0213,  0.2111,  0.1952, -0.8802],
           [-0.9573,  0.8507, -0.1691,  0.6376]],

          [[-0.9573,  0.8507, -0.1691,  0.6376],
           [ 0.1386, -1.4488,  0.5860, -0.0135]],

          [[-0.9573,  0.8507, -0.1691,  0.6376],
           [-0.9486, -0.2253, -1.0952, -0.0432]],

          [[-0.9486, -0.2253, -1.0952, -0.0432],
           [ 0.3545, -1.2674, -0.1321, -0.4732]]])

print(tmp.mean(dim=[1, 2]))
> tensor([-0.2820, -0.0167, -0.0470, -0.2438, -0.4788])

In this example the first row of idx is [0, 2], which is then creating [x[0], x[2]]. The last line of code is then taking the mean in dim1 and dim2.

1 Like