PyTorch>=1.6.0 cannot coexist with PyCuda

Since torch==1.6.0 version using both Torch and PyCuda together causes unpredictable failures, even if they never share any memory.

After PyTorch update, I encountered many CUDA errors in various places (illegal memory access, misaligned address, cuDNN error: CUDNN_STATUS_MAPPING_ERROR, etc.). I was finally able to pin it down to PyCuda usage, getting rid of all the PyCuda calls and imports fixes the problem. I prepared a simplified MNIST example (based on https://github.com/pytorch/examples/tree/master/mnist) that fails after unrelated time measurement using PyCuda events.

Of course, this is far from a real use case, it just shows the kinds of operations that cause issues.

I tested it with 1.4.0 (ok), 1.5.0 (ok), 1.6.0 (fails), 1.7.0 (fails), 1.7.1 (fails) and 1.8.0 nightly (fails).

import time

import pycuda.driver as cuda
import torch
import torch.nn as nn
import torch.nn.functional as F
from pycuda.autoinit import context as pycuda_ctx
from torchvision import datasets, transforms


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def pycuda_dummy_measure_time():
    pycuda_ctx.push()
    event_start = cuda.Event().record()
    pycuda_ctx.pop()

    time.sleep(2)

    pycuda_ctx.push()
    event_stop = cuda.Event().record().synchronize()
    print(event_stop.time_since(event_start))
    pycuda_ctx.pop()


def main():
    torch.cuda.init()  # any torch cuda initialization before pycuda calls, torch.randn(10).cuda() works too

    pycuda_dummy_measure_time()  # measures time of a 2-second sleep using pycuda Events

    # normal MNIST training below
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, batch_size=64, num_workers=1, pin_memory=True)

    model = Net().cuda()
    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        if batch_idx % 10 == 0:
            print('[{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(train_loader.dataset),
                                                           100. * batch_idx / len(train_loader), loss.item()))


if __name__ == '__main__':
    main()

On failing configuration it produces:

Traceback (most recent call last):
  File "mnist.py", line 77, in <module>
    main()
  File "mnist.py", line 68, in main
    output = model(data)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "mnist.py", line 22, in forward
    x = self.conv1(x)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 727, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py", line 423, in forward
    return self._conv_forward(input, self.weight)
  File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/conv.py", line 420, in _conv_forward
    self.padding, self.dilation, self.groups)
RuntimeError: cuDNN error: CUDNN_STATUS_MAPPING_ERROR

All the configurations have:

Ubuntu 18.04
Python 3.6.9
pycuda==2020.1

Some of the configurations I got the failures in:

Tesla T4
Driver 440.64.00
CUDA 10.2, V10.2.89
torch==1.6.0
torchvision==0.7.0

GeForce GTX 1070
Driver 450.80.02
CUDA 11.0, V11.0.221
torch==1.7.1+cu110
torchvision==0.8.2+cu110

Some of the configurations with no problems:

Tesla T4
Driver 440.64.00
CUDA 10.2, V10.2.89
torch==1.4.0
torchvision==0.5.0

GeForce GTX 1070
Driver 450.80.02
CUDA 11.0, V11.0.221
torch==1.5.0
torchvision==0.6.0

I was also doing some tests with different drivers and CUDA versions, but I didn’t write down the full configurations, so I don’t want to cause confusion. If I should check some configuration, please let me know.

Thanks for all your help!

Are you using the same CUDA versions in PyTorch and PyCUDA?

Yes, they both use the default CUDA installation (/usr/local/cuda).

I found a solution that works for me, I’m posting it here for any future readers.

It turns out that since PyCuda 2020.1 version (released in October 2020) it is no longer required to create the PyCuda context, retain_primary_context method was added - it returns the device’s primary context. Using retain_primary_context instead of import pycuda.autoinit or make_default_context prevents new context creation and all the problems related to it.

It still doesn’t explain why pycuda context coexistence with torch worked until torch 1.5.0 and stopped working afterwards, but I think that won’t matter anyway in most cases, retain_primary_context is cleaner than creating a new one.

An improved version of the snippet above that works both for older and newer torch versions with pycuda>=2020.1:

import time

import pycuda.driver as cuda
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms

cuda.init()
pycuda_ctx = cuda.Device(0).retain_primary_context()


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output


def pycuda_dummy_measure_time():
    pycuda_ctx.push()
    event_start = cuda.Event().record()
    pycuda_ctx.pop()

    time.sleep(2)

    pycuda_ctx.push()
    event_stop = cuda.Event().record().synchronize()
    print(event_stop.time_since(event_start))
    pycuda_ctx.pop()


def main():
    torch.cuda.init()  # any torch cuda initialization before pycuda calls, torch.randn(10).cuda() works too

    pycuda_dummy_measure_time()  # measures time of a 2-second sleep using pycuda Events

    # normal MNIST training below
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
    dataset1 = datasets.MNIST('../data', train=True, download=True, transform=transform)
    train_loader = torch.utils.data.DataLoader(dataset1, batch_size=64, num_workers=1, pin_memory=True)

    model = Net().cuda()
    model.train()

    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.cuda(), target.cuda()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        if batch_idx % 10 == 0:
            print('[{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(batch_idx * len(data), len(train_loader.dataset),
                                                           100. * batch_idx / len(train_loader), loss.item()))


if __name__ == '__main__':
    main()
2 Likes

Thanks for the update! :slight_smile: