Mismatched results for partitioned operators

I found matmul result is inconsistent with partitioned-matmal result using torch APIs/

Consider the matmul calculation: Z = XY, which can be partitioned by Z = [X1 X2][Y1 // Y2] = X1Y1 + X2Y2, where X is partitioned by the second dimension and Y is partitioned by the first dimension. The results need to be reduced to get Z.

Following is the implementation and I run it with PyTorch 1.13 on one nvidia V100-32GB GPU, but it failed at assertation on with atol=1e-2.

Can anyone help explain this?

import torch

def matmul(x, y):
    return torch.matmul(x, y)

def partitioned_matmual(x, y, num):
    xs = x.chunk(num, dim=1)
    ys = y.chunk(num, dim=0)
    out = torch.zeros(x.size(0), y.size(1), dtype=x.dtype, device=x.device)
    for x, y in zip(xs, ys):
        out += matmul(x, y)
    return out

if __name__ == '__main__':
    x = torch.randn(128, 128, dtype=torch.float16).cuda()
    y = torch.randn(128, 128, dtype=torch.float16).cuda()
    z = matmul(x, y)
    pz = partitioned_matmual(x, y, 2)

    assert torch.allclose(z, pz, atol=1e-1)
    assert torch.allclose(z, pz, atol=1e-2)
    assert torch.allclose(z, pz, atol=1e-3)

These errors might be expected in float16 and you could use a wider dtype to increase the precision.
E.g. float32 yields an abs().max() error of ~1e-5 and float64 of ~1e-14.

Thank you! I’m curious that whether this mismatching is due to different ordering of multiplication and reduction in the underlying hardware?

Yes, numerical mismatches due to the limited floating point precision can be caused by a difference in the operation order as seen also in this small example:

x = torch.randn(100, 100)
s1 = x.sum()
s2 = x.sum(0).sum(0)
s3 = x[torch.randperm(x.size(0))].sum()
print(s1 - s2)
# tensor(-5.7220e-06)
print(s1 - s3)
# tensor(1.1444e-05)