Vastly different results for float, yet only on CPU

I’m not sure if this is a bug or If im being dumb. Running this code:

## Code example

import numpy as np
import torchvision
import torch
import time

seed = int(time.perf_counter())
print('seed:',seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic=True
#Define Model
class MNISTAutoencoder(torch.nn.Module):
    def __init__(self):
        super(MNISTAutoencoder, self).__init__()
        self.encoder = torch.nn.Sequential(torch.nn.Linear(28*28,512,bias=True),
                                      torch.nn.Softplus(beta=1),
                                      torch.nn.Linear(512,256,bias=True),
                                      torch.nn.Softplus(beta=1),
                                      torch.nn.Linear(256,128,bias=True),
                                      torch.nn.Softplus(beta=1),
                                      torch.nn.Linear(128,32,bias=True),
                                      torch.nn.Softplus(beta=1),)
        self.decoder = torch.nn.Sequential(torch.nn.Linear(32,128,bias=True),
                                      torch.nn.Softplus(beta=1),
                                      torch.nn.Linear(128,256,bias=True),
                                      torch.nn.Softplus(beta=1),
                                      torch.nn.Linear(256,512,bias=True),
                                      torch.nn.Softplus(beta=1),
                                      torch.nn.Linear(512,28*28,bias=True),
                                      torch.nn.Sigmoid(),)
    def forward(self,x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


net = MNISTAutoencoder()
loss = torch.nn.MSELoss()

#Load Data
trainset = torchvision.datasets.MNIST("./data/",download=True,train=True)
X = [torchvision.transforms.ToTensor()(s[0]) for s in trainset]
X = torch.cat(X).view(-1,28*28) 
nTrials = 11

#Get Parameters to make sure same is used across Trials
pars = [p.data.clone() for p in net.parameters()]

#1: CUDA Float
if torch.cuda.is_available():
    for i in range(nTrials):
        net_cuda = net.cuda()
        for (cuda_p,p) in zip(net_cuda.parameters(),pars):
            cuda_p.data = p.clone().cuda()
        X_cuda = X.cuda()
        X_cuda_var = torch.autograd.Variable(X_cuda,requires_grad=False)
        start = time.perf_counter() 
        res = loss(net(X_cuda_var),X_cuda_var)
        stop = time.perf_counter() 
        print('CUDA Float: {:.32f}'.format(np.float64(res)),type(net_cuda.parameters().__next__().data),type(X_cuda),'took',stop-start,'seconds')

#2: CPU Float

for i in range(nTrials):
    net_float = net.float()
    for (float_p,p) in zip(net_float.parameters(),pars):
        float_p.data = p.clone().float()
    X_float = X.float()
    X_float_var = torch.autograd.Variable(X_float,requires_grad=False)
    start = time.perf_counter()
    res = loss(net(X_float_var),X_float_var)
    stop = time.perf_counter()
    print('CPU Float: {:.32f}'.format(np.float64(res)),type(net_float.parameters().__next__().data),type(X_float),'took',stop-start,'seconds')

#3: CUDA Double

if torch.cuda.is_available():
    for i in range(nTrials):
        net_cuda_double = net.double().cuda()
        for (cuda_double_p,p) in zip(net_cuda_double.parameters(),pars):
            cuda_double_p.data = p.clone().double().cuda()
        X_cuda_double = X.double().cuda()
        X_cuda_double_var = torch.autograd.Variable(X_cuda_double,requires_grad=False)
        start = time.perf_counter()
        res = loss(net(X_cuda_double_var),X_cuda_double_var)
        stop = time.perf_counter()
        print('CUDA Double: {:.32f}'.format(np.float64(res)),type(net_cuda_double.parameters().__next__().data),type(X_cuda_double),'took',stop-start,'seconds')

#4: CPU Double

for i in range(nTrials):
    net_double = net.double()
    for (double_p,p) in zip(net_double.parameters(),pars):
        double_p.data = p.clone().double()
    X_double = X.double()
    X_double_var = torch.autograd.Variable(X_double,requires_grad=False)
    start = time.perf_counter()
    res = loss(net(X_double_var),X_double_var)
    stop = time.perf_counter()
    print('CPU Double: {:.32f}'.format(np.float64(res)),type(net_double.parameters().__next__().data),type(X_double),'took',stop-start,'seconds')

Gives a ~20% lower loss for 32bit floats, but ONLY when run on GPU. (This is consistent for different seeds.)
The runtime suggests that the right precision is used on both GPU and CPU, yet I feel that the difference between CPU float and the other runs is way too large to be due to the lower precision. Does anybody see what’s going on? Thank you!

seed: 7197644
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.27543109748512506 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.031048773787915707 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.023138870485126972 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.02370053343474865 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.022951023653149605 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.022918391972780228 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.02192889992147684 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.021906440146267414 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.02196166105568409 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.021798920817673206 seconds
CUDA Float: 0.24038340151309967041015625000000 <class 'torch.cuda.FloatTensor'> <class 'torch.cuda.FloatTensor'> took 0.02138354256749153 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.28088248707354 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.273789366707206 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.301917650736868 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.366850168444216 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.251176925376058 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.254637469537556 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.248044191859663 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.26183510106057 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.272330325096846 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.187176392413676 seconds
CPU Float: 0.18305869400501251220703125000000 <class 'torch.FloatTensor'> <class 'torch.FloatTensor'> took 8.140838886611164 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.5673027914017439 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.41811269149184227 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.41235586535185575 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.4123460128903389 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.41235347278416157 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.41240311600267887 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.4124773908406496 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.41238985676318407 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.41240973211824894 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.4142442727461457 seconds
CUDA Double: 0.24038340743650432607125821959926 <class 'torch.cuda.DoubleTensor'> <class 'torch.cuda.DoubleTensor'> took 0.41422911919653416 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.737960833124816 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.918829107657075 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.934827781282365 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.825709821656346 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.902172627858818 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.822428305633366 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.829005297273397 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.847982819192111 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.84829668328166 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.882735094055533 seconds
CPU Double: 0.24038340743650035702394518466463 <class 'torch.DoubleTensor'> <class 'torch.DoubleTensor'> took 10.890244227834046 seconds

System Info

PyTorch version: 0.3.0.post4
Is debug build: No
CUDA used to build PyTorch: 8.0.61

OS: CentOS Linux release 7.4.1708 (Core)
GCC version: (GCC) 4.8.5 20150623 (Red Hat 4.8.5-16)
CMake version: version 2.8.12.2

Python version: 3.6
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
Nvidia driver version: 387.26
cuDNN version: Could not collect

Versions of relevant libraries:
[pip3] numpy (1.14.2)
[pip3] torch (0.3.0.post4)
[pip3] torchvision (0.2.0)
[conda] Could not collect

I tried to debug your code and I think this problem is related to this issue (I just tested the CPU side so far).
You’ll get the same results, if your calculate the MSELoss manually:

torch.mean((output32 - X_float_var)**2)
> tensor(0.2477)
torch.mean((output64 - X_double_var)**2)
> tensor(0.2477, dtype=torch.float64)

loss(output32, X_float_var)
> tensor(0.1844)
loss(output64, X_double_var)
> tensor(0.2477, dtype=torch.float64)
1 Like

That solves it, thank you!