CPU/GPU results inconsistent with matrix multiplication

I always thought 32-bits floats should be sufficient for most ML calculations. I tested the actual precision of a simple matrix multiplication operation on NumPy, PyTorch CPU, and PyTorch CUDA. Here is my code:

import numpy as np
import torch

np.random.seed(0)
M = np.random.randn(1000, 1000).astype(np.float64)

for T_np, T_cuda in [(np.float64, torch.float64), (np.float32, torch.float32)]:
    print('')
    M_np, M_pt, M_cuda = M.astype(T_np), torch.from_numpy(M).type(T_cuda), torch.from_numpy(M).type(T_cuda).cuda()
    M_pt.requires_grad = M_cuda.requires_grad = False
    print(M_np.dtype, M_np.shape, M_pt.dtype, M_pt.size(), M_cuda.dtype, M_cuda.size(), flush=True)

    MM = np.matmul(M_np, M_np)
    print(MM[0, : 5], flush=True)

    MM_pt = torch.matmul(M_pt, M_pt)
    print(MM_pt[0, : 5], flush=True)

    diff = np.absolute(M_np - M_pt.numpy())
    print('error:', diff.sum().item(), diff.mean().item(), diff.std().item())
    diff = np.absolute(MM - MM_pt.numpy())
    print('error:', diff.sum().item(), diff.mean().item(), diff.std().item())

    MM_cuda = torch.matmul(M_cuda, M_cuda)
    print(MM_cuda[0, : 5], flush=True)

    diff = np.absolute(M_np - M_cuda.cpu().numpy())
    print('error:', diff.sum().item(), diff.mean().item(), diff.std().item())
    diff = np.absolute(MM - MM_cuda.cpu().numpy())
    print('error:', diff.sum().item(), diff.mean().item(), diff.std().item())

On my windows machine with an RTX30 GPU, the output is:

float64 (1000, 1000) torch.float64 torch.Size([1000, 1000]) torch.float64 torch.Size([1000, 1000])
[-20.90456457  63.66601763   4.60530742 -38.41815213 -50.70685021]
tensor([-20.9046,  63.6660,   4.6053, -38.4182, -50.7069], dtype=torch.float64)
error: 0.0 0.0 0.0
error: 0.0 0.0 0.0
tensor([-20.9046,  63.6660,   4.6053, -38.4182, -50.7069], device='cuda:0',
       dtype=torch.float64)
error: 0.0 0.0 0.0
error: 1.8810475736069643e-08 1.8810475736069643e-14 1.685592703582923e-14

float32 (1000, 1000) torch.float32 torch.Size([1000, 1000]) torch.float32 torch.Size([1000, 1000])
**[-20.904562  63.666027   4.605306 -38.418148 -50.706852]**
**tensor([-20.9046,  63.6660,   4.6053, -38.4181, -50.7069])**
error: 0.0 0.0 0.0
error: 0.0 0.0 0.0
**tensor([-20.8954,  63.6586,   4.5888, -38.4053, -50.7082], device='cuda:0')**
error: 0.0 0.0 0.0
error: 7389.58349609375 0.007389583624899387 0.005598879884928465

It shows that for 64-bits floats, results for CUDA tensors are consistent with both CPU tensors and NumPy matrices. But for 32-bits floats, CUDA tensors give quite different results, while CPU tensors are same to NumPy results.

I have another machine with GTX1080Ti running Ubuntu. Same code, the results are:

float64 (1000, 1000) torch.float64 torch.Size([1000, 1000]) torch.float64 torch.Size([1000, 1000])
[-20.90456457  63.66601763   4.60530742 -38.41815213 -50.70685021]
tensor([-20.9046,  63.6660,   4.6053, -38.4182, -50.7069], dtype=torch.float64)
error: 0.0 0.0 0.0
error: 0.0 0.0 0.0
tensor([-20.9046,  63.6660,   4.6053, -38.4182, -50.7069], device='cuda:0',
       dtype=torch.float64)
error: 0.0 0.0 0.0
error: 2.4674370506863396e-08 2.4674370506863395e-14 2.4656188542088475e-14

float32 (1000, 1000) torch.float32 torch.Size([1000, 1000]) torch.float32 torch.Size([1000, 1000])
[-20.904562  63.666027   4.605306 -38.418148 -50.706852]
tensor([-20.9046,  63.6660,   4.6053, -38.4181, -50.7069])
error: 0.0 0.0 0.0
error: 0.0 0.0 0.0
tensor([-20.9046,  63.6660,   4.6053, -38.4181, -50.7068], device='cuda:0')
error: 0.0 0.0 0.0
error: 10.073770523071289 1.0073770681628957e-05 9.016605872602668e-06

The difference between CUDA/CPU results is much smaller this time. Both machines runs PyTorch 1.10 with CUDA toolkit 11.3.

From the results, the difference comes from the matrix multiplication operation, instead of copying tensors from RAM to GPU. For Windows, the error is really high for 32-bits floats. I think the results are not very reliable anymore. I tested matrix adding too, but there was no error at all. Can anyone familiar with the underlying GPU implementation of PyTorch explain? Thank you!

Hi Zvant!

Depending on your GPU, nvidia might be switching you over by
default to the misleading (dishonestly?) named “tf32” floating-point
arithmetic. (tf32 is essentially half-precision floating-point.)

You can try turning tf32 off with:

torch.backends.cuda.matmul.allow_tf32 = False

See the following thread and the github issue @tom references in it:

Best.

K. Frank

I set the flag to False and the error level is now reasonable. Never knew PyTorch on Ampere GPUs has this behaviour. Guess in most cases tf32 should be enough. Most networks have all kinds of regularizations anyway. Thank you!

While I have reservations about PyTorch enabling tf32 by default, note that this

is a bit of an oversimplification, I’d probably describe it as “tf32 has the dynamic range of fp32 but the relative precision of fp16”. For many applications, dynamic range of fp16 has been troublesome (so half overflowing to +/-inf or underflowing to 0), this is why AMP does the gradient scaling for backpropagation. In place where fp16 fails due to the lack of relative precision, tf32 won’t cut the mustard, either.
If tf32s characteristics are acceptable, it does seem to offer a formidable speed-up.

Best regards

Thomas