I found matmul result is inconsistent with partitioned-matmal result using torch APIs/
Consider the matmul calculation: Z = XY, which can be partitioned by Z = [X1 X2][Y1 // Y2] = X1Y1 + X2Y2, where X is partitioned by the second dimension and Y is partitioned by the first dimension. The results need to be reduced to get Z.
Following is the implementation and I run it with PyTorch 1.13 on one nvidia V100-32GB GPU, but it failed at assertation on with atol=1e-2
.
Can anyone help explain this?
import torch
def matmul(x, y):
return torch.matmul(x, y)
def partitioned_matmual(x, y, num):
xs = x.chunk(num, dim=1)
ys = y.chunk(num, dim=0)
out = torch.zeros(x.size(0), y.size(1), dtype=x.dtype, device=x.device)
for x, y in zip(xs, ys):
out += matmul(x, y)
return out
if __name__ == '__main__':
x = torch.randn(128, 128, dtype=torch.float16).cuda()
y = torch.randn(128, 128, dtype=torch.float16).cuda()
z = matmul(x, y)
pz = partitioned_matmual(x, y, 2)
print(z)
print(pz)
assert torch.allclose(z, pz, atol=1e-1)
assert torch.allclose(z, pz, atol=1e-2)
assert torch.allclose(z, pz, atol=1e-3)