Incorrect mean of slice of tensor on MPS

Hi, I noticed strange and incorrect results when taking the mean of a slice of a tensor on MPS.
See the code and results below. The calculation is correct when done on CPU, but on MPS the mean is incorrect, even though printing the slice shows the correct part of the tensor.

The behavior also depends on the size of the tensor. See the example below. Results for the first tensor m1, with shape (3,4,2) are correct, both for CPU and MPS. But the results for m2 with shape (3,5,2) is NOT correct for slice 1 on MPS.

==> What can be done about this?

Below is the code. I’m running torch version 2.0.0.dev20230205, on Apple M1PRO, 32GB, 10-core CPU, 16-core GPU.

import torch

print("Torch version: ", torch.__version__)
m1 = torch.rand(3,4,2)
m2 = torch.rand(3,5,2)
for m in (m1, m2):
    print('='*40)
    print(m)
    for device in ['cpu', 'mps']:
        print('-' * 40)
        print("Mean of slice on ", device)
        print('-' * 40)
        m = m.to(device)
        for i in range(3):
            print("Slice ", i)
            print(m[i][:2])
            print("Mean with dim=0: ", m[i][:2].mean(dim=0))
            print("Mean with dim=1: ", m[i][:2].mean(dim=1))

Here are the results. Notice that the calculations are correct on CPU. On MPS the mean calculation for tensor m1 are identical to results on CPU. But for tensor m2 slice 0 and slice 2 are correct, but slice 1 is incorrect! [How is this possible??]

Torch version:  2.0.0.dev20230205
========================================
tensor([[[0.8846, 0.1703],
         [0.3438, 0.1784],
         [0.0500, 0.5442],
         [0.8398, 0.3753]],

        [[0.8590, 0.1662],
         [0.8247, 0.0485],
         [0.9754, 0.5172],
         [0.8726, 0.3522]],

        [[0.0371, 0.4527],
         [0.2520, 0.5700],
         [0.1812, 0.6419],
         [0.1710, 0.8662]]])
----------------------------------------
Mean of slice on  cpu
----------------------------------------
Slice  0
tensor([[0.8846, 0.1703],
        [0.3438, 0.1784]])
Mean with dim=0:  tensor([0.6142, 0.1743])
Mean with dim=1:  tensor([0.5274, 0.2611])
Slice  1
tensor([[0.8590, 0.1662],
        [0.8247, 0.0485]])
Mean with dim=0:  tensor([0.8418, 0.1074])
Mean with dim=1:  tensor([0.5126, 0.4366])
Slice  2
tensor([[0.0371, 0.4527],
        [0.2520, 0.5700]])
Mean with dim=0:  tensor([0.1446, 0.5113])
Mean with dim=1:  tensor([0.2449, 0.4110])
----------------------------------------
Mean of slice on  mps
----------------------------------------
Slice  0
tensor([[0.8846, 0.1703],
        [0.3438, 0.1784]], device='mps:0')
Mean with dim=0:  tensor([0.6142, 0.1743], device='mps:0')
Mean with dim=1:  tensor([0.5274, 0.2611], device='mps:0')
Slice  1
tensor([[0.8590, 0.1662],
        [0.8247, 0.0485]], device='mps:0')
Mean with dim=0:  tensor([0.8418, 0.1074], device='mps:0')
Mean with dim=1:  tensor([0.5126, 0.4366], device='mps:0')
Slice  2
tensor([[0.0371, 0.4527],
        [0.2520, 0.5700]], device='mps:0')
Mean with dim=0:  tensor([0.1446, 0.5113], device='mps:0')
Mean with dim=1:  tensor([0.2449, 0.4110], device='mps:0')
========================================
tensor([[[0.5245, 0.7926],
         [0.6072, 0.3903],
         [0.1381, 0.5081],
         [0.9186, 0.9196],
         [0.0205, 0.8702]],

        [[0.1767, 0.2295],
         [0.4399, 0.8822],
         [0.1794, 0.5968],
         [0.4745, 0.1653],
         [0.4785, 0.3840]],

        [[0.5404, 0.5616],
         [0.6746, 0.7644],
         [0.5006, 0.3931],
         [0.6984, 0.1798],
         [0.6535, 0.1407]]])
----------------------------------------
Mean of slice on  cpu
----------------------------------------
Slice  0
tensor([[0.5245, 0.7926],
        [0.6072, 0.3903]])
Mean with dim=0:  tensor([0.5659, 0.5915])
Mean with dim=1:  tensor([0.6585, 0.4988])
Slice  1
tensor([[0.1767, 0.2295],
        [0.4399, 0.8822]])
Mean with dim=0:  tensor([0.3083, 0.5558])
Mean with dim=1:  tensor([0.2031, 0.6610])
Slice  2
tensor([[0.5404, 0.5616],
        [0.6746, 0.7644]])
Mean with dim=0:  tensor([0.6075, 0.6630])
Mean with dim=1:  tensor([0.5510, 0.7195])
----------------------------------------
Mean of slice on  mps
----------------------------------------
Slice  0
tensor([[0.5245, 0.7926],
        [0.6072, 0.3903]], device='mps:0')
Mean with dim=0:  tensor([0.5659, 0.5915], device='mps:0')
Mean with dim=1:  tensor([0.6585, 0.4988], device='mps:0')
Slice  1
tensor([[0.1767, 0.2295],
        [0.4399, 0.8822]], device='mps:0')
**Mean with dim=0:  tensor([0.0986, 0.5499], device='mps:0')**
**Mean with dim=1:  tensor([0.4453, 0.2031], device='mps:0')**
Slice  2
tensor([[0.5404, 0.5616],
        [0.6746, 0.7644]], device='mps:0')
Mean with dim=0:  tensor([0.6075, 0.6630], device='mps:0')
Mean with dim=1:  tensor([0.5510, 0.7195], device='mps:0')

Thanks for reporting this issue as it might be a real bug in the MPS backend. Could you create a GitHub issue so that the code owners could track and fix it, please?

That is fine. Where do I do that?

You can create an issue here by selecting “New issue”. The template will then walk you through the needed information.