Memory consumption greatly varies depending on the device used

While running my training code, I got RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR at the row of forwarding (net(data)).
The error message also contained a minimal code snippet below to reproduce the exception.

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True
data = torch.randn([1, 256, 200, 200], dtype=torch.float, device='cuda', requires_grad=True)
net = torch.nn.Conv2d(256, 256, kernel_size=[3, 3], padding=[1, 1], stride=[1, 1], dilation=[1, 1], groups=1)
net = net.cuda().float()
out = net(data)
out.backward(torch.randn_like(out))
torch.cuda.synchronize()

When I executed the snippet above, the error was reproduced.

Curiously, the error occurs when I set CUDA_VISIBLE_DEVICES=0, and the code works fine when CUDA_VISIBLE_DEVICES=1.

To clarify the problem, I used pytorch_memlab to print the memory consumption during the execution.

import torch
from pytorch_memlab import profile

@profile
def main():  
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.allow_tf32 = True
    data = torch.randn([1, 256, 200, 200], dtype=torch.float, device='cuda', requires_grad=True)
    net = torch.nn.Conv2d(256, 256, kernel_size=[3, 3], padding=[1, 1], stride=[1, 1], dilation=[1, 1], groups=1)
    net = net.cuda().float()
    out = net(data)
    out.backward(torch.randn_like(out))
    torch.cuda.synchronize()

main()

When I executed the above code with the GPU 1, no problems happend.

# CUDA_VISIBLE_DEVICES=1 python cudnn_error_repro.py 
/opt/conda/lib/python3.7/site-packages/pytorch_memlab/line_profiler/line_records.py:63: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
  .drop(['code_hash', 'num_alloc_retries', 'num_ooms', 'prev_record_idx'], 1))
/opt/conda/lib/python3.7/site-packages/pytorch_memlab/line_profiler/line_records.py:189: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
  merged = merged.drop('code', 1, level=0)
## main

active_bytes reserved_bytes line code                                                                                                              
         all            all                                                                                                                        
        peak           peak                                                                                                                        
       0.00B          0.00B    4 @profile                                                                                                          
                               5 def main():                                                                                                       
       0.00B          0.00B    6     torch.backends.cuda.matmul.allow_tf32 = True                                                                  
       0.00B          0.00B    7     torch.backends.cudnn.benchmark = True                                                                         
       0.00B          0.00B    8     torch.backends.cudnn.deterministic = False                                                                    
       0.00B          0.00B    9     torch.backends.cudnn.allow_tf32 = True                                                                        
      40.00M         40.00M   10     data = torch.randn([1, 256, 200, 200], dtype=torch.float, device='cuda', requires_grad=True)                  
      40.00M         40.00M   11     net = torch.nn.Conv2d(256, 256, kernel_size=[3, 3], padding=[1, 1], stride=[1, 1], dilation=[1, 1], groups=1) 
      42.25M         62.00M   12     net = net.cuda().float()                                                                                      
       2.84G          2.86G   13     out = net(data)                                                                                               
       2.91G          5.61G   14     out.backward(torch.randn_like(out))                                                                           
     123.56M          2.86G   15     torch.cuda.synchronize()                                

According to the report, 2.86G of the memory was consumed at the row of the forwarding (out = net(data)).

On the other hand, when executing with the GPU 0…

# CUDA_VISIBLE_DEVICES=0 python cudnn_error_repro.py 
Traceback (most recent call last):
  File "cudnn_error_repro.py", line 17, in <module>
    main()
  File "cudnn_error_repro.py", line 13, in main
    out = net(data)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 399, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py", line 396, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: cuDNN error: CUDNN_STATUS_INTERNAL_ERROR
You can try to repro this exception using the following code snippet. If that doesn't trigger the error, please include your original repro script when reporting this issue.

