# Mean of 2d tensor using indices tensor

Hello,
Let’s assume I have a Tensor of shape [M , D] where M is the number of samples and D is the feature dimension. additionally a mask tensor of shape [M , K] which contains K different indices for each of the M samples.

my question is how can I for each of the M samples calculate the mean of its K indices (for each sample we have in the mask which K indices we need to use)

Thanks!

This should probably work, if I understand the use case correctly:

``````M, D, K = 2, 4, 3
x = torch.randn(M, D)
idx = torch.randint(0, D, (M, K))
x[torch.arange(x.size(0)).unsqueeze(1), idx].mean(1)
``````

but notice, in your code you are using the mask for getting features entries, my goal is different
for instance if the first row of the mask contains : [0 , 10 , 13 , 200] I would like to mean X[0 , :] , X[10 , :] , X[13 , : ] , X[200 , : ] and not X[0 , 0] , X[0 , 10] , X[0 , 13] , X[0 , 200]
Thanks!

Ah OK, so `idx` would be used in `dim0`:

``````M, D, K = 5, 4, 2
x = torch.randn(M, D)
print(x)
> tensor([[ 0.3545, -1.2674, -0.1321, -0.4732],
[-0.9573,  0.8507, -0.1691,  0.6376],
[ 0.1386, -1.4488,  0.5860, -0.0135],
[-0.9486, -0.2253, -1.0952, -0.0432],
[-0.0213,  0.2111,  0.1952, -0.8802]])

idx = torch.randint(0, M, (M, K))
print(idx)
> tensor([[0, 2],
[4, 1],
[1, 2],
[1, 3],
[3, 0]])

tmp = x[idx]
print(tmp)
> tensor([[[ 0.3545, -1.2674, -0.1321, -0.4732], # corresponds to first row in idx: [0, 2]
[ 0.1386, -1.4488,  0.5860, -0.0135]],

[[-0.0213,  0.2111,  0.1952, -0.8802],
[-0.9573,  0.8507, -0.1691,  0.6376]],

[[-0.9573,  0.8507, -0.1691,  0.6376],
[ 0.1386, -1.4488,  0.5860, -0.0135]],

[[-0.9573,  0.8507, -0.1691,  0.6376],
[-0.9486, -0.2253, -1.0952, -0.0432]],

[[-0.9486, -0.2253, -1.0952, -0.0432],
[ 0.3545, -1.2674, -0.1321, -0.4732]]])

print(tmp.mean(dim=[1, 2]))
> tensor([-0.2820, -0.0167, -0.0470, -0.2438, -0.4788])
``````

In this example the first row of `idx` is `[0, 2]`, which is then creating `[x[0], x[2]]`. The last line of code is then taking the `mean` in `dim1` and `dim2`.

1 Like