import os
os.environ["CUDA_LAUNCH_BLOCKING"]= "1"
import torch
from torch import nn
from torch.nn import DataParallel
class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
def forward(self, feed, index, out, mask=None):
feed = torch.index_select(feed, 0, index.flatten()).view(*index.size(), -1)
o = out.repeat(1, index.size(1)).view(*index.size(), -1)
feed = torch.cat([feed, o], dim=-1)
return feed
feed = torch.tensor([[0., 1., 2.], [2., 3., 4.], [4., 5., 6.]])
index = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]])
out = torch.tensor([[0., 1.], [1., 2.], [3., 4.], [5., 6.]])
feed = feed.cuda()
index = index.cuda()
out = out.cuda()
test = DataParallel(Test().cuda())
test(feed, index, out)
RuntimeError Traceback (most recent call last)
/tmp/ipykernel_12948/3522224863.py in <module>
----> 1 test(feed, index, out)
~/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1049 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1050 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1051 return forward_call(*input, **kwargs)
1052 # Do not call functions when jit is used
1053 full_backward_hooks, non_full_backward_hooks = [], []
~/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs)
166 return self.module(*inputs[0], **kwargs[0])
167 replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
--> 168 outputs = self.parallel_apply(replicas, inputs, kwargs)
169 return self.gather(outputs, self.output_device)
170
~/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs)
176
177 def parallel_apply(self, replicas, inputs, kwargs):
--> 178 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
179
180 def gather(self, outputs, output_device):
~/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices)
84 output = results[i]
85 if isinstance(output, ExceptionWrapper):
---> 86 output.reraise()
87 outputs.append(output)
88 return outputs
~/lib/python3.7/site-packages/torch/_utils.py in reraise(self)
423 # have message field
424 raise self.exc_type(message=msg)
--> 425 raise self.exc_type(msg)
426
427
RuntimeError: Caught RuntimeError in replica 1 on device 1.
Original Traceback (most recent call last):
File "~/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker
output = module(*input, **kwargs)
File "~/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
return forward_call(*input, **kwargs)
File "/tmp/ipykernel_12948/397267949.py", line 11, in forward
feed = torch.index_select(feed, 0, index.flatten()).view(*index.size(), -1)
RuntimeError: CUDA error: device-side assert triggered
The above code is working on a CPU and a single GPU. But, the error comes out from multiple GPUs with DataParallel