Data corruption with matrix.inverse() after other_tensor.to('cuda', non_blocking=True)

Hello. I get incorrect results of Tensor.inverse() if some other tensor is transferred to GPU with non_blocking=True just before that:

import torch

device = 'cuda:0'

# The 'Fixes' seem to eliminate the problem
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # Fix 1: Set this
the_matrix = [[2, 0], [0, 1]]
the_inverse = [[.5, 0], [.0, 1]]
the_inverse = torch.tensor(the_inverse, dtype=torch.float32, device=device)

for _ in range(10_000):
    batch_size = 2
    # batch_size = 4  # Fix 2: Set batch_size != 2
    matrix = torch.tensor([the_matrix] * batch_size, dtype=torch.float32)
    ballast = torch.ones([batch_size, 2**16], dtype=torch.float32)

    matrix = matrix.to(device)
    ballast = ballast.to(device, non_blocking=True)
    # torch.cuda.synchronize()  # Fix 3: call cuda.synchronize here

    # matrix = matrix.cpu().to(matrix.device)  # Fix 4: Move the matrix to cpu and back
    matrix_inv = matrix.inverse()
    # matrix_inv = torch.stack([m.inverse() for m in matrix])  # Fix 5: Do inversion one by one

    if (matrix_inv != the_inverse).any():
        print(f'CAUGHT BROKEN INVERSE\n'
              f'Matrix\n'
              f'------\n'
              f'{matrix}\n'
              f'Inverse\n'
              f'-------\n'
              f'{matrix_inv}\n'
              f'True inverse\n'
              f'------------\n'
              f'{the_inverse}')
        break

The problem seems to be gone if I uncomment one of the ‘fixes’.

I managed to reproduce this problem in the following environments

PyTorch version: 1.7.0
Is debug build: True
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.3 LTS (x86_64)
GCC version: (Ubuntu 7.4.0-1ubuntu1~18.04.1) 7.4.0
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.8 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti

Nvidia driver version: 430.64
cuDNN version: /usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.2
[pip3] torch==1.7.0
[pip3] torchaudio==0.7.0a0+ac17b64
[pip3] torchvision==0.8.1
[conda] blas                      1.0                         mkl  
[conda] cudatoolkit               10.1.243             h6bb024c_0  
[conda] mkl                       2020.2                      256  
[conda] mkl-service               2.3.0            py38he904b0f_0  
[conda] mkl_fft                   1.2.0            py38h23d657b_0  
[conda] mkl_random                1.1.1            py38h0573a6f_0  
[conda] numpy                     1.19.2           py38h54aff64_0  
[conda] numpy-base                1.19.2           py38hfa32c7d_0  
[conda] pytorch                   1.7.0           py3.8_cuda10.1.243_cudnn7.6.3_0    pytorch
[conda] torchaudio                0.7.0                      py38    pytorch
[conda] torchvision               0.8.1                py38_cu101    pytorch

and

same, but
GPU models and configuration: GPU 0: Tesla V100-SXM2-16GB
Nvidia driver version: 455.32.00

and occasionally in this collab.

@ptrblck, would you mind to have a look at this issue?

I’ve faced the same behaviour

Could you update to the nightly version and rerun the code, as we’ve recently fixed a race condition.
CC @voyleg

1 Like

Cannot reproduce with 1.8.0.dev20201126, thank you.

1 Like