Bizarre CUDA behavior with nn.Linear

I am encountering some particularly strange behavior with very simple usage of nn.Linear when using CUDA. I wanted to post and see if others have encountered similar behavior. I’ve posted some information about my environment below the example.

Example:

I start with a trivial nn.Linear and a Tensor, both on CUDA.

import torch
import torch.nn as nn

linear = nn.Linear(1, 1, bias=False).to('cuda:0')
x = torch.ones(3, 1, device='cuda:0')

There’s only one parameter for linear:

>>> linear.state_dict()
OrderedDict([('weight', tensor([[0.3293]], device='cuda:0'))])

This default parameter is fine for the example. The crazy behavior appears when we use linear.forward:

>>> linear(x)  # This call returns nonsense
tensor([[0.0000],
        [1.8750],
        [1.0000]], device='cuda:0', grad_fn=<MmBackward>)

(Side note: if I repeatedly call linear(x), the result can fluctuate dramatically.)

Just as a reality-check, when running on CPU, the result comes out correctly:

>>> linear.to('cpu')
>>> x = x.cpu()
>>> linear(x)
tensor([[0.3293],
        [0.3293],
        [0.3293]], grad_fn=<MmBackward>)

Environment
I am using an EC2 with instance type p2.xlarge. I haven’t made any changes to the environment–I’m it’s a newly spun-up instance and I’m using a default virtual environment with source activate pytorch_latest_p37. The GPU is a Tesla K80. Here is the output from nvidia-smi:

$ nvidia-smi
Fri Jul  2 16:04:45 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.119.03   Driver Version: 450.119.03   CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla K80           On   | 00000000:00:1E.0 Off |                    0 |
| N/A   38C    P0    54W / 149W |    500MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     12914      C   ...rch_latest_p37/bin/python      497MiB |
+-----------------------------------------------------------------------------+

Thanks in advance for any suggestions about what could be going on here, and sorry if I’m making some silly obvious mistake.

I can’t seem to reproduce this on my end with

import torch
import subprocess

linearcuda = torch.nn.Linear(1,1,bias=False,device='cuda')
linearcpu = torch.nn.Linear(1,1,bias=False)
onescuda = torch.ones(3,1,device='cuda')
onescpu = torch.ones(3,1)
print(linearcuda(onescuda))
print(linearcpu(onescpu))
tensor([[-0.6604],
        [-0.6604],
        [-0.6604]], device='cuda:0', grad_fn=<MmBackward>)
tensor([[-0.6166],
        [-0.6166],
        [-0.6166]], grad_fn=<MmBackward>)

Can you give some more details about this setup e.g.,

print(torch.__version__)
print(torch.version.cuda)

Thanks very much for looking at this. Can’t say I’m surprised that it isn’t reproducible for you.

>>> print(torch.__version__)
1.8.1+cu111
>>> print(torch.version.cuda)
11.1

Is there any other information I can provide?

The issue seems to be related to the version.

For the same one used above,

>>> print(torch.__version__)
1.8.1+cu111
>>> print(torch.version.cuda)
11.1

This is the result:

>>> linearcuda = torch.nn.Linear(1,1,bias=False).to('cuda:0')
>>> linearcpu = torch.nn.Linear(1,1,bias=False)
>>> onescuda = torch.ones(3,1,device='cuda:0')
>>> onescpu = torch.ones(3,1)
>>> print(linearcuda(onescuda))
tensor([[1.],
        [1.],
        [1.]], device='cuda:0', grad_fn=<MmBackward>)
>>> print(linearcpu(onescpu))
tensor([[0.2037],
        [0.2037],
        [0.2037]], grad_fn=<MmBackward>)
>>> print(linearcuda.state_dict())
OrderedDict([('weight', tensor([[0.0904]], device='cuda:0'))])

While with an older version:

>>> print(torch.__version__)
1.4.0
>>> print(torch.version.cuda)
10.1
>>> linearcuda = torch.nn.Linear(1,1,bias=False).to('cuda:0')
>>> linearcpu = torch.nn.Linear(1,1,bias=False)
>>> onescuda = torch.ones(3,1,device='cuda:0')
>>> onescpu = torch.ones(3,1)
>>> print(linearcuda(onescuda))
tensor([[0.2613],
        [0.2613],
        [0.2613]], device='cuda:0', grad_fn=<MmBackward>)
>>> print(linearcpu(onescpu))
tensor([[0.8475],
        [0.8475],
        [0.8475]], grad_fn=<MmBackward>)
1 Like

Could you update to 1.9.0 and rerun the code, please?
We’ve seen some issues in the pip wheels in 1.8.0 and 1.8.1 in particular for sm_61 by leaking cublas symbols, which might be also visible on the K80.

1 Like

You’re right, pip install -U torch seems to have fixed the problem. Thanks! Odd that this is a default configuration for the instance with such an issue…

2 Likes