Large difference in results between CPU and CUDA

Hi, I have an issue where I’m getting substantially different results on my NN model when I’m running it on the CPU vs CUDA, despite setting all seeds. I understand that small differences are expected, but these are quite large. The code is relatively simple and I pasted it below. Could someone help me to understand if there’s something I’m doing wrong that causes these differences between CPU and GPU? For example, I found this question where the author manages to solve their problem by replacing x /= 255 by x = x/255, but I don’t really understand the principle of this, so it’s possible there’s something in the code which is obvious to others. Thanks for any help. When I’m talking about large differences, I mean the printout of “error” in every epoch (not necessarily the final one) in the code below.

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim, autograd
import time
from math import pi
import numpy as np
import sobol_seq as sobol
 
class Block(nn.Module):
 

    def __init__(self, in_N, width, out_N):
        super(Block, self).__init__()
        self.L1 = nn.Linear(in_N, width)
        self.L2 = nn.Linear(width, out_N)
        self.phi = nn.Tanh()

    def forward(self, x):
        return self.phi(self.L2(self.phi(self.L1(x)))) + x


class drrnn(nn.Module):
 
    def __init__(self, in_N, m, out_N, depth=4):
        super(drrnn, self).__init__()
        self.in_N = in_N
        self.m = m
        self.out_N = out_N
        self.depth = depth
        self.phi = torch.nn.Tanh()
        self.stack = torch.nn.ModuleList()
        self.stack.append(torch.nn.Linear(in_N,m))

        for i in range(depth):
            self.stack.append(Block(m,m, m))
        # last layer
        self.stack.append(torch.nn.Linear(m, out_N))

    def forward(self, x):
        # first layer
        for i in range(len(self.stack)):
            x = self.stack[i](x)
        return x


def get_interior_points(N=128):
 
    x1 = sobol.i4_sobol_generate(2, N)  
    return torch.from_numpy(x1).float() 


def get_boundary_points(N=33):
    index = sobol.i4_sobol_generate(1, N)
    xb1 = np.concatenate((index, np.zeros_like(index)), 1)
    xb2 = np.concatenate((index, np.ones_like(index)), 1)
    xb4 = np.concatenate((np.zeros_like(index), index), 1)
    xb6 = np.concatenate((np.ones_like(index), index), 1)
    xb = torch.from_numpy(np.concatenate((xb1, xb2, xb4, xb6), 0)).float()

    return xb


def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.Linear)):
        nn.init.xavier_normal_(m.weight)
        nn.init.constant_(m.bias, 0.0)


def exact_sol(x):
    value = torch.where(x[:, 0: 1] > 0.5, (1 - x[:, 0: 1]) ** 2, x[:, 0: 1] ** 2)  
    return value


def function_l_exact(x):
    return x[:, 0: 1] * x[:, 1: 2] * (1 - x[:, 0: 1]) * (1 - x[:, 1: 2])


def function_f():
    return -2


def gradients(input, output):
    return autograd.grad(outputs=output, inputs=input,
                                grad_outputs=torch.ones_like(output),
                                create_graph=True, retain_graph=True, only_inputs=True)[0]

def error_l2(x, y):
 
    return torch.norm(x - y) / torch.norm(y)


