Error in my minimal working example of multiple GPUs?

I am trying to exploit multiple GPUs on Amazon AWS via DataParallel. This is on AWS Sagemaker with 4 GPUs, PyTorch 1.8 (GPU Optimized) and Python 3.6.

I have searched through the forum and read through the data parallel tutorial but I do not see such a minimal working example mentioned or the error explained.

Do you know what is wrong?

x = torch.rand(300, 400, 500).cuda()

model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
model = torch.nn.DataParallel(model, device_ids=[0,1])
y = model(x)

I get an error:


RuntimeError Traceback (most recent call last)
in
4 model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
5 model = torch.nn.DataParallel(model, device_ids=[0,1])
----> 6 y = model(x)

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
916 result = self._slow_forward(*input, **kwargs)
917 else:
→ 918 result = self.forward(*input, **kwargs)
919 for hook in itertools.chain(
920 _global_forward_hooks.values(),

/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
153 raise RuntimeError("module must have its parameters and buffers "
154 "on device {} (device_ids[0]) but found one of "
→ 155 “them on device: {}”.format(self.src_device_obj, t.device))
156
157 inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)

RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

For these variants, I get these errors:

import torch
x = torch.rand(300, 400, 500)

model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
model = torch.nn.DataParallel(model, device_ids=[0,1])
y = model(x)
```> 
> ---------------------------------------------------------------------------
> RuntimeError                              Traceback (most recent call last)
> <ipython-input-133-2ed596eb6192> in <module>
>       4 model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
>       5 model = torch.nn.DataParallel(model, device_ids=[0,1])
> ----> 6 y = model(x)
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
>     916             result = self._slow_forward(*input, **kwargs)
>     917         else:
> --> 918             result = self.forward(*input, **kwargs)
>     919         for hook in itertools.chain(
>     920                 _global_forward_hooks.values(),
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
>     153                 raise RuntimeError("module must have its parameters and buffers "
>     154                                    "on device {} (device_ids[0]) but found one of "
> --> 155                                    "them on device: {}".format(self.src_device_obj, t.device))
>     156 
>     157         inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
> 
> RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

import torch
x = torch.rand(300, 400, 500)

model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
y = model(x)

> ---------------------------------------------------------------------------
> RuntimeError                              Traceback (most recent call last)
> <ipython-input-134-16dea105c595> in <module>
>       4 model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
>       5 model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
> ----> 6 y = model(x)
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
>     916             result = self._slow_forward(*input, **kwargs)
>     917         else:
> --> 918             result = self.forward(*input, **kwargs)
>     919         for hook in itertools.chain(
>     920                 _global_forward_hooks.values(),
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
>     164         if len(self.device_ids) == 1:
>     165             return self.module(*inputs[0], **kwargs[0])
> --> 166         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
>     167         outputs = self.parallel_apply(replicas, inputs, kwargs)
>     168         return self.gather(outputs, self.output_device)
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in replicate(self, module, device_ids)
>     169 
>     170     def replicate(self, module, device_ids):
> --> 171         return replicate(module, device_ids, not torch.is_grad_enabled())
>     172 
>     173     def scatter(self, inputs, kwargs, device_ids):
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in replicate(network, devices, detach)
>      89     params = list(network.parameters())
>      90     param_indices = {param: idx for idx, param in enumerate(params)}
> ---> 91     param_copies = _broadcast_coalesced_reshape(params, devices, detach)
>      92 
>      93     buffers = list(network.buffers())
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in _broadcast_coalesced_reshape(tensors, devices, detach)
>      69         # Use the autograd function to broadcast if not detach
>      70         if len(tensors) > 0:
> ---> 71             tensor_copies = Broadcast.apply(devices, *tensors)
>      72             return [tensor_copies[i:i + len(tensors)]
>      73                     for i in range(0, len(tensor_copies), len(tensors))]
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/_functions.py in forward(ctx, target_gpus, *inputs)
>      21         ctx.num_inputs = len(inputs)
>      22         ctx.input_device = inputs[0].get_device()
> ---> 23         outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
>      24         non_differentiables = []
>      25         for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/comm.py in broadcast_coalesced(tensors, devices, buffer_size)
>      56     devices = [_get_device_index(d) for d in devices]
>      57     tensors = [_handle_complex(t) for t in tensors]
> ---> 58     return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
>      59 
>      60 
> 
> RuntimeError: NCCL Error 2: unhandled system error

import torch
x = torch.rand(300, 400, 500).cuda()

model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
y = model(x)

> ---------------------------------------------------------------------------
> RuntimeError                              Traceback (most recent call last)
> <ipython-input-136-439fd34aeaf9> in <module>
>       4 model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
>       5 model = torch.nn.DataParallel(model, device_ids=[0,1]).cuda()
> ----> 6 y = model(x)
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
>     916             result = self._slow_forward(*input, **kwargs)
>     917         else:
> --> 918             result = self.forward(*input, **kwargs)
>     919         for hook in itertools.chain(
>     920                 _global_forward_hooks.values(),
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
>     164         if len(self.device_ids) == 1:
>     165             return self.module(*inputs[0], **kwargs[0])
> --> 166         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
>     167         outputs = self.parallel_apply(replicas, inputs, kwargs)
>     168         return self.gather(outputs, self.output_device)
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in replicate(self, module, device_ids)
>     169 
>     170     def replicate(self, module, device_ids):
> --> 171         return replicate(module, device_ids, not torch.is_grad_enabled())
>     172 
>     173     def scatter(self, inputs, kwargs, device_ids):
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in replicate(network, devices, detach)
>      89     params = list(network.parameters())
>      90     param_indices = {param: idx for idx, param in enumerate(params)}
> ---> 91     param_copies = _broadcast_coalesced_reshape(params, devices, detach)
>      92 
>      93     buffers = list(network.buffers())
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in _broadcast_coalesced_reshape(tensors, devices, detach)
>      69         # Use the autograd function to broadcast if not detach
>      70         if len(tensors) > 0:
> ---> 71             tensor_copies = Broadcast.apply(devices, *tensors)
>      72             return [tensor_copies[i:i + len(tensors)]
>      73                     for i in range(0, len(tensor_copies), len(tensors))]
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/_functions.py in forward(ctx, target_gpus, *inputs)
>      21         ctx.num_inputs = len(inputs)
>      22         ctx.input_device = inputs[0].get_device()
> ---> 23         outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
>      24         non_differentiables = []
>      25         for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/comm.py in broadcast_coalesced(tensors, devices, buffer_size)
>      56     devices = [_get_device_index(d) for d in devices]
>      57     tensors = [_handle_complex(t) for t in tensors]
> ---> 58     return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
>      59 
>      60 
> 
> RuntimeError: NCCL Error 2: unhandled system error

import torch
x = torch.rand(300, 400, 500).cuda()

model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
model = torch.nn.DataParallel(model.cuda(), device_ids=[0,1])
y = model(x)

> ---------------------------------------------------------------------------
> RuntimeError                              Traceback (most recent call last)
> <ipython-input-137-a1695a6de8c1> in <module>
>       4 model = torch.nn.Sequential(torch.nn.Linear(500, 900), torch.nn.Linear(900, 1))
>       5 model = torch.nn.DataParallel(model.cuda(), device_ids=[0,1])
> ----> 6 y = model(x)
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
>     916             result = self._slow_forward(*input, **kwargs)
>     917         else:
> --> 918             result = self.forward(*input, **kwargs)
>     919         for hook in itertools.chain(
>     920                 _global_forward_hooks.values(),
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
>     164         if len(self.device_ids) == 1:
>     165             return self.module(*inputs[0], **kwargs[0])
> --> 166         replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
>     167         outputs = self.parallel_apply(replicas, inputs, kwargs)
>     168         return self.gather(outputs, self.output_device)
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py in replicate(self, module, device_ids)
>     169 
>     170     def replicate(self, module, device_ids):
> --> 171         return replicate(module, device_ids, not torch.is_grad_enabled())
>     172 
>     173     def scatter(self, inputs, kwargs, device_ids):
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in replicate(network, devices, detach)
>      89     params = list(network.parameters())
>      90     param_indices = {param: idx for idx, param in enumerate(params)}
> ---> 91     param_copies = _broadcast_coalesced_reshape(params, devices, detach)
>      92 
>      93     buffers = list(network.buffers())
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py in _broadcast_coalesced_reshape(tensors, devices, detach)
>      69         # Use the autograd function to broadcast if not detach
>      70         if len(tensors) > 0:
> ---> 71             tensor_copies = Broadcast.apply(devices, *tensors)
>      72             return [tensor_copies[i:i + len(tensors)]
>      73                     for i in range(0, len(tensor_copies), len(tensors))]
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/_functions.py in forward(ctx, target_gpus, *inputs)
>      21         ctx.num_inputs = len(inputs)
>      22         ctx.input_device = inputs[0].get_device()
> ---> 23         outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
>      24         non_differentiables = []
>      25         for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
> 
> /opt/conda/lib/python3.6/site-packages/torch/nn/parallel/comm.py in broadcast_coalesced(tensors, devices, buffer_size)
>      56     devices = [_get_device_index(d) for d in devices]
>      57     tensors = [_handle_complex(t) for t in tensors]
> ---> 58     return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
>      59 
>      60 
> 
> RuntimeError: NCCL Error 2: unhandled system error

Rerun your script with NCCL_DEBUG=INFO python script.pt args and check what causes the system error. The first error messages complaining about the device were caused as you didn’t push the parameters to the device via .cuda().

Thank you!

On torch 1.6, the following works:

#%env NCCL_DEBUG=INFO
import torch
a,b,c,d = 300,400,500,6000
x = torch.rand(a, b, c).cuda()

model = torch.nn.Sequential(torch.nn.Linear(c, d))
model = torch.nn.DataParallel(model.cuda(), device_ids=[0,1,2,3])

%timeit -n 5 y = model(x)

But on torch 1.8, Sagemaker prints the following warning and
INFO profiler_config_parser.py:102] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.
and this error
RuntimeError: NCCL Error 2: unhandled system error

Do you have any idea what causes the trouble?

I guess setting the env variable inside the code didn’t work, as NCCL would output the entire setup of the node etc.

Hi ptrblck,
Here is the output with NCCL_DEBUG=INFO. Any ideas? This is on an AWS Sagemaker system with 4 GPUs. The code is

import torch
device = torch.device("cuda")

bs = 2**5
k = 2**10

data = torch.rand(bs, k).to(device)

model = torch.nn.Linear(k, k)
model = torch.nn.DataParallel(model)
model.to(device)

output = model(data)

[2022-02-18 23:12:08.559 pytorch-1-8-gpu:844 INFO utils.py:27] RULE_JOB_STOP_SIGNAL_FILENAME: None
[2022-02-18 23:12:08.588 pytorch-1-8-gpu:844 INFO profiler_config_parser.py:102] Unable to find config at /opt/ml/input/config/profilerconfig.json. Profiler is disabled.
pytorch-1-8-gpu:844:844 [0] NCCL INFO Bootstrap : Using [0]lo:127.0.0.1<0> [1]veth-app1-2:169.255.254.2<0>

pytorch-1-8-gpu:844:844 [0] ofi_init:1136 NCCL WARN NET/OFI Only EFA provider is supported
pytorch-1-8-gpu:844:844 [0] NCCL INFO NET/IB : No device found.
pytorch-1-8-gpu:844:844 [0] NCCL INFO NET/Socket : Using [0]lo:127.0.0.1<0> [1]veth-app1-2:169.255.254.2<0>
pytorch-1-8-gpu:844:844 [0] NCCL INFO Using network Socket
NCCL version 2.7.8+cuda11.1
pytorch-1-8-gpu:844:907 [0] NCCL INFO Channel 00/02 : 0 1 2 3
pytorch-1-8-gpu:844:909 [2] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/64
pytorch-1-8-gpu:844:908 [1] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/64
pytorch-1-8-gpu:844:910 [3] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/64
pytorch-1-8-gpu:844:907 [0] NCCL INFO Channel 01/02 : 0 1 2 3
pytorch-1-8-gpu:844:909 [2] NCCL INFO Trees [0] 3/-1/-1->2->1|1->2->3/-1/-1 [1] 3/-1/-1->2->1|1->2->3/-1/-1
pytorch-1-8-gpu:844:908 [1] NCCL INFO Trees [0] 2/-1/-1->1->0|0->1->2/-1/-1 [1] 2/-1/-1->1->0|0->1->2/-1/-1
pytorch-1-8-gpu:844:910 [3] NCCL INFO Trees [0] -1/-1/-1->3->2|2->3->-1/-1/-1 [1] -1/-1/-1->3->2|2->3->-1/-1/-1
pytorch-1-8-gpu:844:907 [0] NCCL INFO threadThresholds 8/8/64 | 32/8/64 | 8/8/64
pytorch-1-8-gpu:844:907 [0] NCCL INFO Trees [0] 1/-1/-1->0->-1|-1->0->1/-1/-1 [1] 1/-1/-1->0->-1|-1->0->1/-1/-1
pytorch-1-8-gpu:844:909 [2] NCCL INFO Could not enable P2P between dev 2(=1d0) and dev 1(=1c0)
pytorch-1-8-gpu:844:910 [3] NCCL INFO Could not enable P2P between dev 3(=1e0) and dev 2(=1d0)
pytorch-1-8-gpu:844:908 [1] NCCL INFO Could not enable P2P between dev 1(=1c0) and dev 0(=1b0)
pytorch-1-8-gpu:844:907 [0] NCCL INFO Could not enable P2P between dev 0(=1b0) and dev 3(=1e0)

pytorch-1-8-gpu:844:907 [0] include/shm.h:28 NCCL WARN Call to posix_fallocate failed : No space left on device

pytorch-1-8-gpu:844:908 [1] include/shm.h:28 NCCL WARN Call to posix_fallocate failed : No space left on device
pytorch-1-8-gpu:844:908 [1] NCCL INFO include/shm.h:41 → 2
pytorch-1-8-gpu:844:907 [0] NCCL INFO include/shm.h:41 → 2

pytorch-1-8-gpu:844:908 [1] include/shm.h:48 NCCL WARN Error while creating shared memory segment nccl-shm-recv-68dbfe5912d345c9-0-0-1 (size 9637888)

pytorch-1-8-gpu:844:907 [0] include/shm.h:48 NCCL WARN Error while creating shared memory segment nccl-shm-recv-68dbfe5912d345c9-0-3-0 (size 9637888)

pytorch-1-8-gpu:844:910 [3] include/shm.h:28 NCCL WARN Call to posix_fallocate failed : No space left on device
pytorch-1-8-gpu:844:908 [1] NCCL INFO transport/shm.cc:101 → 2
pytorch-1-8-gpu:844:907 [0] NCCL INFO transport/shm.cc:101 → 2

pytorch-1-8-gpu:844:909 [2] include/shm.h:28 NCCL WARN Call to posix_fallocate failed : No space left on device
pytorch-1-8-gpu:844:908 [1] NCCL INFO transport.cc:30 → 2
pytorch-1-8-gpu:844:910 [3] NCCL INFO include/shm.h:41 → 2
pytorch-1-8-gpu:844:907 [0] NCCL INFO transport.cc:30 → 2
pytorch-1-8-gpu:844:909 [2] NCCL INFO include/shm.h:41 → 2
pytorch-1-8-gpu:844:908 [1] NCCL INFO transport.cc:49 → 2

pytorch-1-8-gpu:844:910 [3] include/shm.h:48 NCCL WARN Error while creating shared memory segment nccl-shm-recv-68dbfe5912d345c9-0-2-3 (size 9637888)

pytorch-1-8-gpu:844:907 [0] NCCL INFO transport.cc:49 → 2
pytorch-1-8-gpu:844:910 [3] NCCL INFO transport/shm.cc:101 → 2
pytorch-1-8-gpu:844:908 [1] NCCL INFO init.cc:766 → 2

pytorch-1-8-gpu:844:909 [2] include/shm.h:48 NCCL WARN Error while creating shared memory segment nccl-shm-recv-68dbfe5912d345c9-0-1-2 (size 9637888)

pytorch-1-8-gpu:844:908 [1] NCCL INFO init.cc:840 → 2
pytorch-1-8-gpu:844:907 [0] NCCL INFO init.cc:766 → 2
pytorch-1-8-gpu:844:909 [2] NCCL INFO transport/shm.cc:101 → 2
pytorch-1-8-gpu:844:907 [0] NCCL INFO init.cc:840 → 2
pytorch-1-8-gpu:844:910 [3] NCCL INFO transport.cc:30 → 2
pytorch-1-8-gpu:844:908 [1] NCCL INFO group.cc:73 → 2 [Async thread]
pytorch-1-8-gpu:844:910 [3] NCCL INFO transport.cc:49 → 2
pytorch-1-8-gpu:844:909 [2] NCCL INFO transport.cc:30 → 2
pytorch-1-8-gpu:844:910 [3] NCCL INFO init.cc:766 → 2
pytorch-1-8-gpu:844:909 [2] NCCL INFO transport.cc:49 → 2
pytorch-1-8-gpu:844:907 [0] NCCL INFO group.cc:73 → 2 [Async thread]
pytorch-1-8-gpu:844:910 [3] NCCL INFO init.cc:840 → 2
pytorch-1-8-gpu:844:909 [2] NCCL INFO init.cc:766 → 2
pytorch-1-8-gpu:844:909 [2] NCCL INFO init.cc:840 → 2
pytorch-1-8-gpu:844:910 [3] NCCL INFO group.cc:73 → 2 [Async thread]
pytorch-1-8-gpu:844:909 [2] NCCL INFO group.cc:73 → 2 [Async thread]
pytorch-1-8-gpu:844:844 [0] NCCL INFO init.cc:906 → 2
Traceback (most recent call last):
File “test.py”, line 13, in
output = model(data)
File “/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py”, line 918, in _call_impl
result = self.forward(*input, **kwargs)
File “/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py”, line 166, in forward
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
File “/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py”, line 171, in replicate
return replicate(module, device_ids, not torch.is_grad_enabled())
File “/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py”, line 91, in replicate
param_copies = _broadcast_coalesced_reshape(params, devices, detach)
File “/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/replicate.py”, line 71, in _broadcast_coalesced_reshape
tensor_copies = Broadcast.apply(devices, *tensors)
File “/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/_functions.py”, line 23, in forward
outputs = comm.broadcast_coalesced(inputs, ctx.target_gpus)
File “/opt/conda/lib/python3.6/site-packages/torch/nn/parallel/comm.py”, line 58, in broadcast_coalesced
return torch._C._broadcast_coalesced(tensors, devices, buffer_size)
RuntimeError: NCCL Error 2: unhandled system error

Thanks for the update.
The issue is caused by:

NCCL WARN Call to posix_fallocate failed : No space left on device

which is raised e.g. if NCCL tries to allocate shared memory in /dev/shm and isn’t able to do so.
I guess you might be using a container and are not properly configuring the shared memory usage (e.g. via --ipc=host).