CUDA 11.7 causing large calculation discrepancy

Last year, I wrote a post about SE module, I posted a piece of code that compares SE module written with linear and conv2d. When I revisit it recently, I noticed that with the same piece of code.
output with CUDA 11.6:

tensor(0.0006, device='cuda:0')
0.9538939730009588
1.5663262130001385

output with CUDA 11.7

tensor(0.4165, device='cuda:0')
0.9701259670000582
1.7940559319995373

Notice that for CUDA 11.7, the difference between Linear implementation and Conv2d implementation increased to 0.4165, I would not consider this as simply “float point error”.

If I understand it correctly, in this case, Conv2d with kernel size of 1 should be the same as Linear, and in fact, if we search for random open source SE module implementation online, we can see Linear and Conv2d are used interchangeably.

Both runs use the same PyTorch version of 1.13.1, under the same machine same platform using WSL. What may be the cause?

Was your PyTorch package also updated along with CUDA? I would check if this is a consequence of TF32 usage via cuDNN but not in cuBLAS. For example, you can run your linear implementation with:

torch.set_float32_matmul_precision("high")

and see if it changes the difference or with

 torch.backends.cudnn.allow_tf32 = False

and see if that changes the difference.

Thanks for reply. Surprisingly, using torch.set_float32_matmul_precision("high") changes the result significantly!
output for CUDA11.6:

tensor(0.6099, device='cuda:0')
1.2398140280001826
1.602404725001179

output for CUDA11.7:

tensor(0.6175, device='cuda:0')
1.066491745001258
1.7795323549999011

notice that now both version has a large unexpected discrepancy.

Also, I am sure both used the same Pytorch version of 1.13.1 (both installed via pip using command from Start Locally | PyTorch)

edit: I tried torch.backends.cudnn.allow_tf32 = False, now both has the error of 0.0007. So does that mean the problem comes from tf32? why does tf32 has such a large error (I mean, error of ~0.5 after a sigmoid unit is really huge)

Could you share some more details about how you are computing the error? Is it a sum of the absolute differences following the sigmoid? In that case it might be expected, as TF32 effectively has FP32 range but only FP16 precision.

The code is the same as the link I posted:

import torch 
import torch.nn as nn
import time
from torch.backends import cudnn

cudnn.benchmark = True

@torch.no_grad()
def main():
    data = torch.randn(64, 512,1,1).float().to("cuda")
    shape = data.shape
    m1 = nn.Sequential(
        nn.Linear(512, 64),
        nn.ReLU(True),
        nn.Linear(64, 512),
        nn.Sigmoid()
    ).float().to("cuda")
    m2 = nn.Sequential(
        nn.Conv2d(512, 64, 1),
        nn.ReLU(True),
        nn.Conv2d(64, 512, 1),
        nn.Sigmoid()
    ).float().to("cuda")
    m1[0].weight.data = m2[0].weight.data.squeeze()
    m1[0].bias.data = m2[0].bias.data
    m1[2].weight.data = m2[2].weight.data.squeeze()
    m1[2].bias.data = m2[2].bias.data
    res1 = data.squeeze()
    res1 = m1(res1)
    res1 = res1.reshape(shape)
    res2 = m2(data)
    print(torch.abs(res1 - res2).sum())

    for _ in range(2):  # warmup
        data = data.squeeze()
        data = m1(data)
        data = data.reshape(shape)
    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(10000):
        data = data.squeeze()
        data = m1(data)
        data = data.reshape(shape)
    torch.cuda.synchronize()
    print(time.perf_counter() - t)

    for _ in range(2):  # warmup
        data = m2(data)
    torch.cuda.synchronize()
    t = time.perf_counter()
    for _ in range(10000):
        data = m2(data)
    torch.cuda.synchronize()
    print(time.perf_counter() - t)

if __name__ == "__main__":
    main()

Basically what it does is creating two SE modules (squeeze and excitation), which consist of linear-relu-linear-sigmoid. However, m2 uses Conv2d with k=1 instead of Linear, which theoretically should be the same. I assigned the same weight matrix to both modules, and measured their absolute total error.

So if TF32 has only FP16 precision, that is probably the reason. I did not expect it to cause such a large error