Something wrong with pytorch 1.7.1 and cuda 11.0

I am getting weird results when I use pytorch 1.7.1 with cuda 11.0 on P100. Steps to reproduce:

import torch
import torch.nn as nn
torch.manual_seed(42)
l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
model = nn.Sequential(
	*[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
	nn.AdaptiveAvgPool2d((1, 1))
	).to('cuda:0')
torch.manual_seed(42)
a = torch.randn(16, 3, 224, 224)
b = model(a.to('cuda:0')).squeeze()
b

When I run this on 1.7.1 with cuda 11.0, I get

>>> import torch
>>> import torch.nn as nn
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f56c3859290>
>>> l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
>>> model = nn.Sequential(
...     *[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
...     nn.AdaptiveAvgPool2d((1, 1))
...     ).to('cuda:0')
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f56c3859290>
>>> a = torch.randn(16, 3, 224, 224)
>>> b = model(a.to('cuda:0')).squeeze()
>>> torch.__version__
'1.7.1'
>>> torch.version.cuda
'11.0'
>>> b
tensor([0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105,
        0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
>>>

But when I run on 1.7.0 with cuda 11.0, I get

>>> import torch
>>> import torch.nn as nn
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f88ae14b050>
>>> l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
>>> model = nn.Sequential(
...     *[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
...     nn.AdaptiveAvgPool2d((1, 1))
...     ).to('cuda:0')
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f88ae14b050>
>>> a = torch.randn(16, 3, 224, 224)
>>> b = model(a.to('cuda:0')).squeeze()
>>> torch.__version__
'1.7.0'
>>> torch.version.cuda
'11.0'
>>> b
tensor([0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109,
        0.0109, 0.0109, 0.0109, 0.0109, 0.0108, 0.0108, 0.0109],
       device='cuda:0', grad_fn=<SqueezeBackward0>)

Then I tried with pytorch 1.7.1 and 10.2 and I get the same

>>> import torch
>>> import torch.nn as nn
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f87f1b50270>
>>> l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
>>> model = nn.Sequential(
...     *[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
...     nn.AdaptiveAvgPool2d((1, 1))
...     ).to('cuda:0')
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f87f1b50270>
>>> a = torch.randn(16, 3, 224, 224)
>>> b = model(a.to('cuda:0')).squeeze()
>>> torch.__version__
'1.7.1'
>>> torch.version.cuda
'10.2'
>>> b
tensor([0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109,
        0.0109, 0.0109, 0.0109, 0.0109, 0.0108, 0.0108, 0.0109],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
>>>

Then I tried on cpu…

>>> import torch
>>> import torch.nn as nn
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f56c3859290>
>>> l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
>>> model = nn.Sequential(
...     *[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
...     nn.AdaptiveAvgPool2d((1, 1))
...     )
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f56c3859290>
>>> a = torch.randn(16, 3, 224, 224)
>>> b = model(a).squeeze()
>>> torch.__version__
'1.7.1'
>>> torch.version.cuda
'11.0'
>>> b
tensor([0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109,
        0.0109, 0.0109, 0.0109, 0.0109, 0.0108, 0.0108, 0.0109],
       grad_fn=<SqueezeBackward0>)

As we can see, just on 1.7.1 with cuda 11.0 we are getting a different result. This is really weird and as far as I have seen it only occurs when we have high number of filters and batch size (I think it happens when there is some kind of memory issue, but I am not sure).
I encountered this, when I was training my regular UNet models in 1.7.1 with 10.2 and when I moved to 11.0, suddenly my loss jumped to nan pretty soon. It happened to multiple models with different datasets! Then I tried to train on 10.2 again and it did not occur.

@ptrblck, can you please look into this.

I cannot reproduce this issue on a recent source build using CUDA11.2 and an internal cudnn version as well as the PyTorch 1.7.1 conda binaries with CUDA11.0 and get always:

tensor([0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109,
        0.0109, 0.0109, 0.0109, 0.0109, 0.0108, 0.0108, 0.0109],

on a node with P100-SXM2 16GB GPUs.

EDIT: could you store the state_dict as well as the input from the CUDA10.2 run, load it in the CUDA11.0 environment, and check the outputs again?

Hi, thank you so much for checking this. I tried on 2 instances and it happened, so I guessed it would be reproducible. I saved model and ‘a’ in 10.2 using

>>> import torch
>>> import torch.nn as nn
>>> torch.__version__
'1.7.1'
>>> torch.version.cuda
'10.2'
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f02f91602b0>
>>> l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
>>> model = nn.Sequential(
...     *[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
...     nn.AdaptiveAvgPool2d((1, 1))
...     ).to('cuda:0')
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f02f91602b0>
>>> a = torch.randn(16, 3, 224, 224)
>>> b = model(a.to('cuda:0')).squeeze()
>>> b
tensor([0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109, 0.0109,
        0.0109, 0.0109, 0.0109, 0.0109, 0.0108, 0.0108, 0.0109],
       device='cuda:0', grad_fn=<SqueezeBackward0>)
>>> torch.save({'a': a}, 'a_10_2.pth')
>>> torch.save(model.state_dict(), 'model_10_2.pth')

and then load into model in cuda 11 as

>>> import torch
>>> import torch.nn as nn
>>> torch.__version__
'1.7.1'
>>> torch.version.cuda
'11.0'
>>> l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
>>> model = nn.Sequential(
...     *[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
...     nn.AdaptiveAvgPool2d((1, 1))
...     ).to('cuda:0')
>>> model.load_state_dict(torch.load('model_10_2.pth'))
<All keys matched successfully>
>>> a = torch.load('a_10_2.pth')['a']
>>> b = model(a.to('cuda:0')).squeeze()
>>> b
tensor([0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105,
        0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105, 0.0105],
       device='cuda:0', grad_fn=<SqueezeBackward0>)

Note, this does not happen when the batch size is small (actually this occurred when I was using batch size of 2 in segmentation, so I think it is more to do with how much memory is being used rather than just batch size).
On 10.2

>>> import torch
>>> import torch.nn as nn
>>> torch.__version__
'1.7.1'
>>> torch.version.cuda
'10.2'
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f32829412b0>
>>> l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
>>> model = nn.Sequential(
...     *[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
...     nn.AdaptiveAvgPool2d((1, 1))
...     ).to('cuda:0')
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f32829412b0>
>>> a = torch.randn(4, 3, 224, 224)
>>> b = model(a.to('cuda:0')).squeeze()
>>> b
tensor([0.0109, 0.0109, 0.0109, 0.0109], device='cuda:0',
       grad_fn=<SqueezeBackward0>)

On 11.0

>>> import torch
>>> import torch.nn as nn
>>> torch.__version__
'1.7.1'
>>> torch.version.cuda
'11.0'
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f1da0aaa290>
>>> l = [3, 32, 64, 128, 256, 512, 256, 128, 64, 32, 1]
>>> model = nn.Sequential(
...     *[nn.Conv2d(l[i], l[i+1], 3, padding=1) for i in range(len(l)-1)],
...     nn.AdaptiveAvgPool2d((1, 1))
...     ).to('cuda:0')
>>> torch.manual_seed(42)
<torch._C.Generator object at 0x7f1da0aaa290>
>>> a = torch.randn(4, 3, 224, 224)
>>> b = model(a.to('cuda:0')).squeeze()
>>> b
tensor([0.0109, 0.0109, 0.0109, 0.0109], device='cuda:0',
       grad_fn=<SqueezeBackward0>)