Is there an easy way to compute mean of every two tensors in a 2D tensor?

Suppose I have a tensor that looks like this:

a = torch.tensor([
                  [1. 2.],
                  [3. 4.],
                  [5. 6.],
                  [7. 8.]
                 ])

Is there a way to compute the mean of every two row tensors (without overlap) in a without looping?

The expected output tensor would be:

torch.tensor([
              [2. 3.], # mean of 1st and 2nd row tensors in `a`
              [6. 7.]  # mean of 3rd and 4th row tensors in `a`
            ])

Something like this would work:

b = torch.stack([a_.mean(0) for a_ in a.split(2, 0)], 0)
print(b)
# tensor([[2., 3.],
#         [6., 7.]])

or using scatter_reduce_:

torch.zeros(2, 2).scatter_reduce_(0, torch.tensor([[0, 0], [0, 0], [1, 1], [1, 1]]), a, reduce="mean", include_self=False)

or index_put_ with accumulate=True and a scaling:

c = torch.zeros(2, 2)
c.index_put_((torch.tensor([0, 0, 1, 1]), ), a, accumulate=True)
c /= 2

Depending on your actual use case, how the indices are calculated, how large the tensor is etc. one approach might perform better than others.

1 Like