I found matmul result is inconsistent with partitioned-matmal result using torch APIs/
Consider the matmul calculation: Z = XY, which can be partitioned by Z = [X1 X2][Y1 // Y2] = X1Y1 + X2Y2, where X is partitioned by the second dimension and Y is partitioned by the first dimension. The results need to be reduced to get Z.
Following is the implementation and I run it with PyTorch 1.13 on one nvidia V100-32GB GPU, but it failed at assertation on with
Can anyone help explain this?
import torch def matmul(x, y): return torch.matmul(x, y) def partitioned_matmual(x, y, num): xs = x.chunk(num, dim=1) ys = y.chunk(num, dim=0) out = torch.zeros(x.size(0), y.size(1), dtype=x.dtype, device=x.device) for x, y in zip(xs, ys): out += matmul(x, y) return out if __name__ == '__main__': x = torch.randn(128, 128, dtype=torch.float16).cuda() y = torch.randn(128, 128, dtype=torch.float16).cuda() z = matmul(x, y) pz = partitioned_matmual(x, y, 2) print(z) print(pz) assert torch.allclose(z, pz, atol=1e-1) assert torch.allclose(z, pz, atol=1e-2) assert torch.allclose(z, pz, atol=1e-3)