def runmodel(epochs: int,
             lr1, lr2,
             gamma1, gamma2,
             step_size1, step_size2, N_interior, N_boundary):

    seed = 123
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

    in_N = 2
    m = 40
    out_N = 1

    print(torch.cuda.is_available())
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    soln_nn = drrnn(in_N, m, out_N).to(device) # solution
    soln_nn.apply(weights_init)
    test_function_nn = drrnn(in_N, m, out_N).to(device) # the test function
    test_function_nn.apply(weights_init)
    optimizer_solution = optim.Adam(soln_nn.parameters(), lr=lr1)
    optimizer_test_function = optim.Adam(test_function_nn.parameters(), lr=lr2)


    StepLR1 = torch.optim.lr_scheduler.StepLR(optimizer_solution, step_size=step_size1, gamma=gamma1)
    StepLR2 = torch.optim.lr_scheduler.StepLR(optimizer_test_function, step_size=step_size2, gamma=gamma2)
    best_loss, best_epoch = 1000, 0
    tt = time.time()
    f_value = function_f()
    xr = get_interior_points(N_interior)
    xb = get_boundary_points(N_boundary)
    xr = xr.to(device)
    xb = xb.to(device)
    for epoch in range(epochs+1):

        xr.requires_grad_()
        output_r = soln_nn(xr)
        output_b = soln_nn(xb)
        output_phi_r = test_function_nn(xr) * function_l_exact(xr)
        exact_b = exact_sol(xb)
        grads_u = gradients(xr, output_r)
        grads_phi = gradients(xr, output_phi_r)

        loss_r = torch.square(torch.mean(torch.sum(grads_u * grads_phi, dim=1) - f_value * output_phi_r)) / torch.mean(torch.square(output_phi_r))
        loss_b = 10 * torch.mean(torch.abs(output_b - exact_b))
        
        loss1 = loss_r + loss_b
        loss2 = - loss_r + torch.square(torch.mean(torch.square(output_phi_r)) - 1)

        if epoch % 3 == 2:
            optimizer_test_function.zero_grad()
            loss2.backward()
            optimizer_test_function.step()
            StepLR2.step()
        else:
            optimizer_solution.zero_grad()
            loss1.backward()
            optimizer_solution.step()
            StepLR1.step()

        if epoch % 100 == 0:
            err = error_l2(soln_nn(xr), exact_sol(xr))
            print('epoch:', epoch, 'loss1:', loss1.item(), 'loss2:', loss2.item(), 'error', err.item())
            tt = time.time()

    with torch.no_grad():
        N0 = 1000
        x1 = np.linspace(0, 1, N0 + 1)

        xs1, ys1 = np.meshgrid(x1, x1)
        Z1 = torch.from_numpy(np.concatenate((xs1.flatten()[:, None], ys1.flatten()[:, None]), 1)).float()
        pred = torch.reshape(soln_nn(Z1), [N0 + 1, N0 + 1]).cpu().numpy()
        exact = torch.reshape(exact_sol(Z1), [N0 + 1, N0 + 1]).cpu().numpy()

    err = np.sqrt(np.sum(np.square(exact - pred)) / np.sum(np.square(exact)))
    print("Error:", err)

def main():
    epochs = 2000

    N_interior = 1000
    N_boundary = 200


    lr1 = 1e-2
    lr2 = 1e-2
    gamma1= 0.5
    gamma2 =0.5
    step_size1 = 1000
    step_size2 = 1000
    runmodel(epochs, lr1, lr2, gamma1, gamma2, step_size1, step_size2, N_interior, N_boundary)


if __name__ == '__main__':
    main()

Hi Christopher!

It’s not clear to me how big these differences are nor after how much
computation they occur.

I would suggest printing out (to full precision) xr.sum(), output_r.sum(),
output_phi_r.sum(), grads_u.sum(), grads_phi.sum(), loss1, and
loss2 for both the CPU and GPU cases the first four times through your
for epoch in range(epochs+1) loop.

Do these values really start out deviating by more than a few times
round-off error, or does the deviation only become “large” as you
perform more computation?

Also let use know what version of pytorch and what GPU you are using.

Best.

K. Frank

Hi KFrank

Thank you so much for the reply and sorry that I got back so late. I did what you suggested and noticed big differences already in the first epoch between CPU and GPU. Below I print in the following order: print(xr.sum()) print(output_r.sum()) print(output_phi_r.sum()) print(grads_u.sum()) print(grads_phi.sum()) print(loss1) print(loss2)
and I get:

On Cuda:

tensor(999.7226562500, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-1414.1257324219, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-8.1109323502, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-2020.3367919922, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-0.0266622305, device='cuda:0', grad_fn=<SumBackward0>)
tensor(14.0622110367, device='cuda:0', grad_fn=<AddBackward0>)
tensor(0.2760158777, device='cuda:0', grad_fn=<AddBackward0>)

On CPU:

tensor(999.7226562500, grad_fn=<SumBackward0>)
tensor(-594.3301391602, grad_fn=<SumBackward0>)
tensor(-24.9474697113, grad_fn=<SumBackward0>)
tensor(-940.2386474609, grad_fn=<SumBackward0>)
tensor(-0.1921596527, grad_fn=<SumBackward0>)
tensor(6.2187952995, grad_fn=<AddBackward0>)
tensor(0.6313534975, grad_fn=<AddBackward0>)

Should they really be so large? I don’t see what I’m doing wrong. Thanks for any help.

My torch version is ‘2.0.0+cu117’ and the GPU I’m using is NVIDIA A100-SXM4 40GB.