import torch
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
torch.backends.cudnn.allow_tf32 = True
data = torch.randn([1, 256, 200, 200], dtype=torch.float, device='cuda', requires_grad=True)
net = torch.nn.Conv2d(256, 256, kernel_size=[3, 3], padding=[1, 1], stride=[1, 1], dilation=[1, 1], groups=1)
net = net.cuda().float()
out = net(data)
out.backward(torch.randn_like(out))
torch.cuda.synchronize()

ConvolutionParams 
    data_type = CUDNN_DATA_FLOAT
    padding = [1, 1, 0]
    stride = [1, 1, 0]
    dilation = [1, 1, 0]
    groups = 1
    deterministic = false
    allow_tf32 = true
input: TensorDescriptor 0x564de74e3730
    type = CUDNN_DATA_FLOAT
    nbDims = 4
    dimA = 1, 256, 200, 200, 
    strideA = 10240000, 40000, 200, 1, 
output: TensorDescriptor 0x564de74efba0
    type = CUDNN_DATA_FLOAT
    nbDims = 4
    dimA = 1, 256, 200, 200, 
    strideA = 10240000, 40000, 200, 1, 
weight: FilterDescriptor 0x564de86adc00
    type = CUDNN_DATA_FLOAT
    tensor_format = CUDNN_TENSOR_NCHW
    nbDims = 4
    dimA = 256, 256, 3, 3, 
Pointer addresses: 
    input: 0x7f46a7200000
    output: 0x7f46ab000000
    weight: 0x7f46a9a00000

/opt/conda/lib/python3.7/site-packages/pytorch_memlab/line_profiler/line_records.py:63: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
  .drop(['code_hash', 'num_alloc_retries', 'num_ooms', 'prev_record_idx'], 1))
/opt/conda/lib/python3.7/site-packages/pytorch_memlab/line_profiler/line_records.py:189: FutureWarning: In a future version of pandas all arguments of DataFrame.drop except for the argument 'labels' will be keyword-only
  merged = merged.drop('code', 1, level=0)
## main

active_bytes reserved_bytes line code                                                                                                              
         all            all                                                                                                                        
        peak           peak                                                                                                                        
       0.00B          0.00B    4 @profile                                                                                                          
                               5 def main():                                                                                                       
       0.00B          0.00B    6     torch.backends.cuda.matmul.allow_tf32 = True                                                                  
       0.00B          0.00B    7     torch.backends.cudnn.benchmark = True                                                                         
       0.00B          0.00B    8     torch.backends.cudnn.deterministic = False                                                                    
       0.00B          0.00B    9     torch.backends.cudnn.allow_tf32 = True                                                                        
      40.00M         40.00M   10     data = torch.randn([1, 256, 200, 200], dtype=torch.float, device='cuda', requires_grad=True)                  
      40.00M         40.00M   11     net = torch.nn.Conv2d(256, 256, kernel_size=[3, 3], padding=[1, 1], stride=[1, 1], dilation=[1, 1], groups=1) 
      42.25M         62.00M   12     net = net.cuda().float()                                                                                      
      32.46G         32.48G   13     out = net(data)                                                                                               
                              14     out.backward(torch.randn_like(out))                                                                           
                              15     torch.cuda.synchronize()                                                                    

At the row of the forwarding (out = net(data)), 32.48G of the memory was consumed, although it was less than 3G when executing with GPU 1! So, it seems that the RuntimeError (cuDNN error) was actually an OOM error.

Does anyone have ideas about the cause of this weird issue?

My computer has two GPUs; the GPU 0 is RTX 8000, and the GPU 1 is RTX 4000.

# nvidia-smi
Tue Dec 14 09:54:40 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.63.01    Driver Version: 470.63.01    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Quadro RTX 8000     On   | 00000000:1A:00.0 Off |                  Off |
| 33%   35C    P8    29W / 260W |     10MiB / 48601MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Quadro RTX 4000     On   | 00000000:68:00.0  On |                  N/A |
| 30%   41C    P8    13W / 125W |    630MiB /  7974MiB |      9%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
+-----------------------------------------------------------------------------+

Thank you for your help!