[solved] DataParallel Multiple V100s Hang

I’m having some trouble getting multi-gpu working across several V100s. Here’s code:

BATCH_SIZE = 800

import torch
import torchvision
import torchvision.transforms as transforms
from pathlib import Path

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])


testset = torchvision.datasets.CIFAR10(Path.home()/"data", train=False,
                                       download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=BATCH_SIZE,
                                         shuffle=False, num_workers=0)

import torch.nn as nn
import torch.nn.functional as F
criterion = nn.CrossEntropyLoss()
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 20, 3)
        self.conv2 = nn.Conv2d(20, 80, 3)
        self.conv3 = nn.Conv2d(80, 160, 5)
        self.pool = nn.MaxPool2d(3, 3)
        self.fc1 = nn.Linear(5120*2, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.reshape(-1, 5120*2)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

dataiter = iter(testloader)
inputs, labels = dataiter.next()
inputs, labels = inputs.cuda(), labels.cuda()

num_device = torch.cuda.device_count()
multi_time = None
print('making net')
net = nn.DataParallel(Net().cuda(), device_ids=tuple(range(num_device)))
print('made net')
optim = torch.optim.SGD(net.parameters(), lr=1)
# burn-in
for i in range(100):
    print(i)
    out = net(inputs)
    print('got output')
    loss = criterion(out, labels)
    print('got loss')
    loss.backward() 

The output I get is

making net
made net
0

and then it hangs. I can see the appropriate amount of memory allocated on all my GPUs. If I interrupt the kernel, the message I get is:

---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-4-8bd966a1882f> in <module>()
      8 for i in range(100):
      9     print(i)
---> 10     out = net(inputs)
     11     print('got output')
     12     loss = criterion(out, labels)

~/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    475             result = self._slow_forward(*input, **kwargs)
    476         else:
--> 477             result = self.forward(*input, **kwargs)
    478         for hook in self._forward_hooks.values():
    479             hook_result = hook(self, input, result)

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
    121             return self.module(*inputs[0], **kwargs[0])
    122         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 123         outputs = self.parallel_apply(replicas, inputs, kwargs)
    124         return self.gather(outputs, self.output_device)
    125 

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
    131 
    132     def parallel_apply(self, replicas, inputs, kwargs):
--> 133         return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
    134 
    135     def gather(self, outputs, output_device):

~/anaconda3/lib/python3.6/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
     67             thread.start()
     68         for thread in threads:
---> 69             thread.join()
     70     else:
     71         _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])

~/anaconda3/lib/python3.6/threading.py in join(self, timeout)
   1054 
   1055         if timeout is None:
-> 1056             self._wait_for_tstate_lock()
   1057         else:
   1058             # the behavior of a negative timeout isn't documented, but

~/anaconda3/lib/python3.6/threading.py in _wait_for_tstate_lock(self, block, timeout)
   1070         if lock is None:  # already determined that the C code is done
   1071             assert self._is_stopped
-> 1072         elif lock.acquire(block, timeout):
   1073             lock.release()
   1074             self._stop()

KeyboardInterrupt: 

which makes me think that somehow a lock isn’t being properly released. Has anybody had success getting multiple V100s working? This exact same code works fine on a different machine with multiple 1080Tis.

Disabling ACS solved this issue.

1 Like

Hey, could you tell me how to disable ACS? I also got hangs when I use dataparallel.

If you run sudo lspci -vvvv | grep -i plx you’ll get a listing of all the relevant PCI bridges. For example, 19:08.0. Then you can go through each one with sudo lspci -s 19:08.0 -vvv | grep -i acs and look at the ACSCtl line (the last line of output). All of the flags should have a - sign on them. If not, you can set them appropriately with sudo setpci -s 19:08.0 f2a.2=0000. Once you’ve done this for all of them, you should be able to use DataParallel.

^Small addition, in my case I found f2a.w=0000 worked instead of a specific number in place of ‘w’