I also tried with adding

torch.use_deterministic_algorithms(True)

and

os.environ["CUBLAS_WORKSPACE_CONFIG"]=":4096:8"

but there is a similar difference between CPU and GPU.

Hi Christopher!

Could you print some of the values during the first forward pass before
any optimization steps? The reason is that if there are small numerical
differences, as you train, the two training paths can diverge (generally
giving statistically equivalent results that nonetheless might be numerically
quite different).

Also, in addition to printing out, for example, output_r.sum(), print out
some specific individual element of output_r.

Maybe – if, as noted above, small numerical differences have been
“amplified” by having the training paths diverge over the course of
many optimization steps.

But these differences are significant – they are not just (unamplified)
numerical round-off error.

Perhaps your issue is being exacerbated (or caused) by the dreaded
TF32 floating-point format. Pytorch chooses (unwisely) by default
to silently degrade the precision of what you thought was standard float32
floating-point arithmetic.

Quoting from the above link:

# The flag below controls whether to allow TF32 on matmul. This flag defaults to False
# in PyTorch 1.12 and later.
torch.backends.cuda.matmul.allow_tf32 = True


# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
torch.backends.cudnn.allow_tf32 = True

Consider setting torch.backends.cudnn.allow_tf32 = False and
rerunning your tests.

Best.

K. Frank

Hi K Frank,
Thanks again. I printed it now before the optimisation steps (and with or without the tf32 thing, I get the same results):

tensor(999.7226562500, device='cuda:0', grad_fn=<SumBackward0>)
tensor(-192.5105590820, device='cuda:0', grad_fn=<SumBackward0>)
tensor(15.4092960358, device='cuda:0', grad_fn=<SumBackward0>)
tensor(74.8240737915, device='cuda:0', grad_fn=<SumBackward0>)
tensor(0.3192231655, device='cuda:0', grad_fn=<SumBackward0>)
tensor(4.1089110374, device='cuda:0', grad_fn=<AddBackward0>)
tensor(-0.9975129962, device='cuda:0', grad_fn=<AddBackward0>)

and on CPU:

tensor(999.7226562500, grad_fn=<SumBackward0>)
tensor(-229.0813446045, grad_fn=<SumBackward0>)
tensor(15.2660779953, grad_fn=<SumBackward0>)
tensor(-599.9611816406, grad_fn=<SumBackward0>)
tensor(0.1353473663, grad_fn=<SumBackward0>)
tensor(6.5842385292, grad_fn=<AddBackward0>)
tensor(-2.6843986511, grad_fn=<AddBackward0>)

So there is still a large difference.

I was thinking: the initialisation for the neural networks is done by Xavier (as in the code), and there it uses some random numbers from some distributions. I read that these random numbers are different when on CPU and GPU. Could this be the reason for such large differences do you think? In any case I am surprised to see that this topic has not really been discussed by anyone.

Your code is unfortunately not executable, so I cannot reproduce the values.
However, you are right that there is no guarantee to sample the same random numbers between different devices. Could you rerun the code on the CPU with 2 different seeds and compare the differences?

Hi ptrkblck,

Thanks for your reply. With two different seeds, there are still large differences:

tensor(999.7226562500, grad_fn=<SumBackward0>)
tensor(-396.2010498047, grad_fn=<SumBackward0>)
tensor(12.9433746338, grad_fn=<SumBackward0>)
tensor(-776.6835937500, grad_fn=<SumBackward0>)
tensor(0.1150979996, grad_fn=<SumBackward0>)
tensor(8.2682046890, grad_fn=<AddBackward0>)
tensor(-2.6870155334, grad_fn=<AddBackward0>)
tensor(999.7226562500, grad_fn=<SumBackward0>)
tensor(368.0512084961, grad_fn=<SumBackward0>)
tensor(-13.6386232376, grad_fn=<SumBackward0>)
tensor(707.1499023438, grad_fn=<SumBackward0>)
tensor(-0.1424732208, grad_fn=<SumBackward0>)
tensor(6.1664748192, grad_fn=<AddBackward0>)
tensor(-1.9697301388, grad_fn=<AddBackward0>)

I don’t know why the code is not executable, I tried it on a different machine and it works there.

Thanks for running this test. In this case it seems you code might just be sensitive to the random seed and assuming you are sampling on the CPU vs. GPU the effect might be comparable to using different seeds due to the difference in the PRNG implementations.