Hi, I noticed strange and incorrect results when taking the mean of a slice of a tensor on MPS.
See the code and results below. The calculation is correct when done on CPU, but on MPS the mean is incorrect, even though printing the slice shows the correct part of the tensor.
The behavior also depends on the size of the tensor. See the example below. Results for the first tensor m1, with shape (3,4,2) are correct, both for CPU and MPS. But the results for m2 with shape (3,5,2) is NOT correct for slice 1 on MPS.
==> What can be done about this?
Below is the code. I’m running torch version 2.0.0.dev20230205, on Apple M1PRO, 32GB, 10-core CPU, 16-core GPU.
import torch
print("Torch version: ", torch.__version__)
m1 = torch.rand(3,4,2)
m2 = torch.rand(3,5,2)
for m in (m1, m2):
print('='*40)
print(m)
for device in ['cpu', 'mps']:
print('-' * 40)
print("Mean of slice on ", device)
print('-' * 40)
m = m.to(device)
for i in range(3):
print("Slice ", i)
print(m[i][:2])
print("Mean with dim=0: ", m[i][:2].mean(dim=0))
print("Mean with dim=1: ", m[i][:2].mean(dim=1))
Here are the results. Notice that the calculations are correct on CPU. On MPS the mean calculation for tensor m1 are identical to results on CPU. But for tensor m2 slice 0 and slice 2 are correct, but slice 1 is incorrect! [How is this possible??]
Torch version: 2.0.0.dev20230205
========================================
tensor([[[0.8846, 0.1703],
[0.3438, 0.1784],
[0.0500, 0.5442],
[0.8398, 0.3753]],
[[0.8590, 0.1662],
[0.8247, 0.0485],
[0.9754, 0.5172],
[0.8726, 0.3522]],
[[0.0371, 0.4527],
[0.2520, 0.5700],
[0.1812, 0.6419],
[0.1710, 0.8662]]])
----------------------------------------
Mean of slice on cpu
----------------------------------------
Slice 0
tensor([[0.8846, 0.1703],
[0.3438, 0.1784]])
Mean with dim=0: tensor([0.6142, 0.1743])
Mean with dim=1: tensor([0.5274, 0.2611])
Slice 1
tensor([[0.8590, 0.1662],
[0.8247, 0.0485]])
Mean with dim=0: tensor([0.8418, 0.1074])
Mean with dim=1: tensor([0.5126, 0.4366])
Slice 2
tensor([[0.0371, 0.4527],
[0.2520, 0.5700]])
Mean with dim=0: tensor([0.1446, 0.5113])
Mean with dim=1: tensor([0.2449, 0.4110])
----------------------------------------
Mean of slice on mps
----------------------------------------
Slice 0
tensor([[0.8846, 0.1703],
[0.3438, 0.1784]], device='mps:0')
Mean with dim=0: tensor([0.6142, 0.1743], device='mps:0')
Mean with dim=1: tensor([0.5274, 0.2611], device='mps:0')
Slice 1
tensor([[0.8590, 0.1662],
[0.8247, 0.0485]], device='mps:0')
Mean with dim=0: tensor([0.8418, 0.1074], device='mps:0')
Mean with dim=1: tensor([0.5126, 0.4366], device='mps:0')
Slice 2
tensor([[0.0371, 0.4527],
[0.2520, 0.5700]], device='mps:0')
Mean with dim=0: tensor([0.1446, 0.5113], device='mps:0')
Mean with dim=1: tensor([0.2449, 0.4110], device='mps:0')
========================================
tensor([[[0.5245, 0.7926],
[0.6072, 0.3903],
[0.1381, 0.5081],
[0.9186, 0.9196],
[0.0205, 0.8702]],
[[0.1767, 0.2295],
[0.4399, 0.8822],
[0.1794, 0.5968],
[0.4745, 0.1653],
[0.4785, 0.3840]],
[[0.5404, 0.5616],
[0.6746, 0.7644],
[0.5006, 0.3931],
[0.6984, 0.1798],
[0.6535, 0.1407]]])
----------------------------------------
Mean of slice on cpu
----------------------------------------
Slice 0
tensor([[0.5245, 0.7926],
[0.6072, 0.3903]])
Mean with dim=0: tensor([0.5659, 0.5915])
Mean with dim=1: tensor([0.6585, 0.4988])
Slice 1
tensor([[0.1767, 0.2295],
[0.4399, 0.8822]])
Mean with dim=0: tensor([0.3083, 0.5558])
Mean with dim=1: tensor([0.2031, 0.6610])
Slice 2
tensor([[0.5404, 0.5616],
[0.6746, 0.7644]])
Mean with dim=0: tensor([0.6075, 0.6630])
Mean with dim=1: tensor([0.5510, 0.7195])
----------------------------------------
Mean of slice on mps
----------------------------------------
Slice 0
tensor([[0.5245, 0.7926],
[0.6072, 0.3903]], device='mps:0')
Mean with dim=0: tensor([0.5659, 0.5915], device='mps:0')
Mean with dim=1: tensor([0.6585, 0.4988], device='mps:0')
Slice 1
tensor([[0.1767, 0.2295],
[0.4399, 0.8822]], device='mps:0')
**Mean with dim=0: tensor([0.0986, 0.5499], device='mps:0')**
**Mean with dim=1: tensor([0.4453, 0.2031], device='mps:0')**
Slice 2
tensor([[0.5404, 0.5616],
[0.6746, 0.7644]], device='mps:0')
Mean with dim=0: tensor([0.6075, 0.6630], device='mps:0')
Mean with dim=1: tensor([0.5510, 0.7195], device='mps:0')